Skip to content

Commit

Permalink
Update Yi-6B-chat FastApi部署调用 code
Browse files Browse the repository at this point in the history
  • Loading branch information
KMnO4-zx committed Dec 20, 2023
1 parent edc4518 commit eeb65dc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
- [x] [Qwen-7B-chat 全量微调](./Qwen/06-Qwen-7B-chat%20全量微调.md) @不要葱姜蒜
- [x] [Qwen-7B-Chat 接入langchain搭建知识库助手](./Qwen/07-Qwen-7B-Chat%20接入langchain搭建知识库助手.md) @娇娇
- [x] [Qwen-7B-chat 低精度训练](./Qwen/08-Qwen-7B-Chat%20Lora%20低精度微调.md) @ Hongru0306 ddl=12.11
- [ ] Qwen-1_8B-chat CPU 部署 @ 散步
- [x] [Qwen-1_8B-chat CPU 部署](./Qwen/09-Qwen-1_8B-chat%20CPU%20部署%20.md) @ 散步

- [DeepSeek 深度求索](https://github.com/deepseek-ai/DeepSeek-LLM)
- [x] [DeepSeek-7B-chat FastApi 部署调用](./DeepSeek/01-DeepSeek-7B-chat%20FastApi.md) @ 不要葱姜蒜
Expand All @@ -98,8 +98,8 @@
- [x] [Baichuan2-7B-chat Lora 微调](./BaiChuan/04-Baichuan2-7B-chat%2Blora%2B%E5%BE%AE%E8%B0%83.md) @ 三山时春いddl=12.15

- [Yi 零一万物](https://github.com/01-ai/Yi.git)
- [ ] Yi-6B-chat FastApi 部署调用 @ Joe ddl=12.15
- [ ] Yi-6B-chat langchain接入 @ Joe ddl=12.15
- [x] [Yi-6B-chat FastApi 部署调用](./Yi/01-Yi-6B-Chat%20FastApi%20部署调用.md) @ Joe ddl=12.15
- [x] [Yi-6B-chat langchain接入](./Yi/02-Yi-6B-Chat%20接入langchain搭建知识库助手.md) @ Joe ddl=12.15
- [x] [Yi-6B-chat WebDemo](./Yi/03-Yi-6B-chat%20WebDemo.md) @ Hongru0306 ddl=12.15
- [x] [Yi-6B-chat Lora 微调](./Yi/04-Yi-6B-Chat%20Lora%20微调.md) @ 娇娇 ddl=12.15

Expand Down
35 changes: 17 additions & 18 deletions Yi/01-Yi-6B-Chat FastApi 部署调用.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,20 @@ async def create_item(request: Request):
json_post = json.dumps(json_post_raw) # 将JSON数据转换为字符串
json_post_list = json.loads(json_post) # 将字符串转换为Python对象
prompt = json_post_list.get('prompt') # 获取请求中的提示
history = json_post_list.get('history') # 获取请求中的历史记录
max_length = json_post_list.get('max_length') # 获取请求中的最大长度
top_p = json_post_list.get('top_p') # 获取请求中的top_p参数
temperature = json_post_list.get('temperature') # 获取请求中的温度参数

messages = [
{"role": "user", "content": prompt}
]

# 调用模型进行对话生成
response, history = model.chat(
tokenizer,
prompt,
history=history,
max_length=max_length if max_length else 2048, # 如果未提供最大长度,默认使用2048
top_p=top_p if top_p else 0.7, # 如果未提供top_p参数,默认使用0.7
temperature=temperature if temperature else 0.95 # 如果未提供温度参数,默认使用0.95
)
input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
output_ids = model.generate(input_ids.to('cuda'))
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
now = datetime.datetime.now() # 获取当前时间
time = now.strftime("%Y-%m-%d %H:%M:%S") # 格式化时间为字符串
# 构建响应JSON
answer = {
"response": response,
"history": history,
"status": 200,
"time": time
}
Expand All @@ -105,9 +100,10 @@ async def create_item(request: Request):
# 主函数入口
if __name__ == '__main__':
# 加载预训练的分词器和模型
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/01ai/Yi-6B-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("/root/autodl-tmp/01ai/Yi-6B-Chat", device_map="auto", trust_remote_code=True).eval()
model.generation_config = GenerationConfig.from_pretrained("/root/autodl-tmp/01ai/Yi-6B-Chat", trust_remote_code=True) # 可指定
model_name_or_path = 'root/autodl-tmp/01ai/Yi-6B-Chat'
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True).eval()
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True) # 可指定
model.eval() # 设置模型为评估模式
# 启动FastAPI应用
# 用6006端口可以将autodl的端口映射到本地,从而在本地使用api
Expand All @@ -125,7 +121,8 @@ python api.py

加载完毕后出现如下信息说明成功。

![启动服务加载信息](images/5.png)
![Alt text](images/5.png)


默认部署在 6006 端口,通过 POST 方法进行调用,可以使用 curl 调用,如下所示:

Expand All @@ -143,7 +140,7 @@ import json

def get_completion(prompt):
headers = {'Content-Type': 'application/json'}
data = {"prompt": prompt, "history": []}
data = {"prompt": prompt}
response = requests.post(url='http://127.0.0.1:6006', headers=headers, data=json.dumps(data))
return response.json()['response']

Expand All @@ -161,3 +158,5 @@ if __name__ == '__main__':
"time":"2023-12-15 20:08:40"
}
```

![Alt text](images/6.png)
Binary file modified Yi/images/5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Yi/images/6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit eeb65dc

Please sign in to comment.