使用 t-SNE 可视化嵌入#

概述#

本教程演示如何使用 Gemini API 中的嵌入来可视化和执行聚类。您将使用t-SNE可视化 20 个新闻组数据集的子集,并使用KMeans算法对该子集进行聚类。

有关开始使用 Gemini API 生成的嵌入的更多信息,请查看Python 快速入门

前提条件#

您可以在 Google Colab 中运行此快速入门。

要在您自己的开发环境中完成本快速入门,请确保您的环境满足以下要求:

  • Python 3.9+

  • 安装 jupyter 以运行笔记本

安装#

首先,下载并安装 Gemini API Python 库。

#!pip install -q google.generativeai

## !pip install -U -q google.colab
import re
import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import google.generativeai as genai
import google.ai.generativelanguage as glm

# Used to securely store your API key
# from google.colab import userdata

from sklearn.datasets import fetch_20newsgroups
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

获取 API 密钥#

在使用 Gemini API 之前,您必须先获取 API 密钥。如果您还没有密钥,请在 Google AI Studio 中一键创建密钥。

在 Colab 中,将密钥添加到左侧面板“🔑”下的秘密管理器中。将其命名为 API_KEY。 获得 API 密钥后,将其传递给 SDK。您可以通过两种方式执行此操作:

  • 将密钥放入 GOOGLE_API_KEY 环境变量中(SDK 将自动从那里获取它)。

  • 将密钥传递给 genai.configure(api_key=…)

# Or use `os.getenv('API_KEY')` to fetch an environment variable.
# API_KEY=userdata.get('API_KEY')
GOOGLE_API_KEY = "YOUR-API-KEY"
genai.configure(api_key=GOOGLE_API_KEY)

Tip

要点:接下来,您将选择一个模型。任何嵌入模型都适用于本教程,但对于实际应用程序,选择特定模型并坚持使用非常重要。不同型号的输出互不兼容。

Warning

注意:目前,Gemini API 仅在某些区域可用

for m in genai.list_models():
  if 'embedContent' in m.supported_generation_methods:
    print(m.name)
models/embedding-001

数据集#

20 个新闻组文本数据集包含 20 个主题的 18,000 个新闻组帖子,分为训练集和测试集。训练和测试数据集之间的划分基于特定日期之前和之后发布的消息。在本教程中,您将使用训练子集。

newsgroups_train = fetch_20newsgroups(subset='train')

