-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllm_prompt.py
80 lines (66 loc) · 2.73 KB
/
llm_prompt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from transformers import pipeline
import torch
from langchain.document_loaders import WikipediaLoader, OnlinePDFLoader, UnstructuredURLLoader
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from transformers import AutoTokenizer
from huggingface_hub import InferenceClient
def get_prompt_embeddingllm():
model_id = "HuggingFaceH4/zephyr-7b-beta"
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt_in_chat_format = [
{
"role": "system",
"content": """Using the information contained in the context,
give a comprehensive answer to the question.
Respond only to the question asked, response should be concise and relevant to the question.
Provide the number of the source document when relevant.
If the answer cannot be deduced from the context, do not give an answer.""",
},
{
"role": "user",
"content": """Context:
{context}
---
Now here is the question you need to answer.
Question: {question}""",
},
]
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
prompt_in_chat_format, tokenize=False, add_generation_prompt=True
)
# embeddings
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embedding_llm = SentenceTransformerEmbeddings(model_name=model_name)
return RAG_PROMPT_TEMPLATE, embedding_llm
def get_text_chunks_langchain(text):
text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=50)
docs = [Document(page_content=x, metadata={"document": i}) for i, x in enumerate(text_splitter.split_text(text))]
return docs
def get_answer(embedding_llm, RAG_PROMPT_TEMPLATE, docs, question):
# vector database
save_to_dir = "/content/wiki_chroma_db"
vector_db = Chroma.from_documents(
docs,
embedding_llm,
#persist_directory=save_to_dir
)
similar_docs = vector_db.similarity_search(question, k=1)
retrieved_docs_text = [doc.page_content for doc in similar_docs] # We only need the text of the documents
context = "\nExtracted documents:\n"
context += "".join([f"Document {str(i)}:::\n" + doc for i, doc in enumerate(retrieved_docs_text)])
final_prompt = RAG_PROMPT_TEMPLATE.format(question=question, context=context)
# Redact an answer
# final_answer = READER_LLM(final_prompt)[0]["generated_text"]
client = InferenceClient()
response = client.text_generation(
prompt=final_prompt,
model="HuggingFaceH4/zephyr-7b-beta",
temperature=0.8,
max_new_tokens=500,
seed=42,
return_full_text=False,
)
return response