-
Notifications
You must be signed in to change notification settings - Fork 810
/
Copy pathllm_internlm.py
141 lines (120 loc) · 5.23 KB
/
llm_internlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import copy
from typing import List, Optional, Callable, Optional
from dataclasses import dataclass, asdict
import torch.nn as nn
from plugins.common import settings
def chat_init(history):
tmp = []
# print(history)
for i, old_chat in enumerate(history):
if old_chat['role'] == "user":
tmp.append(user_prompt.replace("{user}", old_chat['content']))
elif old_chat['role'] == "AI":
tmp.append(robot_prompt.replace("{robot}", old_chat['content']))
else:
continue
history = ''.join(tmp)
return history
def chat_one(prompt, history, max_length, top_p, temperature, data):
# if prompt.startswith("raw!"):
# print("[raw mode]", end="")
# prompt = prompt.replace("raw!", "")
# else:
# prompt = f"{user}{interface}{prompt}\n{answer}{interface}"
generation_config = GenerationConfig(
max_length=max_length,
top_p=top_p,
temperature=temperature,
repetition_penalty=1.05
)
prompt = history + cur_query_prompt.replace("{user}", prompt)
for i in generate_interactive(prompt, (generation_config),additional_eos_token_id=103028):
yield i
def load_model():
global model, tokenizer
model = AutoModelForCausalLM.from_pretrained(
settings.llm.path, trust_remote_code=True).to(torch.bfloat16).cuda()
tokenizer = AutoTokenizer.from_pretrained(
settings.llm.path, trust_remote_code=True)
@ torch.inference_mode()
def generate_interactive(
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[
int, torch.Tensor], List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
inputs = tokenizer([prompt], padding=True, return_tensors="pt")
input_length = len(inputs["input_ids"][0])
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
eos_token_id=[additional_eos_token_id]
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
model_kwargs = generation_config.update(**kwargs)
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)
yield response
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
user_prompt = "<|User|>:{user}<eoh>\n"
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"