forked from datawhalechina/self-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
753 additions
and
6 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from langchain.llms.base import LLM | ||
from typing import Any, List, Optional | ||
from langchain.callbacks.manager import CallbackManagerForLLMRun | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
import torch | ||
|
||
class Atom(LLM): | ||
# 基于本地 Atom 自定义 LLM 类 | ||
tokenizer : AutoTokenizer = None | ||
model: AutoModelForCausalLM = None | ||
|
||
def __init__(self, model_path :str): | ||
# model_path: Atom 模型路径 | ||
# 从本地初始化模型 | ||
super().__init__() | ||
print("正在从本地加载模型...") | ||
model_dir = '/root/autodl-tmp/FlagAlpha/Atom-7B-Chat' | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | ||
self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", trust_remote_code=True,torch_dtype=torch.float16).eval() | ||
print("完成本地模型的加载") | ||
|
||
def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any): | ||
input_ids = self.tokenizer([f'<s>Human: {prompt}\n</s><s>Assistant: '], return_tensors="pt", add_special_tokens=False).input_ids.to('cuda') | ||
generate_input = { | ||
"input_ids": input_ids, | ||
"max_new_tokens": 512, | ||
"do_sample": True, | ||
"top_k": 50, | ||
"top_p": 0.95, | ||
"temperature": 0.3, | ||
"repetition_penalty": 1.3, | ||
"eos_token_id": self.tokenizer.eos_token_id, | ||
"bos_token_id": self.tokenizer.bos_token_id, | ||
"pad_token_id": self.tokenizer.pad_token_id | ||
} | ||
generate_ids = self.model.generate(**generate_input) | ||
text = self.tokenizer.decode(generate_ids[0]) | ||
return text | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
return "Atom" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# 首先导入所需第三方库 | ||
from langchain.document_loaders import UnstructuredFileLoader | ||
from langchain.document_loaders import UnstructuredMarkdownLoader | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.vectorstores import Chroma | ||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | ||
from tqdm import tqdm | ||
import os | ||
import nltk | ||
nltk.download('punkt') | ||
|
||
# 获取文件路径函数 | ||
def get_files(dir_path): | ||
# args:dir_path,目标文件夹路径 | ||
file_list = [] | ||
for filepath, dirnames, filenames in os.walk(dir_path): | ||
# os.walk 函数将递归遍历指定文件夹 | ||
for filename in filenames: | ||
# 通过后缀名判断文件类型是否满足要求 | ||
if filename.endswith(".md"): | ||
# 如果满足要求,将其绝对路径加入到结果列表 | ||
file_list.append(os.path.join(filepath, filename)) | ||
elif filename.endswith(".txt"): | ||
file_list.append(os.path.join(filepath, filename)) | ||
return file_list | ||
|
||
# 加载文件函数 | ||
def get_text(dir_path): | ||
# args:dir_path,目标文件夹路径 | ||
# 首先调用上文定义的函数得到目标文件路径列表 | ||
file_lst = get_files(dir_path) | ||
# docs 存放加载之后的纯文本对象 | ||
docs = [] | ||
# 遍历所有目标文件 | ||
for one_file in tqdm(file_lst): | ||
file_type = one_file.split('.')[-1] | ||
if file_type == 'md': | ||
loader = UnstructuredMarkdownLoader(one_file) | ||
elif file_type == 'txt': | ||
loader = UnstructuredFileLoader(one_file) | ||
else: | ||
# 如果是不符合条件的文件,直接跳过 | ||
continue | ||
docs.extend(loader.load()) | ||
return docs | ||
|
||
# 目标文件夹 | ||
tar_dir = [ | ||
"/root/autodl-tmp/Llama2-Chinese", | ||
] | ||
|
||
# 加载目标文件 | ||
docs = [] | ||
for dir_path in tar_dir: | ||
docs.extend(get_text(dir_path)) | ||
|
||
# 对文本进行分块 | ||
text_splitter = RecursiveCharacterTextSplitter( | ||
chunk_size=500, chunk_overlap=150) | ||
split_docs = text_splitter.split_documents(docs) | ||
|
||
# 加载开源词向量模型 | ||
embeddings = HuggingFaceEmbeddings(model_name="/root/autodl-tmp/embedding_model") | ||
|
||
# 构建向量数据库 | ||
# 定义持久化路径 | ||
persist_directory = 'data_base/vector_db/chroma' | ||
# 加载数据库 | ||
vectordb = Chroma.from_documents( | ||
documents=split_docs, | ||
embedding=embeddings, | ||
persist_directory=persist_directory # 允许我们将persist_directory目录保存到磁盘上 | ||
) | ||
# 将加载的向量数据库持久化到磁盘上 | ||
vectordb.persist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
存储本项目相关代码文件。包括: | ||
|
||
--creat_db.py 构建向量数据库的脚本 | ||
|
||
--llm.py 将 InternLM 封装为自定义 LLM 的脚本 | ||
|
||
--run_gradio.py 启动 Gradio 服务的脚本 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# 导入必要的库 | ||
import gradio as gr | ||
from langchain.vectorstores import Chroma | ||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | ||
import os | ||
from LLM import Atom | ||
from langchain.prompts import PromptTemplate | ||
|
||
def load_chain(): | ||
# 加载问答链 | ||
# 定义 Embeddings | ||
embeddings = HuggingFaceEmbeddings(model_name="/root/autodl-tmp/embedding_model") | ||
|
||
# 向量数据库持久化路径 | ||
persist_directory = 'data_base/vector_db/chroma' | ||
|
||
# 加载数据库 | ||
vectordb = Chroma( | ||
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上 | ||
embedding_function=embeddings | ||
) | ||
|
||
llm = Atom(model_path = "/root/autodl-tmp/FlagAlpha") | ||
|
||
template = """ | ||
{context} | ||
问题: {question} | ||
有用的回答:""" | ||
|
||
QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"], | ||
template=template) | ||
|
||
# 运行 chain | ||
from langchain.chains import RetrievalQA | ||
|
||
qa_chain = RetrievalQA.from_chain_type(llm, | ||
retriever=vectordb.as_retriever(), | ||
return_source_documents=True, | ||
chain_type_kwargs={"prompt":QA_CHAIN_PROMPT}) | ||
|
||
return qa_chain | ||
|
||
class Model_center(): | ||
""" | ||
存储问答 Chain 的对象 | ||
""" | ||
def __init__(self): | ||
self.chain = load_chain() | ||
|
||
def qa_chain_self_answer(self, question: str, chat_history: list = []): | ||
""" | ||
调用不带历史记录的问答链进行回答 | ||
""" | ||
if question == None or len(question) < 1: | ||
return "", chat_history | ||
try: | ||
chat_history.append( | ||
(question, self.chain({"query": question})["result"])) | ||
return "", chat_history | ||
except Exception as e: | ||
return e, chat_history | ||
|
||
def clear_history(self): | ||
self.chain.clear_history() | ||
|
||
|
||
|
||
import gradio as gr | ||
model_center = Model_center() | ||
|
||
block = gr.Blocks() | ||
with block as demo: | ||
with gr.Row(equal_height=True): | ||
with gr.Column(scale=15): | ||
gr.Markdown("""<h1><center>Atom</center></h1> | ||
<center>Llama2-chinese(教程来自DataWhale Self-LM团队)</center> | ||
""") | ||
# gr.Image(value=LOGO_PATH, scale=1, min_width=10,show_label=False, show_download_button=False) | ||
|
||
with gr.Row(): | ||
with gr.Column(scale=4): | ||
chatbot = gr.Chatbot(height=450, show_copy_button=True) | ||
# 创建一个文本框组件,用于输入 prompt。 | ||
msg = gr.Textbox(label="Prompt/问题") | ||
|
||
with gr.Row(): | ||
# 创建提交按钮。 | ||
db_wo_his_btn = gr.Button("Chat") | ||
with gr.Row(): | ||
# 创建一个清除按钮,用于清除聊天机器人组件的内容。 | ||
clear = gr.ClearButton( | ||
components=[chatbot], value="Clear console") | ||
|
||
# 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。 | ||
db_wo_his_btn.click(model_center.qa_chain_self_answer, inputs=[msg, chatbot], outputs=[msg, chatbot]) | ||
|
||
# 点击后清空后端存储的聊天记录 | ||
clear.click(model_center.clear_history) | ||
gr.Markdown("""提醒:<br> | ||
1. 初始化数据库时间可能较长,请耐心等待。 | ||
2. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br> | ||
""") | ||
# threads to consume the request | ||
gr.close_all() | ||
# 启动新的 Gradio 应用,设置分享功能为 True,并使用环境变量 PORT1 指定服务器端口。 | ||
# demo.launch(share=True, server_port=int(os.environ['PORT1'])) | ||
# 直接启动 | ||
demo.launch() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters