forked from zylon-ai/private-gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprivateGPT.py
40 lines (35 loc) · 1.53 KB
/
privateGPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from langchain.chains import RetrievalQA
from langchain.embeddings import LlamaCppEmbeddings
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.vectorstores import Chroma
from langchain.llms import GPT4All
from constants import PERSIST_DIRECTORY
from constants import CHROMA_SETTINGS
def main():
# Load stored vectorstore
llama = LlamaCppEmbeddings(model_path="./models/ggml-model-q4_0.bin")
db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=llama, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# Prepare the LLM
callbacks = [StreamingStdOutCallbackHandler()]
llm = GPT4All(model='./models/ggml-gpt4all-j-v1.3-groovy.bin', backend='gptj', callbacks=callbacks, verbose=False)
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
# Get the answer from the chain
res = qa(query)
answer, docs = res['result'], res['source_documents']
# Print the result
print("\n\n> Question:")
print(query)
print("\n> Answer:")
print(answer)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)
if __name__ == "__main__":
main()