使用嵌入进行文档搜索#
概述#
此示例演示如何使用 Gemini API 创建嵌入,以便您可以执行文档搜索。您将使用 Python 客户端库构建词嵌入,以便将搜索字符串或问题与文档内容进行比较。
在本教程中,您将使用嵌入对一组文档执行文档搜索,以询问与 Google Car 相关的问题。
前提条件#
您可以在 Google Colab 中运行此快速入门。
要在您自己的开发环境中完成本快速入门,请确保您的环境满足以下要求:
Python 3.9+
安装 jupyter 以运行笔记本
安装#
首先,下载并安装 Gemini API Python 库。
!pip install -U -q google.generativeai
## !pip install -U -q google.colab
import textwrap
import numpy as np
import pandas as pd
import google.generativeai as genai
import google.ai.generativelanguage as glm
# # Used to securely store your API key
# from google.colab import userdata
from IPython.display import Markdown
获取 API 密钥#
在使用 Gemini API 之前,您必须先获取 API 密钥。如果您还没有密钥,请在 Google AI Studio 中一键创建密钥。
在 Colab 中,将密钥添加到左侧面板“🔑”下的秘密管理器中。将其命名为 API_KEY。 获得 API 密钥后,将其传递给 SDK。您可以通过两种方式执行此操作:
将密钥放入 GOOGLE_API_KEY 环境变量中(SDK 将自动从那里获取它)。
将密钥传递给 genai.configure(api_key=…)
# Or use `os.getenv('API_KEY')` to fetch an environment variable.
# API_KEY=userdata.get('API_KEY')
GOOGLE_API_KEY = "YOUR-API-KEY"
genai.configure(api_key=GOOGLE_API_KEY)
Tip
要点:接下来,您将选择一个模型。任何嵌入模型都适用于本教程,但对于实际应用程序,选择特定模型并坚持使用非常重要。不同型号的输出互不兼容。
Warning
注意:目前,Gemini API 仅在某些区域可用。
for m in genai.list_models():
if 'embedContent' in m.supported_generation_methods:
print(m.name)
models/embedding-001
嵌入生成#
在本节中,您将了解如何使用 Gemini API 中的嵌入为一段文本生成嵌入。
使用模型embedding-001
对嵌入进行 API 更改#
对于新的嵌入模型 embedding-001,有一个新的任务类型参数和可选标题(仅在 task_type=RETRIEVAL_DOCUMENT 时有效)。
这些新参数仅适用于最新的嵌入模型。任务类型为:
任务类型 |
描述 |
---|---|
RETRIEVAL_QUERY |
指定给定文本是搜索/检索设置中的查询。 |
RETRIEVAL_DOCUMENT |
指定给定文本是搜索/检索设置中的文档。 |
SEMANTIC_SIMILARITY |
指定给定文本将用于语义文本相似性 (STS)。 |
CLASSIFICATION |
指定嵌入将用于分类。 |
CLUSTERING |
指定嵌入将用于聚类。 |
Tip
注意:指定 RETRIEVAL_DOCUMENT 的标题可为检索提供更高质量的嵌入。
title = "The next generation of AI for developers and Google Workspace"
sample_text = ("Title: The next generation of AI for developers and Google Workspace"
"\n"
"Full article:\n"
"\n"
"Gemini API & Google AI Studio: An approachable way to explore and prototype with generative AI applications")
model = 'models/embedding-001'
embedding = genai.embed_content(model=model,
content=sample_text,
task_type="retrieval_document",
title=title)
print(embedding)
{'embedding': [0.03411343, -0.05517662, -0.020209055, -0.0041249567, 0.058917783, 0.014129515, 0.0045353593, 0.0014303668, 0.05976634, 0.08292115, 0.007162964, 0.0069041685, -0.053083427, -0.010905125, 0.0321402, -0.037163995, 0.050372455, 0.019348344, -0.037328612, 0.026647927, 0.030781753, -0.011288501, -0.031485256, -0.060248993, -0.026219442, -0.009794235, 0.006630139, -0.01846516, -0.026324715, 0.020442624, -0.06317684, 0.014559574, -0.052296035, 0.016451128, -9.720217e-05, -0.051706687, -0.0054406044, -0.056967627, 0.011144145, -0.009201792, -0.0021951047, -0.1099701, -0.011712193, 0.021221714, 0.009171804, -0.029621972, 0.034534883, 0.039578073, 0.019021519, -0.06269169, 0.039473332, 0.052403256, 0.061814185, -0.034507945, -0.009557816, -0.0049551064, 0.017839009, -0.021176832, 0.015043588, 0.015390569, -0.006334281, 0.043696404, -0.028341983, 0.028433999, 0.01472686, -0.06585564, -0.044533554, 0.0055523133, 0.035775978, 0.031099156, 0.027357662, 0.028062241, 0.056972917, -0.054656833, -0.027864764, -0.15486294, -0.027930057, 0.043678433, 0.008391214, 0.020209847, 0.002841071, -0.07201404, -0.05025868, -0.034896467, -0.030400582, 0.016623711, -0.050455835, -0.025557702, 0.0050540236, 0.032266915, -0.018223321, -0.04913693, 0.07667526, -0.03066128, -0.0127946865, 0.107169494, -0.0563475, -0.016773727, -0.010336115, -0.05220995, -0.022049127, 0.00478732, -0.039094422, 0.015671168, 0.041542538, 0.016112784, -0.022650082, 0.002988097, -0.061147556, 0.06630078, -0.057244215, -0.013767544, -0.003466806, -0.053994596, 0.04230463, -0.029314812, 0.021347178, 0.04522084, 0.0072983643, 0.0247336, 0.020755325, 0.025620919, 0.021721177, 0.008178421, -0.063603185, 0.025854606, -0.037521806, 0.020877894, 0.033131972, 0.030288944, 0.0033810628, -0.048698667, -0.027295412, 0.05622969, 0.029634966, 0.029705161, 0.010602056, -0.02217137, 0.03195607, 0.047208548, -0.013620148, 0.038463745, 0.0052672145, -0.024868716, -0.0071725682, 0.069668904, -0.092124775, 0.014151632, 0.0057337824, -0.006012909, -0.035946254, 0.0334855, -0.07426327, 0.033595767, 0.033740804, 0.0394573, -0.048899952, 0.06265119, 0.0028897377, 0.0039847936, 0.0329678, -0.012809373, 0.050108086, 0.009314281, -0.019499086, -0.048860364, -0.015204039, 0.008016007, 0.015667893, 0.03903864, 0.011732032, 0.034669068, -0.01226258, -0.052623957, 0.006695629, -0.05849198, 0.00101155, -0.009621011, -0.0052014636, -0.020959012, -0.02466722, -0.03861565, 0.049754824, 0.048655674, 0.0044479654, -0.020404046, 0.101043485, -0.022594253, -0.06822699, 0.044780556, -0.03859346, -0.015194885, -0.0059435116, -0.016267126, 0.0012126336, 0.054198146, 0.01978253, 0.02905382, 0.034172967, -0.0032252679, 0.020003818, 0.07212547, 0.035888623, -0.00029856138, 0.0044168616, 0.036989234, 0.100975856, 0.0048228516, -0.04405796, 0.00039434276, -0.044601627, -0.011658614, 0.03398768, 0.02250937, 0.034583274, -0.03440395, -0.003274625, -0.005927225, -0.007679341, -0.025777208, -0.02205426, 0.00823437, -0.027172998, -0.015607741, -0.022958823, 0.098416075, -0.045472592, 0.031623535, 0.030663209, -0.03987397, 0.0048750523, 0.057770126, 0.04547866, 0.009881574, 0.044948515, 0.012011639, 0.003141497, 0.0016209317, 0.07142094, 0.025111957, -0.049478546, 0.052195616, 0.041401174, 0.0380032, -0.05878786, -0.007194873, -0.015402912, 0.048243146, 0.025205499, 0.051020827, -0.030305905, -0.031656887, -0.008994425, 0.039839912, -0.043015696, 0.008373317, -0.089018084, -0.045301914, -0.0074205245, 0.0049243467, 0.060365975, -0.06462967, 0.00815101, -0.020417998, 0.00030822973, -0.039288856, 0.04017253, 0.03137731, -0.031728875, 0.03444872, -0.031234143, -0.048502136, 0.033941768, 0.034225147, -0.008299359, 0.033098515, -0.012317135, 0.014448822, 0.06187389, -0.059683096, 0.0012899865, 0.007460227, 0.02652167, -0.07248658, -0.05818953, 0.030052334, -0.015347992, -0.035913672, -0.034901086, -0.0661791, -0.055562418, 0.0130468095, -0.0035763406, -0.0086615095, -0.046888705, -0.005326655, -0.021710252, 0.072175056, 0.038597208, -0.038364347, -0.005039459, -0.07634857, -0.045539834, -0.07372115, 0.018378831, -0.032071035, -0.030269828, -0.044599615, 0.05132602, 0.04769642, 0.0014855171, -0.028435005, -0.0016777773, 0.0072506564, 0.08448479, 0.04224268, -0.029304115, 0.02165559, 0.0056837583, 0.06214723, 0.0028552667, 0.015904495, -0.016737062, -0.0040876116, -0.037475582, 0.04675168, -0.052556172, -0.016293239, -0.014435561, 0.022127734, -0.0052535324, 0.0190588, -0.011537659, 0.0484614, 0.028816467, 0.024607794, -0.043762755, 0.011608192, 0.0021703655, -0.045297306, -0.0048169023, -0.0071430723, 0.011705694, -0.05006429, 0.029933244, -0.020802287, -0.07580785, -0.012268235, 0.07616304, 0.006857885, 0.01853518, 0.043729052, -0.032221675, -0.010121849, -0.019831154, -0.032731093, 0.051531196, 0.0024111927, 0.090960525, -0.036896333, 0.035708647, 0.03678696, 0.00832481, 0.001757778, 0.04144341, 0.042203393, 0.0045936033, 0.021837635, 0.0066275136, 0.0069022025, -0.008452904, -0.03277543, 0.0061044246, -0.02500629, 0.012441071, 0.018081166, -0.06230864, -0.040046707, -0.019351328, 0.0007255696, 0.002202931, -0.041990966, 0.023313772, 0.039377946, -0.0012839311, 0.010378518, 0.0025737497, 0.043841247, -0.0067742136, 0.045794934, 0.01388272, 0.032243907, 0.0919292, 0.03760722, 0.0060486114, 0.010843367, 0.001991803, -0.04838942, -0.006412631, -0.030764624, -0.015602797, -0.048885867, -0.015245706, -0.0006477355, 0.013608845, 0.0040335134, -0.0015530499, -0.008402027, -0.05728556, -0.027370622, 0.019342335, -0.039145477, 0.049000833, -0.052876346, -0.060248777, 0.009484413, 0.011271402, -0.0019944878, -0.013369263, -0.0130786, 0.0050903647, -0.0003995775, 0.04580157, -0.030488051, -0.07777237, 0.022998745, -0.007693635, -0.013473893, 0.0071830116, 0.014312745, 0.019949466, 0.034036275, -0.0011623668, 0.022655929, -0.0049825236, -0.036455333, 0.0033899196, 0.020583669, -0.010457001, 0.027299065, 0.034606297, -0.0111719165, -0.013660416, -0.02705466, -0.05144293, -0.07396907, -0.022817062, -0.0064836126, 0.037774086, -0.06774259, 0.016620712, -0.046481006, -0.030288063, -0.055035893, -0.015402408, -0.014477583, 0.0024700973, 0.024081903, -0.008900536, -0.0032105052, 0.026591286, -0.027869076, -0.014552753, -0.026460772, 0.06831125, -0.019622969, -0.028588912, -0.02271201, 0.0019694276, 0.0079966, -0.013207389, -0.07246265, -0.005246626, -0.03556684, 0.014131167, 0.0018361827, -0.084728725, 0.010380415, -0.038140625, 0.0066234693, -0.023485202, 0.05133969, 0.018931301, -0.0077241925, -0.01968148, -0.0615474, 0.036711216, 0.028462604, -0.02205502, 0.02294784, 0.03529192, 0.044653304, -0.029656367, -0.04243813, -0.024271922, 0.008206945, -0.015324323, 0.028326686, 0.0708875, 0.03499979, -0.04111004, -0.02691298, -0.011054021, 0.035632536, 0.057256706, -0.058149684, 0.022313014, -0.03727344, 0.0095027555, -0.0325091, -0.007395906, 0.009455788, 0.0053972304, -0.028935568, 0.054196633, -0.051867362, -0.010642803, 0.034427024, 0.04308132, 0.020671992, 0.068610825, 0.018303277, -0.08433639, 0.0023544622, -0.009237108, -0.0410166, 0.012912618, -0.035220295, 0.032994937, -0.0063333404, -0.028377546, 0.05429965, -0.022590995, -0.033762764, -0.0061482205, 0.0014308131, 0.05402618, -0.030298075, -0.020893354, 0.04020406, -0.013849863, -0.047842298, 0.032006662, 0.037729368, -0.02878951, 0.002758488, -0.0023380243, -0.052403864, 0.021707276, -0.02718091, 0.0045513017, 0.02493268, -0.016037108, 0.009521465, 0.022595555, -0.03332406, -0.01791281, -0.026219219, 0.015336862, 0.018615942, 0.0014700901, 0.005194217, -0.0059983027, -0.002134208, 0.055935774, 0.0002028429, -0.01381741, 0.0005677742, 0.052481145, -0.0056857914, -0.024219796, -0.0074823913, 0.041230515, 0.005571935, 0.06841099, -0.025634678, -0.037456885, -0.0021465495, -0.05163424, 0.048833348, 0.057269894, -0.0017718605, -0.012836743, 0.054180846, -0.032427873, -0.003244846, 0.01254491, 0.0071952185, 0.02080726, 0.015043071, -0.08000574, 0.047099367, -0.009071923, 0.022494175, -0.007407801, -0.018199192, 0.01923855, -0.016820459, 0.026590073, 0.05919531, -0.015211094, -0.051043298, 0.05085604, -0.027980763, -0.01785205, 0.05260401, 0.0039136643, -0.010834236, 0.015846392, 0.011993318, 0.0085244095, -0.09705911, 0.004848937, -0.03151453, -0.049902532, -0.023312563, 0.0169847, 0.051852323, -0.018586924, -0.011750037, 0.020324359, 0.041236103, 0.046270456, 0.045885824, 0.035645086, 0.027820086, -0.054944187, -0.0018159872, 0.06008568, 0.056207847, 0.03509413, 0.07476336, 0.00042090056, -0.01791933, -0.049269866, 0.013118644, 0.03817175, 0.03985353, -0.023338122, -0.05917611, -0.040447813, -0.014515073, -0.01641867, -0.012444603, -0.015801677, 0.01694387, -0.012097041, -0.10444289, -0.044068433, 0.028175205, 0.0032158983, 0.017225135, 0.024197249, 0.0003871886, 0.008296747, 0.0020322825, -0.06488942, -0.028532177, 0.03631236, -0.021784041, -0.028676897, 0.020023972, -0.015093374, -0.0053404626, -0.035407133, -0.03022746, -0.045240995, -0.089037456, -0.05241791, 0.01601896, -0.058039088, 0.06633133, 0.01435994, 0.0024608225, 0.02044063, 0.049869247, 0.013966787, 0.011062478, 0.023516618, 0.010368709, 0.039040443, -0.03096598, -0.01665127, 0.010691767, -0.0089797005, 0.018564576, 0.03291386, 0.0032383145, -0.00884169, -0.008645399, 0.0001677955, -0.04452774, 0.007207213, -0.008696507, 0.0023566217, -0.025329702, -0.042708885, -0.03173582, 0.06427912, 0.030916397, -0.022305708, -0.018711232, -0.008136281, -0.01636213, 0.019092057, 0.010243902, -0.04405114, 0.018331835, -0.025844995, 0.035896596, 0.049257137, -0.053962618, -0.084952496, -0.009314442, -0.03644633, 0.0010881334, -0.042904764, 0.016017154, -0.011390375, 0.056498464, 0.007735383, 0.015750613, 0.023586866, -0.005065194, -0.05339934, 0.030084236, -0.021841932, -0.0035868485, -0.025362536, 0.0315042, 0.039552346, -0.032164883, -0.03519624, -0.013936666, 0.006526046, 0.02818671, -0.018081086, 0.04806136, -0.04418975, -0.064630605, -0.010125073, -0.02926605, 0.022641547, 0.040159058, 0.022463534, -0.04924557, -0.010198766, -0.019940902, -0.0033762371, -0.07010838, -0.031799905, -0.020567331, -0.015259151, 0.04870838, 0.030047685, -0.016861487, 0.020778332, -0.034649372, -0.0026895248, -0.0053685517, -0.03297844, -0.0048753927, -0.005587019, -0.041837722, 0.0161564, 0.072810896, -0.043315165, 0.03330332]}
构建嵌入数据库#
以下是用于构建嵌入数据库的三个示例文本。您将使用 Gemini API 创建每个文档的嵌入。将它们转换为数据框以实现更好的可视化。
DOCUMENT1 = {
"title": "Operating the Climate Control System",
"content": "Your Googlecar has a climate control system that allows you to adjust the temperature and airflow in the car. To operate the climate control system, use the buttons and knobs located on the center console. Temperature: The temperature knob controls the temperature inside the car. Turn the knob clockwise to increase the temperature or counterclockwise to decrease the temperature. Airflow: The airflow knob controls the amount of airflow inside the car. Turn the knob clockwise to increase the airflow or counterclockwise to decrease the airflow. Fan speed: The fan speed knob controls the speed of the fan. Turn the knob clockwise to increase the fan speed or counterclockwise to decrease the fan speed. Mode: The mode button allows you to select the desired mode. The available modes are: Auto: The car will automatically adjust the temperature and airflow to maintain a comfortable level. Cool: The car will blow cool air into the car. Heat: The car will blow warm air into the car. Defrost: The car will blow warm air onto the windshield to defrost it."}
DOCUMENT2 = {
"title": "Touchscreen",
"content": "Your Googlecar has a large touchscreen display that provides access to a variety of features, including navigation, entertainment, and climate control. To use the touchscreen display, simply touch the desired icon. For example, you can touch the \"Navigation\" icon to get directions to your destination or touch the \"Music\" icon to play your favorite songs."}
DOCUMENT3 = {
"title": "Shifting Gears",
"content": "Your Googlecar has an automatic transmission. To shift gears, simply move the shift lever to the desired position. Park: This position is used when you are parked. The wheels are locked and the car cannot move. Reverse: This position is used to back up. Neutral: This position is used when you are stopped at a light or in traffic. The car is not in gear and will not move unless you press the gas pedal. Drive: This position is used to drive forward. Low: This position is used for driving in snow or other slippery conditions."}
documents = [DOCUMENT1, DOCUMENT2, DOCUMENT3]
将字典的内容组织到数据框中以实现更好的可视化。
df = pd.DataFrame(documents)
df.columns = ['Title', 'Text']
df
Title | Text | |
---|---|---|
0 | Operating the Climate Control System | Your Googlecar has a climate control system th... |
1 | Touchscreen | Your Googlecar has a large touchscreen display... |
2 | Shifting Gears | Your Googlecar has an automatic transmission. ... |
获取每个文本正文的嵌入。将此信息添加到数据框中。
# Get the embeddings of each text and add to an embeddings column in the dataframe
def embed_fn(title, text):
return genai.embed_content(model=model,
content=text,
task_type="retrieval_document",
title=title)["embedding"]
df['Embeddings'] = df.apply(lambda row: embed_fn(row['Title'], row['Text']), axis=1)
df
Title | Text | Embeddings | |
---|---|---|---|
0 | Operating the Climate Control System | Your Googlecar has a climate control system th... | [-0.033361107, -0.021217084, -0.049581926, -0.... |
1 | Touchscreen | Your Googlecar has a large touchscreen display... | [0.009660736, -0.030662702, -0.017281422, -0.0... |
2 | Shifting Gears | Your Googlecar has an automatic transmission. ... | [-0.04270796, -0.007160868, -0.03242516, -0.02... |
文档搜索与问答#
现在嵌入已经生成,让我们创建一个问答系统来搜索这些文档。您将询问有关超参数调整的问题,创建问题的嵌入,并将其与数据框中的嵌入集合进行比较。
问题的嵌入将是一个向量(浮点值列表),它将使用点积与文档向量进行比较。从 API 返回的该向量已经标准化。点积表示两个向量之间方向的相似性。
点积的值可以在 -1 和 1 之间(包含 -1 和 1)。如果两个向量之间的点积为 1,则这两个向量的方向相同。如果点积值为 0,则这些向量彼此正交或不相关。最后,如果点积为 -1,则向量指向相反方向并且彼此不相似。
请注意,使用新的嵌入模型 (embedding-001
),将任务类型指定为QUERY
(用于用户查询)和DOCUMENT
(嵌入文档文本时)。
任务类型 |
描述 |
---|---|
RETRIEVAL_QUERY |
指定给定文本是搜索/检索设置中的查询。 |
RETRIEVAL_DOCUMENT |
指定给定文本是搜索/检索设置中的文档。 |
query = "How do you shift gears in the Google car?"
model = 'models/embedding-001'
request = genai.embed_content(model=model,
content=query,
task_type="retrieval_query")
使用find_best_passage
函数计算点积,然后将数据帧按点积值从最大到最小进行排序,以从数据库中检索相关段落。
def find_best_passage(query, dataframe):
"""
Compute the distances between the query and each document in the dataframe
using the dot product.
"""
query_embedding = genai.embed_content(model=model,
content=query,
task_type="retrieval_query")
dot_products = np.dot(np.stack(dataframe['Embeddings']), query_embedding["embedding"])
idx = np.argmax(dot_products)
return dataframe.iloc[idx]['Text'] # Return text from index with max value
passage = find_best_passage(query, df)
passage
'Your Googlecar has an automatic transmission. To shift gears, simply move the shift lever to the desired position. Park: This position is used when you are parked. The wheels are locked and the car cannot move. Reverse: This position is used to back up. Neutral: This position is used when you are stopped at a light or in traffic. The car is not in gear and will not move unless you press the gas pedal. Drive: This position is used to drive forward. Low: This position is used for driving in snow or other slippery conditions.'
问答应用#
让尝试使用文本生成API来创建一个问答系统。在下面输入自己的自定义数据以创建简单的问题和回答示例。仍将使用点积作为相似性度量。
def make_prompt(query, relevant_passage):
escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
prompt = textwrap.dedent("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
strike a friendly and converstional tone. \
If the passage is irrelevant to the answer, you may ignore it.
QUESTION: '{query}'
PASSAGE: '{relevant_passage}'
ANSWER:
""").format(query=query, relevant_passage=escaped)
return prompt
prompt = make_prompt(query, passage)
print(prompt)
You are a helpful and informative bot that answers questions using text from the reference passage included below. Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. However, you are talking to a non-technical audience, so be sure to break down complicated concepts and strike a friendly and converstional tone. If the passage is irrelevant to the answer, you may ignore it.
QUESTION: 'How do you shift gears in the Google car?'
PASSAGE: 'Your Googlecar has an automatic transmission. To shift gears, simply move the shift lever to the desired position. Park: This position is used when you are parked. The wheels are locked and the car cannot move. Reverse: This position is used to back up. Neutral: This position is used when you are stopped at a light or in traffic. The car is not in gear and will not move unless you press the gas pedal. Drive: This position is used to drive forward. Low: This position is used for driving in snow or other slippery conditions.'
ANSWER:
选择一种 Gemini 内容生成模型来找到您查询的答案。
for m in genai.list_models():
if 'generateContent' in m.supported_generation_methods:
print(m.name)
models/gemini-pro
models/gemini-pro-vision
# model = genai.GenerativeModel('gemini-ultra') 并没有gemini-ultra
model = genai.GenerativeModel('gemini-pro')
answer = model.generate_content(prompt)
Markdown(answer.text)
### 正常情况
## The provided passage does not contain information about how to shift gears in a Google car, so I cannot answer your question from this source.
The Google car has an automatic transmission, so you don’t need to worry about shifting gears manually. Simply move the shift lever to the desired position, indicated on the gear shift, and the car will automatically shift to that gear.
下一步#
要了解有关如何使用嵌入的更多信息,请查看可用的示例。
要了解如何使用 Gemini API 中的其他服务,请访问Python 快速入门。