Skip to content

Commit

Permalink
修复无法从本地向量数据库加载的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
HildaM committed Jun 19, 2023
1 parent f57e55a commit e3af56e
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 91 deletions.
2 changes: 1 addition & 1 deletion config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"platform": "bilibili",
"room_display_id": "278333",
"room_display_id": "你的直播间号",
"chat_type": "none",
"need_lang": "none",
"before_prompt": "请简要回复:",
Expand Down
18 changes: 11 additions & 7 deletions utils/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
"""
模型1:"sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
"""
model_name = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
DEFAULT_MODEL_NAME = "sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco"
DEFAULT_MODEL_KWARGS = {'device': 'cpu'}
DEFAULT_ENCODE_KWARGS = {'normalize_embeddings': False}
hf_embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
model_name=DEFAULT_MODEL_NAME,
model_kwargs=DEFAULT_MODEL_KWARGS,
encode_kwargs=DEFAULT_ENCODE_KWARGS
)

EMBEDDINGS_MAPPING = {"distilbert-dot-tas_b-b256-msmarco": hf_embeddings}

"""
模型列表
"""
EMBEDDINGS_MAPPING = {DEFAULT_MODEL_NAME: hf_embeddings}
100 changes: 51 additions & 49 deletions utils/faiss_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
@Description : 本地向量数据库配置
"""


import json
import logging

from langchain.vectorstores import FAISS
import os
from tqdm.auto import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import DirectoryLoader, TextLoader
from utils.embeddings import EMBEDDINGS_MAPPING
from utils.embeddings import EMBEDDINGS_MAPPING, DEFAULT_MODEL_NAME
import tiktoken
import zipfile
import pickle
Expand Down Expand Up @@ -80,48 +81,54 @@ def get_chunks(docs, chunk_size=500, chunk_overlap=20, length_function=tiktoken_
"""


def create_faiss_index_from_zip(path_to_zip_file, embeddings=None, pdf_loader=None,
chunk_size=500, chunk_overlap=20,
project_name="Please_Delete_This_File_After_Running"):
# 存储的文件格式
# structure: project_name
# - source data
# - embeddings
# - faiss_index
if isinstance(embeddings, str):
def create_faiss_index_from_zip(path_to_zip_file, embedding_model_name=None, pdf_loader=None,
chunk_size=500, chunk_overlap=20):
# 获取模型名称
if isinstance(embedding_model_name, str):
import copy
embeddings_str = copy.deepcopy(embeddings)
embeddings_str = copy.deepcopy(embedding_model_name)
else:
embeddings_str = "distilbert-dot-tas_b-b256-msmarco" # 默认模型
embeddings_str = DEFAULT_MODEL_NAME # 默认模型

# 选择模型
if embeddings is None:
embeddings = EMBEDDINGS_MAPPING["distilbert-dot-tas_b-b256-msmarco"]
elif isinstance(embeddings, str):
embeddings = EMBEDDINGS_MAPPING[embeddings]
if embedding_model_name is None:
embeddings = EMBEDDINGS_MAPPING[DEFAULT_MODEL_NAME]
elif isinstance(embedding_model_name, str):
embeddings = EMBEDDINGS_MAPPING[embedding_model_name]

# 创建存储向量数据库的目录
# 存储的文件格式
# structure: ./data/vector_base
# - source data
# - embeddings
# - faiss_index
store_path = os.getcwd() + "/data/vector_base/"
if not os.path.exists(store_path + project_name):
if not os.path.exists(store_path):
os.makedirs(store_path)
project_path = os.path.join(store_path, project_name)
project_path = store_path
source_data = os.path.join(project_path, "source_data")
embeddings_data = os.path.join(project_path, "embeddings")
index_data = os.path.join(project_path, "faiss_index")
os.makedirs(source_data) # ./project/source_data
os.makedirs(embeddings_data) # ./project/embeddings
os.makedirs(index_data) # ./project/faiss_index
os.makedirs(source_data) # ./vector_base/source_data
os.makedirs(embeddings_data) # ./vector_base/embeddings
os.makedirs(index_data) # ./vector_base/faiss_index
else:
raise ValueError(f"向量数据库文件夹重名,请删除重名文件夹后再启动。")
logging.warning(
"向量数据库已存在,默认加载旧的向量数据库。如果需要加载新的数据,请删除data目录下的vector_base,再重新启动")
logging.info("正在加载已存在的向量数据库文件")
db = load_exist_faiss_file(store_path)
if db is None:
logging.error("加载旧数据库为空,数据库文件可能存在异常。请彻底删除vector_base文件夹后,再重新导入数据")
exit(-1)
return db

# 解压数据包
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
# extract everything to "source_data"
zip_ref.extractall(source_data)

