Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
guoerjun committed Jun 19, 2024
1 parent 31e0b42 commit 9f1bfe6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 41 deletions.
46 changes: 19 additions & 27 deletions detectron/demo/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from langchain.callbacks.manager import (
CallbackManagerForLLMRun
)
from pydantic import Field
from langchain_core.embeddings import Embeddings
import torch

class Embedding(Embeddings):

def __init__(self,**kwargs):
self.model=AutoModel.from_pretrained('BAAI/bge-small-zh-v1.5')
self.tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-small-zh-v1.5')
self.model.eval()

@property
def _llm_type(self) -> str:
Expand All @@ -23,31 +24,21 @@ def model_name(self) -> str:

def _call(
self,
prompt: Union[str,List[str]],
prompt: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
batch_data = self.tokenizer(
text=prompt,
padding="longest",
return_tensors="pt",
max_length=1024,
truncation=True,
)
encoded_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors='pt')

attention_mask = batch_data["attention_mask"]
# batch_data.to('cuda')
model_output = self.model(**batch_data)
# model_output = model_output.cpu()
last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

vectors = vectors.detach().numpy()
# 对每行的向量进行归一化
vectors = normalize(vectors, norm="l2", axis=1)
print("_call",vectors.shape)
return vectors
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling. In this case, cls pooling.
sentence_embeddings = model_output[0][:, 0]
print(sentence_embeddings.shape)
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.numpy()

@property
def _identifying_params(self) -> Mapping[str, Any]:
Expand All @@ -65,12 +56,13 @@ def embed_documents(self, texts) -> List[List[float]]:

def embed_query(self, text) -> List[float]:
# Embed a single query
embedding = self._call(text)
embedding = self._call([text])
return embedding[0]


if __name__ == '__main__':
sd = Embedding()
v1 = sd.embed_query("他是一个人")
v2 = sd.embed_query("她是一条狗")
print(v1 @ v2.T)
# if __name__ == '__main__':
# sd = Embedding()
# v1 = sd.embed_query("他是一个人")
# v2 = sd.embed_query("他是一个好人")
# v3 = sd.embed_documents(["她是一条狗","他是一个人"])
# print(v1 @ v2.T)
9 changes: 3 additions & 6 deletions detectron/demo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def create_ui():
datatype=["str", "str"],
row_count=2,
col_count=(2, "fixed"),
interactive=False
)
with gr.Column(scale=3):
with gr.Group():
Expand Down Expand Up @@ -185,11 +186,7 @@ def create_event_handlers():
components["chatbot"].like(print_like_dislike, None, None)

components['file_upload'].upload(
file_handler, gradio('file_upload'), None, show_progress=False
)

components['db_view'].upload(
file_handler, gradio('db_view'), None
file_handler, gradio('file_upload'), gradio('db_view'), show_progress=False
)

def do_refernce(algo_type,input_image):
Expand Down Expand Up @@ -313,7 +310,7 @@ def query(payload):
return output[0]['generated_text']
return ""

def file_handler(file_objs,state, regenerate=False, _continue=False):
def file_handler(file_objs):
import shutil
import os
from retriever import Retriever
Expand Down
108 changes: 100 additions & 8 deletions detectron/demo/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,104 @@
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import os
from typing import Any
from typing import Any,List,Dict
from embedding import Embedding


class KnowledgeBaseManager:
def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
self.base_path = base_path
self.embedding_dim = embedding_dim
self.batch_size = batch_size
self.embeddings = Embedding()
self.knowledge_bases: Dict[str, FAISS] = {}
os.makedirs(self.base_path, exist_ok=True)

def create_knowledge_base(self, name: str):
index = faiss.IndexFlatL2(self.embedding_dim)
kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
self.knowledge_bases[name] = kb
self.save_knowledge_base(name)
print(f"Knowledge base '{name}' created.")

def delete_knowledge_base(self, name: str):
if name in self.knowledge_bases:
del self.knowledge_bases[name]
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
print(f"Knowledge base '{name}' deleted.")
else:
print(f"Knowledge base '{name}' does not exist.")

def load_knowledge_base(self, name: str):
kb_path = os.path.join(self.base_path, f"{name}.faiss")
if os.path.exists(kb_path):
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
print(f"Knowledge base '{name}' loaded.")
else:
print(f"Knowledge base '{name}' does not exist.")

def save_knowledge_base(self, name: str):
if name in self.knowledge_bases:
self.knowledge_bases[name].save_local(self.base_path, name)
print(f"Knowledge base '{name}' saved.")
else:
print(f"Knowledge base '{name}' does not exist.")

def add_documents_to_kb(self, name: str, file_paths: List[str]):
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
return

kb = self.knowledge_bases[name]
documents = self.load_documents(file_paths)
print(f"Loaded {len(documents)} documents.")

pages = self.split_documents(documents)
print(f"Split documents into {len(pages)} pages.")

for i in range(0, len(pages), self.batch_size):
batch = pages[i:i+self.batch_size]
kb.add_documents(batch)

self.save_knowledge_base(name)

def load_documents(self, file_paths: List[str]):
documents = []
for file_path in file_paths:
loader = self.get_loader(file_path)
documents.extend(loader.load())
return documents

def get_loader(self, file_path: str):
if file_path.endswith('.txt'):
return TextLoader(file_path)
elif file_path.endswith('.json'):
return JSONLoader(file_path)
elif file_path.endswith('.pdf'):
return PyPDFLoader(file_path)
else:
raise ValueError("Unsupported file format")

def split_documents(self, documents):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
return text_splitter.split_documents(documents)

def retrieve_documents(self, names: List[str], query: str):
results = []
for name in names:
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
continue

retriever = self.knowledge_bases[name].as_retriever(
search_type="mmr",
search_kwargs={"score_threshold": 0.5, "k": 1}
)
docs = retriever.get_relevant_documents(query)
results.extend([{"name": name, "content": doc.page_content} for doc in docs])

return results

class Retriever():
index_path = "./"
index_name = "default"
Expand All @@ -15,10 +110,10 @@ def __init__(self):
self.embeddings = Embedding()
if os.path.exists(self.index_path+self.index_name+".faiss"):
print("load faiss from local index ")
self.vector_store = FAISS.load_local(self.index_path, self.excurtor[0],self.index_name,allow_dangerous_deserialization=True)
self.vector_store = FAISS.load_local(self.index_path, self.embeddings,self.index_name,allow_dangerous_deserialization=True)

else:
index = faiss.IndexFlatL2(1024)
index = faiss.IndexFlatL2(512)
self.vector_store = FAISS(self.embeddings,index,InMemoryDocstore(),{})
self.vector_store.save_local(self.index_path,self.index_name)

Expand Down Expand Up @@ -46,18 +141,15 @@ def load_documents(self, file_paths):
return documents

def split_documents(self, documents):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
return text_splitter.split_documents(documents)

def build_vector_store(self, docs):
self.vector_store.add_documents(docs)

def retrieve_documents(self, query):
docs = self.retriever.get_relevant_documents(query)
texts = []
for d in docs:
texts.append(d.page_content)
return texts
return [doc.page_content for doc in docs]

def run(self, input: Any,**kwargs):
if input is None or input == "":
Expand Down

0 comments on commit 9f1bfe6

Please sign in to comment.