使用嵌入训练文本分类器#
概述#
在本笔记本中,您将学习使用 Gemini API 生成的嵌入来训练模型,该模型可以根据主题对不同类型的新闻组帖子进行分类。
在本教程中,您将训练一个分类器来预测新闻组帖子属于哪个类别。
前提条件#
您可以在 Google Colab 中运行此快速入门。
要在您自己的开发环境中完成本快速入门,请确保您的环境满足以下要求:
Python 3.9+
安装 jupyter 以运行笔记本
安装#
首先,下载并安装 Gemini API Python 库。
!pip install google.generativeai
# keras tensorflow
## !pip install -U -q google.colab
Requirement already satisfied: google.generativeai in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (0.3.1)
Requirement already satisfied: google-ai-generativelanguage==0.4.0 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (0.4.0)
Requirement already satisfied: google-auth in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (2.25.2)
Requirement already satisfied: google-api-core in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (2.15.0)
Requirement already satisfied: protobuf in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (4.23.4)
Requirement already satisfied: tqdm in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (4.66.1)
Requirement already satisfied: proto-plus<2.0.0dev,>=1.22.3 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-ai-generativelanguage==0.4.0->google.generativeai) (1.23.0)
Requirement already satisfied: googleapis-common-protos<2.0.dev0,>=1.56.2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core->google.generativeai) (1.62.0)
Requirement already satisfied: requests<3.0.0.dev0,>=2.18.0 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core->google.generativeai) (2.31.0)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-auth->google.generativeai) (5.3.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-auth->google.generativeai) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-auth->google.generativeai) (4.9)
Requirement already satisfied: grpcio<2.0dev,>=1.33.2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.10.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,<3.0.0dev,>=1.34.0->google-ai-generativelanguage==0.4.0->google.generativeai) (1.60.0)
Requirement already satisfied: grpcio-status<2.0.dev0,>=1.33.2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.10.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,<3.0.0dev,>=1.34.0->google-ai-generativelanguage==0.4.0->google.generativeai) (1.60.0)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth->google.generativeai) (0.5.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (2023.11.17)
import re
import tqdm
import keras
import numpy as np
import pandas as pd
import google.generativeai as genai
import google.ai.generativelanguage as glm
# Used to securely store your API key
# from google.colab import userdata
import seaborn as sns
import matplotlib.pyplot as plt
from keras import layers
from matplotlib.ticker import MaxNLocator
from sklearn.datasets import fetch_20newsgroups
import sklearn.metrics as skmetrics
获取 API 密钥#
在使用 Gemini API 之前,您必须先获取 API 密钥。如果您还没有密钥,请在 Google AI Studio 中一键创建密钥。
在 Colab 中,将密钥添加到左侧面板“🔑”下的秘密管理器中。将其命名为 API_KEY。 获得 API 密钥后,将其传递给 SDK。您可以通过两种方式执行此操作:
将密钥放入 GOOGLE_API_KEY 环境变量中(SDK 将自动从那里获取它)。
将密钥传递给 genai.configure(api_key=…)
# Or use `os.getenv('GOOGLE_API_KEY')` to fetch an environment variable.
# GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
GOOGLE_API_KEY = "YOUR-API-KEY"
genai.configure(api_key=GOOGLE_API_KEY)
Tip
要点:接下来,您将选择一个模型。任何嵌入模型都适用于本教程,但对于实际应用程序,选择特定模型并坚持使用非常重要。不同型号的输出互不兼容。
for m in genai.list_models():
if 'embedContent' in m.supported_generation_methods:
print(m.name)
# for m in genai.list_models():
# if 'generateContent' in m.supported_generation_methods:
# print(m.name)
models/embedding-001
数据集#
20 个新闻组文本数据集包含 20 个主题的 18,000 个新闻组帖子,分为训练集和测试集。训练和测试数据集之间的划分基于特定日期之前和之后发布的消息。在本教程中,您将使用训练和测试数据集的子集。您将预处理数据并将其组织到 Pandas 数据框中。
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')
# View list of class names for dataset
newsgroups_train.target_names
['alt.atheism',
'comp.graphics',
'comp.os.ms-windows.misc',
'comp.sys.ibm.pc.hardware',
'comp.sys.mac.hardware',
'comp.windows.x',
'misc.forsale',
'rec.autos',
'rec.motorcycles',
'rec.sport.baseball',
'rec.sport.hockey',
'sci.crypt',
'sci.electronics',
'sci.med',
'sci.space',
'soc.religion.christian',
'talk.politics.guns',
'talk.politics.mideast',
'talk.politics.misc',
'talk.religion.misc']
以下是训练集中数据点的示例。
idx = newsgroups_train.data[0].index('Lines')
print(newsgroups_train.data[0][idx:])
Lines: 15
I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.
Thanks,
- IL
---- brought to you by your neighborhood Lerxst ----
现在您将开始预处理本教程的数据。删除任何敏感信息,例如姓名、电子邮件或文本的冗余部分(例如“发件人:”
和“\n主题:”
)。将信息组织到 Pandas 数据框中,使其更具可读性。
def preprocess_newsgroup_data(newsgroup_dataset):
# Apply functions to remove names, emails, and extraneous words from data points in newsgroups.data
newsgroup_dataset.data = [re.sub(r'[\w\.-]+@[\w\.-]+', '', d) for d in newsgroup_dataset.data] # Remove email
newsgroup_dataset.data = [re.sub(r"\([^()]*\)", "", d) for d in newsgroup_dataset.data] # Remove names
newsgroup_dataset.data = [d.replace("From: ", "") for d in newsgroup_dataset.data] # Remove "From: "
newsgroup_dataset.data = [d.replace("\nSubject: ", "") for d in newsgroup_dataset.data] # Remove "\nSubject: "
# Cut off each text entry after 5,000 characters
newsgroup_dataset.data = [d[0:5000] if len(d) > 5000 else d for d in newsgroup_dataset.data]
# Put data points into dataframe
df_processed = pd.DataFrame(newsgroup_dataset.data, columns=['Text'])
df_processed['Label'] = newsgroup_dataset.target
# Match label to target name index
df_processed['Class Name'] = ''
for idx, row in df_processed.iterrows():
df_processed.at[idx, 'Class Name'] = newsgroup_dataset.target_names[row['Label']]
return df_processed
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)
df_train.head()
Text | Label | Class Name | |
---|---|---|---|
0 | WHAT car is this!?\nNntp-Posting-Host: rac3.w... | 7 | rec.autos |
1 | SI Clock Poll - Final Call\nSummary: Final ca... | 4 | comp.sys.mac.hardware |
2 | PB questions...\nOrganization: Purdue Univers... | 4 | comp.sys.mac.hardware |
3 | Re: Weitek P9000 ?\nOrganization: Harris Comp... | 1 | comp.graphics |
4 | Re: Shuttle Launch Question\nOrganization: Sm... | 14 | sci.space |
接下来,您将通过在训练数据集中获取 100 个数据点并删除一些类别来对一些数据进行采样,以运行本教程。选择要比较的科学类别。
def sample_data(df, num_samples, classes_to_keep):
df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)
df = df[df['Class Name'].str.contains(classes_to_keep)]
# Reset the encoding of the labels after sampling and dropping certain categories
df['Class Name'] = df['Class Name'].astype('category')
df['Encoded Label'] = df['Class Name'].cat.codes
return df
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = 'sci' # Class name should contain 'sci' in it to keep science categories
df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)
df_train.value_counts('Class Name')
Class Name
sci.crypt 100
sci.electronics 100
sci.med 100
sci.space 100
Name: count, dtype: int64
df_test.value_counts('Class Name')
Class Name
sci.crypt 25
sci.electronics 25
sci.med 25
sci.space 25
Name: count, dtype: int64
创建嵌入#
在本节中,您将了解如何使用 Gemini API 中的嵌入为一段文本生成嵌入。要了解有关嵌入的更多信息,请访问嵌入指南。
Tip
注意:嵌入一次计算一个,大样本量可能需要很长时间!
嵌入的 API 更改 embedding-001#
对于新的嵌入模型,有一个新的任务类型参数和可选标题(仅在 task_type=RETRIEVAL_DOCUMENT 时有效)。
这些新参数仅适用于最新的嵌入模型。任务类型为:
任务类型 |
描述 |
---|---|
RETRIEVAL_QUERY |
指定给定文本是搜索/检索设置中的查询。 |
RETRIEVAL_DOCUMENT |
指定给定文本是搜索/检索设置中的文档。 |
SEMANTIC_SIMILARITY |
指定给定文本将用于语义文本相似性 (STS)。 |
CLASSIFICATION |
指定嵌入将用于分类。 |
CLUSTERING |
指定嵌入将用于聚类。 |
from tqdm.auto import tqdm
tqdm.pandas()
from google.api_core import retry
def make_embed_text_fn(model):
@retry.Retry(timeout=300.0)
def embed_fn(text: str) -> list[float]:
# Set the task_type to CLASSIFICATION.
embedding = genai.embed_content(model=model,
content=text,
task_type="classification")
return embedding['embedding']
return embed_fn
def create_embeddings(model, df):
df['Embeddings'] = df['Text'].progress_apply(make_embed_text_fn(model))
return df
model = 'models/embedding-001'
df_train = create_embeddings(model, df_train)
df_test = create_embeddings(model, df_test)
100%|██████████| 400/400 [02:35<00:00, 2.57it/s]
100%|██████████| 100/100 [00:38<00:00, 2.58it/s]
df_train.head()
Text | Label | Class Name | Encoded Label | Embeddings | |
---|---|---|---|---|---|
1100 | More technical details\nOrganization: AT&T Be... | 11 | sci.crypt | 0 | [0.005982968, -0.024433807, -0.028595297, -0.0... |
1101 | Subject: Re: Keeping Your Mouth Shut \n \nRepl... | 11 | sci.crypt | 0 | [0.021684153, 0.023106724, -0.06751694, -0.053... |
1102 | Re: How do they know what keys to ask for? \n... | 11 | sci.crypt | 0 | [0.0026794474, -0.012339441, -0.084823035, -0.... |
1103 | Re: Source of random bits on a Unix workstati... | 11 | sci.crypt | 0 | [0.0067265956, -0.06828294, -0.093188696, -0.0... |
1104 | Marc VanHeyningen <>Re: More technical details... | 11 | sci.crypt | 0 | [-0.01643939, -0.016774608, -0.020152368, -0.0... |
构建简单的分类模型#
在这里,您将定义一个具有一个隐藏层和一类概率输出的简单模型。预测将对应于一段文本是特定类别新闻的概率。当您构建模型时,Keras 会自动打乱数据点。
def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
inputs = x = keras.Input(input_size)
x = layers.Dense(input_size, activation='relu')(x)
x = layers.Dense(num_classes, activation='sigmoid')(x)
return keras.Model(inputs=[inputs], outputs=x)
# Derive the embedding size from the first training element.
embedding_size = len(df_train['Embeddings'].iloc[0])
# Give your model a different name, as you have already used the variable name 'model'
classifier = build_classification_model(embedding_size, len(df_train['Class Name'].unique()))
classifier.summary()
classifier.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer = keras.optimizers.Adam(learning_rate=0.001),
metrics=['accuracy'])
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 768)] 0
dense_2 (Dense) (None, 768) 590592
dense_3 (Dense) (None, 4) 3076
=================================================================
Total params: 593668 (2.26 MB)
Trainable params: 593668 (2.26 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
embedding_size
768
训练模型对新闻组进行分类#
最后,您可以训练一个简单的模型。使用少量的 epoch 以避免过度拟合。第一个时期比其他时期花费的时间要长得多,因为嵌入只需要计算一次。
NUM_EPOCHS = 20
BATCH_SIZE = 32
# Split the x and y components of the train and validation subsets.
y_train = df_train['Encoded Label']
x_train = np.stack(df_train['Embeddings'])
y_val = df_test['Encoded Label']
x_val = np.stack(df_test['Embeddings'])
# Train the model for the desired number of epochs.
callback = keras.callbacks.EarlyStopping(monitor='accuracy', patience=3)
history = classifier.fit(x=x_train,
y=y_train,
validation_data=(x_val, y_val),
callbacks=[callback],
batch_size=BATCH_SIZE,
epochs=NUM_EPOCHS,)
评估模型性能#
使用 Keras Model.evaluate
获取测试数据集上的损失和准确性。
classifier.evaluate(x=x_val, y=y_val, return_dict=True)
评估模型性能的一种方法是可视化分类器性能。使用plot_history 查看各个时期的损失和准确性趋势。
def plot_history(history):
"""
Plotting training and validation learning curves.
Args:
history: model history with all the metric measures
"""
fig, (ax1, ax2) = plt.subplots(1,2)
fig.set_size_inches(20, 8)
# Plot loss
ax1.set_title('Loss')
ax1.plot(history.history['loss'], label = 'train')
ax1.plot(history.history['val_loss'], label = 'test')
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')
ax1.legend(['Train', 'Validation'])
# Plot accuracy
ax2.set_title('Accuracy')
ax2.plot(history.history['accuracy'], label = 'train')
ax2.plot(history.history['val_accuracy'], label = 'test')
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Epoch')
ax2.legend(['Train', 'Validation'])
plt.show()
plot_history(history)
除了测量损失和准确性之外,查看模型性能的另一种方法是使用混淆矩阵。混淆矩阵使您能够评估分类模型在准确性之外的性能。您可以看到错误分类的点被分类为哪些内容。为了构建这个多类分类问题的混淆矩阵,需要获取测试集中的实际值和预测值。
首先使用Model.predict()
为验证集中的每个示例生成预测类。
y_hat = classifier.predict(x=x_val)
y_hat = np.argmax(y_hat, axis=1)
labels_dict = dict(zip(df_test['Class Name'], df_test['Encoded Label']))
labels_dict
cm = skmetrics.confusion_matrix(y_val, y_hat)
disp = skmetrics.ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=labels_dict.keys())
disp.plot(xticks_rotation='vertical')
plt.title('Confusion matrix for newsgroup test dataset');
plt.grid(False)
下一步#
要了解有关如何使用嵌入的更多信息,请查看可用的示例。
要了解如何使用 Gemini API 中的其他服务,请访问Python 快速入门。