使用嵌入训练文本分类器#

概述#

在本笔记本中,您将学习使用 Gemini API 生成的嵌入来训练模型,该模型可以根据主题对不同类型的新闻组帖子进行分类。

在本教程中,您将训练一个分类器来预测新闻组帖子属于哪个类别。

前提条件#

您可以在 Google Colab 中运行此快速入门。

要在您自己的开发环境中完成本快速入门,请确保您的环境满足以下要求:

  • Python 3.9+

  • 安装 jupyter 以运行笔记本

安装#

首先,下载并安装 Gemini API Python 库。

!pip install google.generativeai 
# keras tensorflow
## !pip install -U -q google.colab
Requirement already satisfied: google.generativeai in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (0.3.1)
Requirement already satisfied: google-ai-generativelanguage==0.4.0 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (0.4.0)
Requirement already satisfied: google-auth in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (2.25.2)
Requirement already satisfied: google-api-core in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (2.15.0)
Requirement already satisfied: protobuf in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (4.23.4)
Requirement already satisfied: tqdm in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google.generativeai) (4.66.1)
Requirement already satisfied: proto-plus<2.0.0dev,>=1.22.3 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-ai-generativelanguage==0.4.0->google.generativeai) (1.23.0)
Requirement already satisfied: googleapis-common-protos<2.0.dev0,>=1.56.2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core->google.generativeai) (1.62.0)
Requirement already satisfied: requests<3.0.0.dev0,>=2.18.0 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core->google.generativeai) (2.31.0)
Requirement already satisfied: cachetools<6.0,>=2.0.0 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-auth->google.generativeai) (5.3.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-auth->google.generativeai) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-auth->google.generativeai) (4.9)
Requirement already satisfied: grpcio<2.0dev,>=1.33.2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.10.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,<3.0.0dev,>=1.34.0->google-ai-generativelanguage==0.4.0->google.generativeai) (1.60.0)
Requirement already satisfied: grpcio-status<2.0.dev0,>=1.33.2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from google-api-core[grpc]!=2.0.*,!=2.1.*,!=2.10.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,<3.0.0dev,>=1.34.0->google-ai-generativelanguage==0.4.0->google.generativeai) (1.60.0)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from pyasn1-modules>=0.2.1->google-auth->google.generativeai) (0.5.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (2.1.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/st/miniconda3/envs/gemini/lib/python3.10/site-packages (from requests<3.0.0.dev0,>=2.18.0->google-api-core->google.generativeai) (2023.11.17)
import re
import tqdm
import keras
import numpy as np
import pandas as pd

import google.generativeai as genai
import google.ai.generativelanguage as glm

# Used to securely store your API key
# from google.colab import userdata

import seaborn as sns
import matplotlib.pyplot as plt

from keras import layers
from matplotlib.ticker import MaxNLocator
from sklearn.datasets import fetch_20newsgroups
import sklearn.metrics as skmetrics

获取 API 密钥#

在使用 Gemini API 之前,您必须先获取 API 密钥。如果您还没有密钥,请在 Google AI Studio 中一键创建密钥。

在 Colab 中,将密钥添加到左侧面板“🔑”下的秘密管理器中。将其命名为 API_KEY。 获得 API 密钥后,将其传递给 SDK。您可以通过两种方式执行此操作:

  • 将密钥放入 GOOGLE_API_KEY 环境变量中(SDK 将自动从那里获取它)。

  • 将密钥传递给 genai.configure(api_key=…)

# Or use `os.getenv('GOOGLE_API_KEY')` to fetch an environment variable.
# GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')
GOOGLE_API_KEY = "YOUR-API-KEY"
genai.configure(api_key=GOOGLE_API_KEY)

Tip

要点:接下来,您将选择一个模型。任何嵌入模型都适用于本教程,但对于实际应用程序,选择特定模型并坚持使用非常重要。不同型号的输出互不兼容。

for m in genai.list_models():
  if 'embedContent' in m.supported_generation_methods:
    print(m.name)

# for m in genai.list_models():
#   if 'generateContent' in m.supported_generation_methods:
#     print(m.name)
models/embedding-001

数据集#

20 个新闻组文本数据集包含 20 个主题的 18,000 个新闻组帖子,分为训练集和测试集。训练和测试数据集之间的划分基于特定日期之前和之后发布的消息。在本教程中,您将使用训练和测试数据集的子集。您将预处理数据并将其组织到 Pandas 数据框中。

from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')

# View list of class names for dataset
newsgroups_train.target_names
['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

以下是训练集中数据点的示例。

idx = newsgroups_train.data[0].index('Lines')
print(newsgroups_train.data[0][idx:])
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----

现在您将开始预处理本教程的数据。删除任何敏感信息,例如姓名、电子邮件或文本的冗余部分(例如“发件人:”“\n主题:”)。将信息组织到 Pandas 数据框中,使其更具可读性。

def preprocess_newsgroup_data(newsgroup_dataset):
  # Apply functions to remove names, emails, and extraneous words from data points in newsgroups.data
  newsgroup_dataset.data = [re.sub(r'[\w\.-]+@[\w\.-]+', '', d) for d in newsgroup_dataset.data] # Remove email
  newsgroup_dataset.data = [re.sub(r"\([^()]*\)", "", d) for d in newsgroup_dataset.data] # Remove names
  newsgroup_dataset.data = [d.replace("From: ", "") for d in newsgroup_dataset.data] # Remove "From: "
  newsgroup_dataset.data = [d.replace("\nSubject: ", "") for d in newsgroup_dataset.data] # Remove "\nSubject: "

  # Cut off each text entry after 5,000 characters
  newsgroup_dataset.data = [d[0:5000] if len(d) > 5000 else d for d in newsgroup_dataset.data]

  # Put data points into dataframe
  df_processed = pd.DataFrame(newsgroup_dataset.data, columns=['Text'])
  df_processed['Label'] = newsgroup_dataset.target
  # Match label to target name index
  df_processed['Class Name'] = ''
  for idx, row in df_processed.iterrows():
    df_processed.at[idx, 'Class Name'] = newsgroup_dataset.target_names[row['Label']]

  return df_processed
# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)
df_train.head()
Text Label Class Name
0 WHAT car is this!?\nNntp-Posting-Host: rac3.w... 7 rec.autos
1 SI Clock Poll - Final Call\nSummary: Final ca... 4 comp.sys.mac.hardware
2 PB questions...\nOrganization: Purdue Univers... 4 comp.sys.mac.hardware
3 Re: Weitek P9000 ?\nOrganization: Harris Comp... 1 comp.graphics
4 Re: Shuttle Launch Question\nOrganization: Sm... 14 sci.space

接下来,您将通过在训练数据集中获取 100 个数据点并删除一些类别来对一些数据进行采样,以运行本教程。选择要比较的科学类别。

def sample_data(df, num_samples, classes_to_keep):
  df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)

  df = df[df['Class Name'].str.contains(classes_to_keep)]

  # Reset the encoding of the labels after sampling and dropping certain categories
  df['Class Name'] = df['Class Name'].astype('category')
  df['Encoded Label'] = df['Class Name'].cat.codes

  return df
TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = 'sci' # Class name should contain 'sci' in it to keep science categories
df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)
df_train.value_counts('Class Name')
Class Name
sci.crypt          100
sci.electronics    100
sci.med            100
sci.space          100
Name: count, dtype: int64
df_test.value_counts('Class Name')
Class Name
sci.crypt          25
sci.electronics    25
sci.med            25
sci.space          25
Name: count, dtype: int64

创建嵌入#

在本节中,您将了解如何使用 Gemini API 中的嵌入为一段文本生成嵌入。要了解有关嵌入的更多信息,请访问嵌入指南

Tip

注意:嵌入一次计算一个,大样本量可能需要很长时间!

嵌入的 API 更改 embedding-001#

对于新的嵌入模型,有一个新的任务类型参数和可选标题(仅在 task_type=RETRIEVAL_DOCUMENT 时有效)。

这些新参数仅适用于最新的嵌入模型。任务类型为:

任务类型

描述

RETRIEVAL_QUERY

指定给定文本是搜索/检索设置中的查询。

RETRIEVAL_DOCUMENT

指定给定文本是搜索/检索设置中的文档。

SEMANTIC_SIMILARITY

指定给定文本将用于语义文本相似性 (STS)。

CLASSIFICATION

指定嵌入将用于分类。

CLUSTERING

指定嵌入将用于聚类。

from tqdm.auto import tqdm
tqdm.pandas()

from google.api_core import retry

def make_embed_text_fn(model):

  @retry.Retry(timeout=300.0)
  def embed_fn(text: str) -> list[float]:
    # Set the task_type to CLASSIFICATION.
    embedding = genai.embed_content(model=model,
                                    content=text,
                                    task_type="classification")
    return embedding['embedding']

  return embed_fn

def create_embeddings(model, df):
  df['Embeddings'] = df['Text'].progress_apply(make_embed_text_fn(model))
  return df
model = 'models/embedding-001'
df_train = create_embeddings(model, df_train)
df_test = create_embeddings(model, df_test)
100%|██████████| 400/400 [02:35<00:00,  2.57it/s]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s]
df_train.head()
Text Label Class Name Encoded Label Embeddings
1100 More technical details\nOrganization: AT&T Be... 11 sci.crypt 0 [0.005982968, -0.024433807, -0.028595297, -0.0...
1101 Subject: Re: Keeping Your Mouth Shut \n \nRepl... 11 sci.crypt 0 [0.021684153, 0.023106724, -0.06751694, -0.053...
1102 Re: How do they know what keys to ask for? \n... 11 sci.crypt 0 [0.0026794474, -0.012339441, -0.084823035, -0....
1103 Re: Source of random bits on a Unix workstati... 11 sci.crypt 0 [0.0067265956, -0.06828294, -0.093188696, -0.0...
1104 Marc VanHeyningen <>Re: More technical details... 11 sci.crypt 0 [-0.01643939, -0.016774608, -0.020152368, -0.0...

构建简单的分类模型#

在这里,您将定义一个具有一个隐藏层和一类概率输出的简单模型。预测将对应于一段文本是特定类别新闻的概率。当您构建模型时,Keras 会自动打乱数据点。

def build_classification_model(input_size: int, num_classes: int) -> keras.Model:
  inputs = x = keras.Input(input_size)
  x = layers.Dense(input_size, activation='relu')(x)
  x = layers.Dense(num_classes, activation='sigmoid')(x)
  return keras.Model(inputs=[inputs], outputs=x)
# Derive the embedding size from the first training element.
embedding_size = len(df_train['Embeddings'].iloc[0])

# Give your model a different name, as you have already used the variable name 'model'
classifier = build_classification_model(embedding_size, len(df_train['Class Name'].unique()))
classifier.summary()

classifier.compile(loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   optimizer = keras.optimizers.Adam(learning_rate=0.001),
                   metrics=['accuracy'])
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 768)]             0         
                                                                 
 dense_2 (Dense)             (None, 768)               590592    
                                                                 
 dense_3 (Dense)             (None, 4)                 3076      
                                                                 
=================================================================
Total params: 593668 (2.26 MB)
Trainable params: 593668 (2.26 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
embedding_size
768

训练模型对新闻组进行分类#

最后,您可以训练一个简单的模型。使用少量的 epoch 以避免过度拟合。第一个时期比其他时期花费的时间要长得多,因为嵌入只需要计算一次。

NUM_EPOCHS = 20
BATCH_SIZE = 32

# Split the x and y components of the train and validation subsets.
y_train = df_train['Encoded Label']
x_train = np.stack(df_train['Embeddings'])
y_val = df_test['Encoded Label']
x_val = np.stack(df_test['Embeddings'])

# Train the model for the desired number of epochs.
callback = keras.callbacks.EarlyStopping(monitor='accuracy', patience=3)

history = classifier.fit(x=x_train,
                         y=y_train,
                         validation_data=(x_val, y_val),
                         callbacks=[callback],
                         batch_size=BATCH_SIZE,
                         epochs=NUM_EPOCHS,)

评估模型性能#

使用 Keras Model.evaluate获取测试数据集上的损失和准确性。

classifier.evaluate(x=x_val, y=y_val, return_dict=True)

评估模型性能的一种方法是可视化分类器性能。使用plot_history 查看各个时期的损失和准确性趋势。

def plot_history(history):
  """
    Plotting training and validation learning curves.

    Args:
      history: model history with all the metric measures
  """
  fig, (ax1, ax2) = plt.subplots(1,2)
  fig.set_size_inches(20, 8)

  # Plot loss
  ax1.set_title('Loss')
  ax1.plot(history.history['loss'], label = 'train')
  ax1.plot(history.history['val_loss'], label = 'test')
  ax1.set_ylabel('Loss')

  ax1.set_xlabel('Epoch')
  ax1.legend(['Train', 'Validation'])

  # Plot accuracy
  ax2.set_title('Accuracy')
  ax2.plot(history.history['accuracy'],  label = 'train')
  ax2.plot(history.history['val_accuracy'], label = 'test')
  ax2.set_ylabel('Accuracy')
  ax2.set_xlabel('Epoch')
  ax2.legend(['Train', 'Validation'])

  plt.show()

plot_history(history)

除了测量损失和准确性之外,查看模型性能的另一种方法是使用混淆矩阵。混淆矩阵使您能够评估分类模型在准确性之外的性能。您可以看到错误分类的点被分类为哪些内容。为了构建这个多类分类问题的混淆矩阵,需要获取测试集中的实际值和预测值。

首先使用Model.predict()为验证集中的每个示例生成预测类。

y_hat = classifier.predict(x=x_val)
y_hat = np.argmax(y_hat, axis=1)
labels_dict = dict(zip(df_test['Class Name'], df_test['Encoded Label']))
labels_dict
cm = skmetrics.confusion_matrix(y_val, y_hat)
disp = skmetrics.ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=labels_dict.keys())
disp.plot(xticks_rotation='vertical')
plt.title('Confusion matrix for newsgroup test dataset');
plt.grid(False)

下一步#

要了解有关如何使用嵌入的更多信息,请查看可用的示例
要了解如何使用 Gemini API 中的其他服务,请访问Python 快速入门