# 组装数据库元信息
db_meta = {"project_name": project_name,
"pdf_loader": pdf_loader.__name__, "chunk_size": chunk_size,
db_meta = {"pdf_loader": pdf_loader.__name__, "chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
"embedding_model": embeddings_str,
"files": os.listdir(source_data),
Expand Down Expand Up @@ -178,34 +185,28 @@ def find_file_dir(file_name, directory):
return None # If the file was not found


# 加载本地数据
def load_faiss_index_from_zip(path_to_zip_file):
# Extract the zip file. Read the db_meta
# base_name = os.path.basename(path_to_zip_file)
path_to_extract = os.path.join(os.getcwd())
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
zip_ref.extractall(path_to_extract)

db_meta_json = find_file("db_meta.json", path_to_extract)
# 加载本地向量数据库
def load_exist_faiss_file(path):
# 获取元数据
db_meta_json = find_file("db_meta.json", path)
if db_meta_json is not None:
with open(db_meta_json, "r", encoding="utf-8") as f:
db_meta_dict = json.load(f)
else:
raise ValueError("Cannot find `db_meta.json` in the .zip file. ")

try:
embeddings = EMBEDDINGS_MAPPING[db_meta_dict["embedding_model"]]
except:
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")

# locate index.faiss
index_path = find_file_dir("index.faiss", path_to_extract)
if index_path is not None:
db = FAISS.load_local(index_path, embeddings)
logging.error("vector_base向量数据库已损坏,请彻底删除该文件夹后,再重新导入数据!")
exit(-1)

# 获取模型数据
embedding = EMBEDDINGS_MAPPING[db_meta_dict["embedding_model"]]

# 加载index.faiss
faiss_path = find_file_dir("index.faiss", path)
if faiss_path is not None:
db = FAISS.load_local(faiss_path, embedding)
return db
else:
raise ValueError("Failed to find `index.faiss` in the .zip file.")
logging.error("加载index.faiss失败,模型已损坏。请彻底删除vector_base文件夹后,再重新导入一次数据")
exit(-1)


# 测试代码
Expand All @@ -223,5 +224,6 @@ def load_faiss_index_from_zip(path_to_zip_file):
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
create_faiss_index_from_zip(path_to_zip_file=zip_file_path, pdf_loader=PyPDFLoader, embeddings=embeddings)

db = load_faiss_index_from_zip(zip_file_path)
db = load_exist_faiss_file(zip_file_path)
if db is not None:
logging.info("加载本地数据库成功!")
47 changes: 13 additions & 34 deletions utils/langchain_pdf_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
@Date : 2023/06/17 下午 4:44
@Description : 本地化向量数据库,实现langchain_pdf
"""


import logging
import uuid
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings

from utils.claude import Claude
from utils.embeddings import EMBEDDINGS_MAPPING
from utils.faiss_handler import load_faiss_index_from_zip, create_faiss_index_from_zip
from utils.my_handle import My_handle

Expand All @@ -22,10 +22,10 @@
# 返回的数据很标准,可以很方便获取content信息
def get_content(data: str):
prefix = "{'content': "
surfix = ", 'chunk'"
suffix = ", 'chunk'"

start = data.find(prefix)
end = data.find(surfix)
end = data.find(suffix)
return data[start:end]


Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, data, chat_type="langchain_pdf_local"):
self.langchain_pdf_question_prompt = data["question_prompt"]
self.langchain_pdf_max_query = data["max_query"]

print(f"pdf文件路径{self.langchain_pdf_data_path}")
print(f"本地数据文件路径{self.langchain_pdf_data_path}")

# 加载pdf并生成向量数据库
self.load_zip_as_db(self.langchain_pdf_data_path, self.pdf_loader,
Expand All @@ -75,64 +75,43 @@ def load_zip_as_db(self, zip_file,
chunk_size=300,
chunk_overlap=20):
if chunk_size <= chunk_overlap:
print("ERROR: chunk_size小于chunk_overlap. 创建失败.")
logging.error("chunk_size小于chunk_overlap. 创建失败.")
return
if zip_file is None:
print("ERROR: 文件为空. 创建失败.")
logging.error("文件为空. 创建失败.")
return

project_name = uuid.uuid4().hex

model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)

self.local_db = create_faiss_index_from_zip(
path_to_zip_file=zip_file,
embeddings=embeddings,
embedding_model_name=self.langchain_pdf_embedding_model,
pdf_loader=pdf_loader,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
project_name=project_name
chunk_overlap=chunk_overlap
)

print("成功创建向量知识库!")
logging.info("成功创建向量知识库!")

# 调用本地向量数据库,获取关联信息
def get_local_database_data(self, message):
print(f"开始从本地向量数据库中查询有关”{message}“的信息........")
logging.info(f"开始从本地向量数据库中查询有关”{message}“的信息........")

contents = []
docs = self.local_db.similarity_search(message, k=self.langchain_pdf_max_query)
for i in range(self.langchain_pdf_max_query):
# 预处理分块
content = docs[i].page_content.replace('\n', ' ')
print(f"No.{i} 相关联信息: {content}")
logging.info(f"No.{i} 相关联信息: {content}")
data = get_content(content)
# 更新contents
contents.append(data)

print("从本地向量数据库查询到的相关信息: {}".format(contents))
logging.info("从本地向量数据库查询到的相关信息: {}".format(contents))
if len(contents) == 0 or contents is None:
return
related_data = "\n---\n".join(contents) + "\n---\n"
return related_data

def load_local_db(self, zip_file):
if zip_file is None:
return "文件为空. 创建失败.", None
self.local_db = load_faiss_index_from_zip(zip_file)

print("成功读取知识库")

def get_langchain_pdf_local_resp(self, chat_type="langchain_pdf", question=""):
if self.local_db is None:
self.load_local_db(self.langchain_pdf_data_path)

related_data = self.get_local_database_data(question)
if related_data is None or len(related_data) <= 0:
content = question
Expand Down

0 comments on commit e3af56e

Please sign in to comment.