-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 78e9bfd
Showing
17 changed files
with
1,163 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import os | ||
import PyPDF2 | ||
import tiktoken | ||
|
||
|
||
|
||
enc = tiktoken.get_encoding("cl100k_base") | ||
|
||
class ReadFile: | ||
|
||
def __init__(self, path): | ||
self.path = path | ||
|
||
|
||
def readlist(self): | ||
file_list = [] | ||
for filepath, dirnames, filenames in os.walk(self.path): | ||
# os.walk 函数将递归遍历指定文件夹 | ||
for filename in filenames: | ||
if filename.endswith(".md"): | ||
file_list.append(os.path.join(filepath, filename)) | ||
elif filename.endswith(".txt"): | ||
file_list.append(os.path.join(filepath, filename)) | ||
elif filename.endswith(".pdf"): | ||
file_list.append(os.path.join(filepath, filename)) | ||
|
||
return file_list | ||
|
||
|
||
def get_all_chunk_content(self,max_len:int=600,cover_len:int=150): | ||
docs=[] | ||
for file in self.readlist(): | ||
|
||
content=self.read_file_content(file) | ||
|
||
chunk_content=self.chunk_content(content,max_len,cover_len) | ||
|
||
docs.extend(chunk_content) | ||
|
||
return docs | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
@classmethod | ||
def chunk_content(cls, text: str, max_token_len: int = 600, cover_content: int = 150): | ||
chunk_text = [] | ||
curr_len = 0 | ||
curr_chunk = '' | ||
lines = text.split('\n') | ||
for line in lines: | ||
line = line.replace(' ', '') | ||
line_len = len(enc.encode(line)) | ||
if curr_len + line_len <= max_token_len: | ||
curr_chunk += line | ||
curr_chunk += '\n' | ||
curr_len += line_len | ||
curr_len += 1 | ||
else: | ||
chunk_text.append(curr_chunk) | ||
curr_chunk = curr_chunk[-cover_content:]+line | ||
curr_len = line_len + cover_content | ||
if curr_chunk: | ||
chunk_text.append(curr_chunk) | ||
return chunk_text | ||
|
||
|
||
|
||
|
||
|
||
|
||
#读取文件内容 | ||
|
||
@classmethod | ||
def read_file_content(cls, file_path: str): | ||
if file_path.endswith('.pdf'): | ||
return cls.read_pdf_content(file_path) | ||
elif file_path.endswith('.md'): | ||
return cls.read_md_content(file_path) | ||
elif file_path.endswith('.txt'): | ||
return cls.read_txt_content(file_path) | ||
|
||
@classmethod | ||
def read_md_content(cls, file_path: str): | ||
with open(file_path, 'r', encoding='utf-8') as f: | ||
return f.read() | ||
|
||
@classmethod | ||
def read_pdf_content(cls, file_path: str): | ||
text="" | ||
with open(file_path, 'rb') as f: | ||
reader=PyPDF2.PdfReader(f) | ||
for num_page in range(len(reader.pages)): | ||
text+=reader.pages[num_page].extract_text() | ||
return text | ||
|
||
|
||
|
||
@classmethod | ||
def read_txt_content(self, file_path: str): | ||
with open(file_path, 'r', encoding='utf-8') as f: | ||
return f.read() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from tqdm import tqdm | ||
import numpy as np | ||
from component.embedding import HFembedding,OpenAIembedding,Zhipuembedding,Jinaembedding | ||
import os | ||
import json | ||
from typing import List | ||
|
||
|
||
|
||
|
||
class VectorDB: | ||
def __init__(self,docs:List=[]) -> None: | ||
self.docs = docs | ||
|
||
def get_vector(self,EmbeddingModel)->List[List[float]]: | ||
self.vectors = [] | ||
for doc in tqdm(self.docs): | ||
self.vectors.append(EmbeddingModel.get_embedding(doc)) | ||
return self.vectors | ||
def persist(self,path:str='db')->None: | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
with open(f"{path}/doecment.json", 'w', encoding='utf-8') as f: | ||
json.dump(self.docs, f, ensure_ascii=False) | ||
with open(f"{path}/vectors.json", 'w', encoding='utf-8') as f: | ||
json.dump(self.vectors, f) | ||
|
||
|
||
def load_vector(self,path:str='db')->None: | ||
with open(f"{path}/vectors.json", 'r', encoding='utf-8') as f: | ||
self.vectors = json.load(f) | ||
with open(f"{path}/doecment.json", 'r', encoding='utf-8') as f: | ||
self.document = json.load(f) | ||
|
||
def get_similarity(self, vector1: List[float], vector2: List[float],embedding_model) -> float: | ||
return embedding_model.compare_v(vector1, vector2) | ||
|
||
def query(self, query: str, EmbeddingModel, k: int = 1) -> List[str]: | ||
query_vector = EmbeddingModel.get_embedding(query) | ||
result = np.array([self.get_similarity(query_vector, vector,EmbeddingModel) | ||
for vector in self.vectors]) | ||
return np.array(self.document)[result.argsort()[-k:][::-1]].tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
|
||
## 这个组件,能下载的embedding模型都使用离线的embedding,不能下载的就使用api | ||
|
||
import numpy as np | ||
from transformers import AutoModel | ||
from numpy.linalg import norm | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from zhipuai import ZhipuAI | ||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | ||
import os | ||
from typing import List | ||
class HFembedding: | ||
def __init__(self, path:str=''): | ||
self.path = path | ||
self.embedding=HuggingFaceEmbeddings(model_name=path) | ||
def get_embedding(self,content:str=''): | ||
return self.embedding.embed_query(content) | ||
def compare(self, text1: str, text2: str): | ||
embed1=self.embedding.embed_query(text1) | ||
embed2=self.embedding.embed_query(text2) | ||
return np.dot(embed1, embed2) / (np.linalg.norm(embed1) * np.linalg.norm(embed2)) | ||
def compare_v(cls, vector1: List[float], vector2: List[float]) -> float: | ||
dot_product = np.dot(vector1, vector2) | ||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) | ||
if not magnitude: | ||
return 0 | ||
return dot_product / magnitude | ||
|
||
|
||
class OpenAIembedding: | ||
def __init__(self, path:str=''): | ||
self.path = path | ||
self.embedding=OpenAIEmbeddings() | ||
def get_embedding(self,content:str=''): | ||
content = content.replace("\n", " ") | ||
return self.embedding.embed_query(content) | ||
def compare(self, text1: str, text2: str): | ||
embed1=self.embedding.embed_query(text1) | ||
embed2=self.embedding.embed_query(text2) | ||
return np.dot(embed1, embed2) / (np.linalg.norm(embed1) * np.linalg.norm(embed2)) | ||
def compare_v(cls, vector1: List[float], vector2: List[float]) -> float: | ||
dot_product = np.dot(vector1, vector2) | ||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) | ||
if not magnitude: | ||
return 0 | ||
return dot_product / magnitude | ||
|
||
|
||
|
||
class Zhipuembedding: | ||
|
||
def __init__(self, path:str=''): | ||
|
||
|
||
client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY")) | ||
self.embedding_model=client | ||
|
||
|
||
def get_embedding(self,content:str=''): | ||
response =self.embedding_model.embeddings.create( | ||
model="embedding-2", #填写需要调用的模型名称 | ||
input=content #填写需要计算的文本内容, | ||
) | ||
return response.data[0].embedding | ||
|
||
def compare_v(cls, vector1: List[float], vector2: List[float]) -> float: | ||
dot_product = np.dot(vector1, vector2) | ||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) | ||
if not magnitude: | ||
return 0 | ||
return dot_product / magnitude | ||
def compare(self, text1: str, text2: str): | ||
|
||
embed1=self.embedding_model.embeddings.create( | ||
model="embedding-2", #填写需要调用的模型名称 | ||
input=text1 #填写需要计算的文本内容, | ||
).data[0].embedding | ||
|
||
embed2=self.embedding_model.embeddings.create( | ||
model="embedding-2", #填写需要调用的模型名称 | ||
input=text2 #填写需要计算的文本内容, | ||
).data[0].embedding | ||
|
||
return np.dot(embed1, embed2) / (np.linalg.norm(embed1) * np.linalg.norm(embed2)) | ||
|
||
|
||
|
||
|
||
class Jinaembedding: | ||
def __init__(self, path:str='jinaai/jina-embeddings-v2-base-zh'): | ||
self.path = path | ||
self.embedding_model=AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-zh', trust_remote_code=True) | ||
def get_embedding(self,content:str=''): | ||
return self.embedding_model.encode([content])[0] | ||
def compare(self, text1: str, text2: str): | ||
|
||
cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b)) | ||
embeddings = self.embedding_model.encode([text1, text2]) | ||
return cos_sim(embeddings[0], embeddings[1]) | ||
|
||
def compare_v(cls, vector1: List[float], vector2: List[float]) -> float: | ||
dot_product = np.dot(vector1, vector2) | ||
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) | ||
if not magnitude: | ||
return 0 | ||
return dot_product / magnitude |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from langchain.schema import HumanMessage,SystemMessage | ||
from langchain_openai import ChatOpenAI,OpenAI | ||
from langchain.prompts import PromptTemplate,ChatPromptTemplate,HumanMessagePromptTemplate,SystemMessagePromptTemplate | ||
from component.embedding import Zhipuembedding,OpenAIembedding,HFembedding,Jinaembedding | ||
from component.data_chunker import ReadFile | ||
from component.databases import VectorDB | ||
import os | ||
import json | ||
from typing import Dict, List, Optional, Tuple, Union | ||
import PyPDF2 | ||
|
||
#把api_key放在环境变量中,可以在系统环境变量中设置,也可以在代码中设置 | ||
# import os | ||
# os.environ['OPENAI_API_KEY'] = '' | ||
|
||
class Openai_model: | ||
def __init__(self,model_name:str='gpt-3.5-turbo-instruct',temperature:float=0.9) -> None: | ||
|
||
self.model_name=model_name | ||
self.temperature=temperature | ||
self.model=OpenAI(model=model_name,temperature=temperature) | ||
|
||
self.db=VectorDB() | ||
self.db.load_vector() | ||
self.embedding_model=Zhipuembedding() | ||
|
||
|
||
def chat(self,question:str): | ||
template="""question:{question}\n以下列表信息供你参考,如果你觉得这些信息对回答问题没有帮助,你可以忽视它:\n info:{info}""" | ||
info=self.db.query(question,self.embedding_model,1) | ||
|
||
prompt=PromptTemplate(template=template,input_variables=["question","info"]).format(question=question,info=info) | ||
|
||
res=self.model.invoke(prompt) | ||
|
||
|
||
return res | ||
|
Oops, something went wrong.