上一篇提到基于 Elasticsearch 和 LLM 的实践探索(https://zhuanlan.zhihu.com/p/12528541608) 讲到基于关系数据库检索,定位不到用户真正关心问题的地方,这篇尝试下向量数据库,基于文本向量相似度来提高检索的准确性 本文构建一个调用流程:用户Query -> 向量数据库检索 -> Prompt -> LLM -> 回复的
准备数据
参考上一篇涉及到的内容和代码
本地搭建检索引擎
向量数据库介绍 在介绍向量数据库前清楚几个概念:
- 向量数据库的意义是快速的检索;
- 向量数据库本身不生成向量,向量是由 Embedding 模型产生的;
- 向量数据库与传统的关系型数据库是互补的,不是替代关系,在实际应用中根据实际需求经常同时使用。
主流向量数据库功能对比
生产环境推荐Milvus,Weaviate
● Milvus: 开源向量数据库,同时有云服务 https://milvus.io/ ● Weaviate: 开源向量数据库,同时有云服务 https://weaviate.io/ ● FAISS: Meta 开源的向量检索引擎 https://github.com/facebookresearch/faiss ● Pinecone: 商用向量数据库,只有云服务 https://www.pinecone.io/ ● Qdrant: 开源向量数据库,同时有云服务 https://qdrant.tech/ ● PGVector: Postgres 的开源向量检索引擎 https://github.com/pgvector/pgvector ● RediSearch: Redis 的开源向量检索引擎 https://github.com/RediSearch/RediSearch ● ElasticSearch 也支持向量检索 https://www.elastic.co/enterprise-search/vector-search
安装chromadb
为了方便演示,我这用内存向量数据库chromadb 安装依赖包:pip install chromadb
代码实现关键词检索
实现将pdf文档内融合灌入到db库中,注意client.embeddings.create文本向量化比较耗时,可以先测试小文本
import chromadb
from chromadb.config import Settings
from numpy import dot
from numpy.linalg import norm
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
import load_pdf as cxt
import tiktoken
import concurrent.futures
import re
_ = load_dotenv(find_dotenv()) # 读取本地 .env 文件,里面定义了 OPENAI_API_KEY
client = OpenAI()
def count_tokens(text, model="text-embedding-ada-002"):
encoding = tiktoken.encoding_for_model(model)
return sum(len(encoding.encode(cxt)) for cxt in text)
class MyVectorDBConnector:
def __init__(self, collection_name, embedding_fn):
chroma_client = chromadb.Client(Settings(allow_reset=True))
# 为了演示,实际不需要每次 reset()
chroma_client.reset()
# 创建一个 collection
self.collection = chroma_client.get_or_create_collection(
name=collection_name)
self.embedding_fn = embedding_fn
def add_documents(self, documents):
'''向 collection 中添加文档与向量'''
embeddings = self.embedding_fn(documents)
if len(embeddings) != len(documents):
print(f"Mismatch: {len(embeddings)} embeddings for {len(documents)} documents")
return
self.collection.add(
embeddings=embeddings,
documents=documents,
ids=[f"id{i}" for i in range(len(documents))]
)
def add_documents_pool(self, documents, batch_size=10):
'''Add documents and vectors to the collection in batches'''
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for i in range(0, len(documents), batch_size):
batch_documents = documents[i:i + batch_size]
futures.append(executor.submit(self.embedding_fn, batch_documents))
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
embeddings = future.result()
if embeddings is not None:
self.collection.add(
embeddings=embeddings, # 每个文档的向量
documents=batch_documents, # 当前批次的文档
ids=[f"id{j}" for j in range(i * batch_size, i * batch_size + len(batch_documents))] # 当前批次的文档 id
)
except Exception as e:
print(f"An error occurred while processing batch {i}: {e}")
def search(self, query, top_n):
'''检索向量数据库'''
results = self.collection.query(
query_embeddings=self.embedding_fn([query]),
n_results=top_n
)
return results
def check_texts_type(texts):
if isinstance(texts, list):
print("texts is a list")
elif isinstance(texts, str):
print("texts is a string")
else:
print(f"texts is of type {type(texts)}")
def get_embeddings(texts, model="text-embedding-ada-002"):
try:
if not isinstance(texts, (list, tuple)):
print("错误:'texts' 不是可迭代的类型。请确保它是列表或元组。")
response = client.embeddings.create(input=texts, model=model)
if response.data is None:
print("Response data is None")
return []
return [x.embedding for x in response.data]
except Exception as e:
print(f"An error occurred: {e}")
return []
# 使用示例
# asyncio.run(get_embeddings_async(paragraphs, model="text-embedding-ada-002"))
def sent_tokenize(input_string):
"""按标点断句"""
# 按标点切分
sentences = re.split(r'(?<=[。!?;?!])', input_string)
# 去掉空字符串
return [sentence for sentence in sentences if sentence.strip()]
def split_text(paragraphs, chunk_size=300, overlap_size=100):
'''按指定 chunk_size 和 overlap_size 交叠割文本'''
sentences = [s.strip() for p in paragraphs for s in sent_tokenize(p)]
chunks = []
i = 0
while i < len(sentences):
chunk = sentences[i]
overlap = ''
prev_len = 0
prev = i - 1
# 向前计算重叠部分
while prev >= 0 and len(sentences[prev])+len(overlap) <= overlap_size:
overlap = sentences[prev] + ' ' + overlap
prev -= 1
chunk = overlap+chunk
next = i + 1
# 向后计算当前chunk
while next < len(sentences) and len(sentences[next])+len(chunk) <= chunk_size:
chunk = chunk + ' ' + sentences[next]
next += 1
chunks.append(chunk)
i = next
return chunks
def main():
paragraphs = ["文本1", "文本2", "文本3"] # 示例文本
vector_db = MyVectorDBConnector("demo", get_embeddings)
vector_db.add_documents(paragraphs)
# 运行异步主函数
#asyncio.run(main())
if __name__ == "__main__":
#paragraphs = cxt.extract_text_from_pdf("/Users/liuqiang/code/ai/lq/RAG/llama2.pdf", min_line_length=10)
paragraphs = cxt.extract_text_from_pdf("/Users/liuqiang/code/ai/lq/RAG/haikangweishi.pdf", min_line_length=10)
# 创建一个向量数据库对象
vector_db = MyVectorDBConnector("demo", get_embeddings)
#paragraphs = ["文本1", "文本2", "文本3"] # 示例文本
chunks = split_text(paragraphs, 300, 100)
# 向向量数据库中添加文档
vector_db.add_documents(chunks)
# user_query = "Llama 2有多少参数"
user_query = "海康威视2023年营收是多少?"
#user_query = "Does Llama 2 have a conversational variant"
results = vector_db.search(user_query, 2)
for para in results['documents'][0]:
print(para+"\n")
从运行结果看:
调用LLM 接口
代码如下:
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
from myVertor import MyVectorDBConnector, get_embeddings,client
import load_pdf as cxt
import myVertor as mv
prompt_template = """
你是一个问答机器人。
你的任务是根据下述给定的已知信息回答用户问题。
已知信息:
{context}
用户问:
{query}
如果已知信息不包含用户问题的答案,或者已知信息不足以回答用户的问题,请直接回复"我无法回答您的问题"。
请不要输出已知信息中不包含的信息或答案。
请用中文回答用户问题。"""
def get_completion(prompt, model='gpt-4o'):
'''封装 openai 接口'''
messages = [{"role": "user", "content": prompt}]
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0, # 模型输出的随机性,0 表示随机性最小
)
return response.choices[0].message.content
def build_prompt(prompt_template, **kwargs):
'''将 Prompt 模板赋值'''
inputs = {}
for k, v in kwargs.items():
if isinstance(v, list) and all(isinstance(elem, str) for elem in v):
val = '\n\n'.join(v)
else:
val = v
inputs[k] = val
return prompt_template.format(**inputs)
class RAG_Robot:
def __init__(self, vector_db, llm_api, n_results=2):
self.vector_db = vector_db
self.llm_api = llm_api
self.n_results = n_results
def chat(self, user_query):
# 1. 检索
search_results = self.vector_db.search(user_query, self.n_results)
# 2. 构建 Prompt
prompt = build_prompt(
prompt_template, context=search_results['documents'][0], query=user_query)
print("===Prompt===")
print(prompt)
# 3. 调用 LLM
response = self.llm_api(prompt)
return response
def loadDb(vector_db):
paragraphs = cxt.extract_text_from_pdf("/Users/liuqiang/code/ai/lq/RAG/haikangweishi.pdf", min_line_length=10)
# 创建一个向量数据库对象
#paragraphs = ["文本1", "文本2", "文本3"] # 示例文本
chunks = mv.split_text(paragraphs, 300, 100)
# 向向量数据库中添加文档
vector_db.add_documents(chunks)
if __name__ == '__main__':
vector_db = MyVectorDBConnector("demo", get_embeddings)
loadDb(vector_db)
print("===Prompt===")
bot = RAG_Robot(
vector_db,
llm_api=get_completion
)
user_query = "海康威视2023年营收是多少?"
response = bot.chat(user_query)
print(response)
从代码中的prompt组成,我们的问题加上向量数据库检索到的补充信息,经LLM加工处理后返回答案,如下图
和财报中的数据一致
本文使用 文章同步助手 同步