# View list of class names for dataset
newsgroups_train.target_names
['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

这是训练集中的第一个示例。

idx = newsgroups_train.data[0].index('Lines')
print(newsgroups_train.data[0][idx:])
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----
# Apply functions to remove names, emails, and extraneous words from data points in newsgroups.data
newsgroups_train.data = [re.sub(r'[\w\.-]+@[\w\.-]+', '', d) for d in newsgroups_train.data] # Remove email
newsgroups_train.data = [re.sub(r"\([^()]*\)", "", d) for d in newsgroups_train.data] # Remove names
newsgroups_train.data = [d.replace("From: ", "") for d in newsgroups_train.data] # Remove "From: "
newsgroups_train.data = [d.replace("\nSubject: ", "") for d in newsgroups_train.data] # Remove "\nSubject: "
# Put training points into a dataframe
df_train = pd.DataFrame(newsgroups_train.data, columns=['Text'])
df_train['Label'] = newsgroups_train.target
# Match label to target name index
df_train['Class Name'] = df_train['Label'].map(newsgroups_train.target_names.__getitem__)
# Retain text samples that can be used in the gecko model.
df_train = df_train[df_train['Text'].str.len() < 10000]

df_train
Text Label Class Name
0 WHAT car is this!?\nNntp-Posting-Host: rac3.w... 7 rec.autos
1 SI Clock Poll - Final Call\nSummary: Final ca... 4 comp.sys.mac.hardware
2 PB questions...\nOrganization: Purdue Univers... 4 comp.sys.mac.hardware
3 Re: Weitek P9000 ?\nOrganization: Harris Comp... 1 comp.graphics
4 Re: Shuttle Launch Question\nOrganization: Sm... 14 sci.space
... ... ... ...
11309 Re: Migraines and scans\nDistribution: world... 13 sci.med
11310 Screen Death: Mac Plus/512\nLines: 22\nOrganiz... 4 comp.sys.mac.hardware
11311 Mounting CPU Cooler in vertical case\nOrganiz... 3 comp.sys.ibm.pc.hardware
11312 Re: Sphere from 4 points?\nOrganization: Cent... 1 comp.graphics
11313 stolen CBR900RR\nOrganization: California Ins... 8 rec.motorcycles

11141 rows × 3 columns

接下来,您将通过在训练数据集中获取 100 个数据点并删除一些类别来对一些数据进行采样,以运行本教程。选择要比较的科学类别。

# Take a sample of each label category from df_train
SAMPLE_SIZE = 150
df_train = (df_train.groupby('Label', as_index = False)
                    .apply(lambda x: x.sample(SAMPLE_SIZE))
                    .reset_index(drop=True))

# Choose categories about science
df_train = df_train[df_train['Class Name'].str.contains('sci')]

# Reset the index
df_train = df_train.reset_index()
df_train
index Text Label Class Name
0 1650 Re: 80-bit keyseach machine\nOrganization: Ad... 11 sci.crypt
1 1651 reason for Clipper alg'm secrecy\nReply-To: ... 11 sci.crypt
2 1652 Re: Do we need the clipper for cheap security... 11 sci.crypt
3 1653 Re: Hard drive security for FBI targets\nLine... 11 sci.crypt
4 1654 Re: Estimating Wiretap Costs/Benefits\nSummar... 11 sci.crypt
... ... ... ... ...
595 2245 Re: Lindbergh and the moon \nOrganization: Un... 14 sci.space
596 2246 Re: Solar Sail Data\nArticle-I.D.: aurora.1993... 14 sci.space
597 2247 Re: Boom! Whoosh......\nOrganization: The Do... 14 sci.space
598 2248 Re: Why not give $1 billion to first year-lon... 14 sci.space
599 2249 Re: Comet in Temporary Orbit Around Jupiter?\... 14 sci.space

600 rows × 4 columns

df_train['Class Name'].value_counts()
Class Name
sci.crypt          150
sci.electronics    150
sci.med            150
sci.space          150
Name: count, dtype: int64

创建嵌入#

在本节中,您将了解如何使用 Gemini API 中的嵌入为数据框中的不同文本生成嵌入。

使用模型 embedding-001 对嵌入进行 API 更改#

对于新的嵌入模型 embedding-001,有一个新的任务类型参数和可选标题(仅在 task_type=RETRIEVAL_DOCUMENT 时有效)。
这些新参数仅适用于最新的嵌入模型。任务类型为:

任务类型

描述

RETRIEVAL_QUERY

指定给定文本是搜索/检索设置中的查询。

RETRIEVAL_DOCUMENT

指定给定文本是搜索/检索设置中的文档。

SEMANTIC_SIMILARITY

指定给定文本将用于语义文本相似性 (STS)。

CLASSIFICATION

指定嵌入将用于分类。

CLUSTERING

指定嵌入将用于聚类。

from tqdm.auto import tqdm
tqdm.pandas()

from google.api_core import retry

def make_embed_text_fn(model):

  @retry.Retry(timeout=300.0)
  def embed_fn(text: str) -> list[float]:
    # Set the task_type to CLUSTERING.
    embedding = genai.embed_content(model=model,
                                    content=text,
                                    task_type="clustering")
    return embedding["embedding"]

  return embed_fn

def create_embeddings(df):
  model = 'models/embedding-001'
  df['Embeddings'] = df['Text'].progress_apply(make_embed_text_fn(model))
  return df

df_train = create_embeddings(df_train)
100%|██████████| 600/600 [03:37<00:00,  2.76it/s]

降维#

文档嵌入向量的长度为 768。为了可视化嵌入文档如何分组在一起,您将需要应用降维,因为您只能可视化 2D 或 3D 空间中的嵌入。上下文相似的文档在空间上应该更接近,而不是不相似的文档。

len(df_train['Embeddings'][0])
768
# Convert df_train['Embeddings'] Pandas series to a np.array of float32
X = np.array(df_train['Embeddings'].to_list(), dtype=np.float32)
X.shape
(600, 768)

您将应用 t 分布式随机邻域嵌入 (t-SNE) 方法来执行降维。该技术减少了维数,同时保留了簇(靠近的点保持靠近)。对于原始数据,模型尝试构建一个分布,其中其他数据点是“邻居”(例如,它们具有相似的含义)。然后,它优化目标函数以在可视化中保持类似的分布。

tsne = TSNE(random_state=0, n_iter=1000)
tsne_results = tsne.fit_transform(X)
df_tsne = pd.DataFrame(tsne_results, columns=['TSNE1', 'TSNE2'])
df_tsne['Class Name'] = df_train['Class Name'] # Add labels column from df_train to df_tsne
df_tsne
TSNE1 TSNE2 Class Name
0 30.719217 6.754019 sci.crypt
1 42.622414 0.371172 sci.crypt
2 43.632767 -3.970512 sci.crypt
3 25.857092 9.590956 sci.crypt
4 30.428589 -1.919301 sci.crypt
... ... ... ...
595 -1.371668 32.625580 sci.space
596 1.049393 12.153948 sci.space
597 -0.349571 15.891821 sci.space
598 -0.791861 27.201237 sci.space
599 -18.490322 13.022571 sci.space

600 rows × 3 columns

fig, ax = plt.subplots(figsize=(8,6)) # Set figsize
sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
sns.scatterplot(data=df_tsne, x='TSNE1', y='TSNE2', hue='Class Name', palette='hls')
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.title('Scatter plot of news using t-SNE');
plt.xlabel('TSNE1');
plt.ylabel('TSNE2');
plt.axis('equal')
(-56.389787673950195, 54.20468330383301, -34.96730289459229, 35.87395343780518)
../_images/80c0e6b89732cc79818d0877235df5c8fb646f0b188fe8ef895443a0681060b6.png

将结果与 KMeans 进行比较#

KMeans 聚类是一种流行的聚类算法,经常用于无监督学习。它迭代地确定最佳的 k 个中心点,并将每个示例分配给最近的质心。将嵌入直接输入到 KMeans 算法中,以将嵌入的可视化与机器学习算法的性能进行比较。

# Apply KMeans
kmeans_model = KMeans(n_clusters=4, random_state=1, n_init='auto').fit(X)
labels = kmeans_model.fit_predict(X)
df_tsne['Cluster'] = labels
df_tsne
TSNE1 TSNE2 Class Name Cluster
0 30.719217 6.754019 sci.crypt 0
1 42.622414 0.371172 sci.crypt 0
2 43.632767 -3.970512 sci.crypt 0
3 25.857092 9.590956 sci.crypt 0
4 30.428589 -1.919301 sci.crypt 0
... ... ... ... ...
595 -1.371668 32.625580 sci.space 1
596 1.049393 12.153948 sci.space 1
597 -0.349571 15.891821 sci.space 1
598 -0.791861 27.201237 sci.space 1
599 -18.490322 13.022571 sci.space 1

600 rows × 4 columns

fig, ax = plt.subplots(figsize=(8,6)) # Set figsize
sns.set_style('darkgrid', {"grid.color": ".6", "grid.linestyle": ":"})
sns.scatterplot(data=df_tsne, x='TSNE1', y='TSNE2', hue='Cluster', palette='magma')
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
plt.title('Scatter plot of news using KMeans Clustering');
plt.xlabel('TSNE1');
plt.ylabel('TSNE2');
plt.axis('equal')
(-56.389787673950195, 54.20468330383301, -34.96730289459229, 35.87395343780518)
../_images/3c942efc386898031d31449dc8772ce99da7ed3d4f26b1e75a5999e942b2c53d.png
def get_majority_cluster_per_group(df_tsne_cluster, class_names):
  class_clusters = dict()
  for c in class_names:
    # Get rows of dataframe that are equal to c
    rows = df_tsne_cluster.loc[df_tsne_cluster['Class Name'] == c]
    # Get majority value in Cluster column of the rows selected
    cluster = rows.Cluster.mode().values[0]
    # Populate mapping dictionary
    class_clusters[c] = cluster
  return class_clusters
classes = df_tsne['Class Name'].unique()
class_clusters = get_majority_cluster_per_group(df_tsne, classes)
class_clusters
{'sci.crypt': 0, 'sci.electronics': 3, 'sci.med': 2, 'sci.space': 1}

获取每个组的大多数集群,并查看该集群中有多少该组的实际成员。

# Convert the Cluster column to use the class name
class_by_id = {v: k for k, v in class_clusters.items()}
df_tsne['Predicted'] = df_tsne['Cluster'].map(class_by_id.__getitem__)

# Filter to the correctly matched rows
correct = df_tsne[df_tsne['Class Name'] == df_tsne['Predicted']]

# Summarise, as a percentage
acc = correct['Class Name'].value_counts() / SAMPLE_SIZE
acc
Class Name
sci.crypt          0.973333
sci.space          0.946667
sci.med            0.933333
sci.electronics    0.926667
Name: count, dtype: float64
# Get predicted values by name
df_tsne['Predicted'] = ''
for idx, rows in df_tsne.iterrows():
  cluster = rows['Cluster']
  # Get key from mapping based on cluster value
  key = list(class_clusters.keys())[list(class_clusters.values()).index(cluster)]
  df_tsne.at[idx, 'Predicted'] = key

df_tsne
TSNE1 TSNE2 Class Name Cluster Predicted
0 30.719217 6.754019 sci.crypt 0 sci.crypt
1 42.622414 0.371172 sci.crypt 0 sci.crypt
2 43.632767 -3.970512 sci.crypt 0 sci.crypt
3 25.857092 9.590956 sci.crypt 0 sci.crypt
4 30.428589 -1.919301 sci.crypt 0 sci.crypt
... ... ... ... ... ...
595 -1.371668 32.625580 sci.space 1 sci.space
596 1.049393 12.153948 sci.space 1 sci.space
597 -0.349571 15.891821 sci.space 1 sci.space
598 -0.791861 27.201237 sci.space 1 sci.space
599 -18.490322 13.022571 sci.space 1 sci.space

600 rows × 5 columns

为了更好地可视化应用于数据的 KMeans 性能,您可以使用混淆矩阵。混淆矩阵使您能够评估分类模型在准确性之外的性能。您可以看到错误分类的点被分类为哪些内容。您将需要在上面的数据框中收集的实际值和预测值。

cm = confusion_matrix(df_tsne['Class Name'].to_list(), df_tsne['Predicted'].to_list())
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=classes)
disp.plot(xticks_rotation='vertical')
plt.title('Confusion Matrix for Actual and Clustered Newsgroups');
plt.grid(False)
../_images/b64c9c1ba8b5492bba1afef71a86e3e69e55280a741a4de4b1538020dc4b4e57.png

下一步#

您现在已经通过聚类创建了自己的嵌入可视化!尝试使用您自己的文本数据将它们可视化为嵌入。您可以执行降维以完成可视化步骤。请注意,TSNE 擅长对输入进行聚类,但可能需要较长时间才能收敛,或者可能会陷入局部最小值。如果您遇到此问题,您可以考虑的另一种技术是主成分分析 (PCA)

KMeans 之外还有其他聚类算法,例如基于密度的空间聚类 (DBSCAN)

要了解如何使用 Gemini API 中的其他服务,请访问 Python 快速入门。要了解有关如何使用嵌入的更多信息,请查看可用的示例。要了解如何从头开始创建它们,请参阅 TensorFlow 的词嵌入教程