forked from shanggangli/ChatGLM-6B-finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_d2q.py
95 lines (85 loc) · 4.15 KB
/
predict_d2q.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# -*- coding:utf-8 -*-
# @project: ChatGLM-6b-int4-Finetuning
# @filename: predict_d2p
# @author: SgangX
# @time: 2023/5/15 13:51
"""
文件说明:
"""
import torch
import json
from modeling_chatglm import ChatGLMForConditionalGeneration
from tokenization_chatglm import ChatGLMTokenizer
from peft import PeftModel
from tqdm import tqdm
import time
import os
import argparse
def set_args():
parser = argparse.ArgumentParser()
parser.add_argument('--test_path', default='data/d2q_1.json', type=str, help='')
parser.add_argument('--device', default='0', type=str, help='')
parser.add_argument('--method', default='freeze', type=str, help='')
parser.add_argument('--ori_model_dir', default='../chatglm-6b-int4', type=str, help='')
parser.add_argument('--model_dir', default="output_dir_freeze/global_step-3600/", type=str, help='')
parser.add_argument('--max_len', type=int, default=768, help='')
parser.add_argument('--max_src_len', type=int, default=450, help='')
parser.add_argument('--prompt_text', type=str, default="你现在是一个问题生成模型,请根据下面文档生成一个问题,文档:",
help='')
parser.add_argument('--top_p', type=float, default=0.95, help='')
parser.add_argument('--do_sample', type=bool, default=True, help='')
parser.add_argument('--num_return_sequences', type=int, default=4, help='')
parser.add_argument('--save_path', type=str, default="d2q_result_data/d2q_freeze.json", help='')
return parser.parse_args()
def main():
args = set_args()
if args.method == "lora":
model = ChatGLMForConditionalGeneration.from_pretrained(args.ori_model_dir)
tokenizer = ChatGLMTokenizer.from_pretrained(args.model_dir)
model.eval()
model = PeftModel.from_pretrained(model, args.model_dir, torch_dtype=torch.float32)
model.half().to("cuda:{}".format(args.device))
else:
model = ChatGLMForConditionalGeneration.from_pretrained(args.model_dir)
tokenizer = ChatGLMTokenizer.from_pretrained(args.model_dir)
model.half().to("cuda:{}".format(args.device))
model.eval()
save_data = []
max_tgt_len = args.max_len - args.max_src_len - 3
s_time = time.time()
with open(args.test_path, "r", encoding="utf-8") as fh:
for i, line in enumerate(tqdm(fh, desc="iter")):
with torch.no_grad():
sample = json.loads(line.strip())
src_tokens = tokenizer.tokenize(sample["text"])
prompt_tokens = tokenizer.tokenize(args.prompt_text)
if len(src_tokens) > args.max_src_len - len(prompt_tokens):
src_tokens = src_tokens[:args.max_src_len - len(prompt_tokens)]
tokens = prompt_tokens + src_tokens + ["[gMASK]", "<sop>"]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# input_ids = tokenizer.encode("帮我写个快排算法")
input_ids = torch.tensor([input_ids]).to("cuda:{}".format(args.device))
generation_kwargs = {
"min_length": 5,
"max_new_tokens": max_tgt_len,
"top_p": args.top_p,
"do_sample": args.do_sample,
"num_return_sequences": args.num_return_sequences,
}
if args.method == "lora":
response = model.generate_one(input_ids, **generation_kwargs)
else:
response = model.generate(input_ids, **generation_kwargs)
res = []
for i_r in range(generation_kwargs["num_return_sequences"]):
outputs = response.tolist()[i_r][input_ids.shape[1]:]
r = tokenizer.decode(outputs).replace("<eop>", "")
res.append(r)
save_data.append({"text": sample["text"], "ori_answer": sample["answer"], "gen_answer": res})
e_time = time.time()
print("总耗时:{}s".format(e_time - s_time))
fin = open(args.save_path, "w", encoding="utf-8")
json.dump(save_data, fin, ensure_ascii=False, indent=4)
fin.close()
if __name__ == '__main__':
main()