Skip to content

Commit

Permalink
support dynamic_eos (modelscope#1947)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Sep 5, 2024
1 parent fb24aa7 commit 6523758
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
7 changes: 4 additions & 3 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def check_flash_attn(self: Union['SftArguments', 'InferArguments']) -> None:
def handle_generation_config(self: Union['SftArguments', 'InferArguments']) -> None:
if self.temperature == 0:
self.do_sample = False
if self.do_sample is False:
if self.do_sample is False and (isinstance(self, InferArguments) and self.infer_backend == 'pt'
and isinstance(self, SftArguments)):
# fix warning
self.temperature = 1.
self.top_p = 1.
Expand Down Expand Up @@ -994,7 +995,6 @@ def __post_init__(self) -> None:
self.dataset_seed = self.seed
self.set_model_type()
self.check_flash_attn()
self.handle_generation_config()
self.handle_lr_scheduler_kwargs()
self.is_multimodal = self._is_multimodal(self.model_type)
self.is_vision = self._is_vision(self.model_type)
Expand Down Expand Up @@ -1170,6 +1170,7 @@ def __post_init__(self) -> None:
self.logging_dir = f'{self.output_dir}/runs'
if self.train_backend == 'transformers':
self.training_args.logging_dir = self.logging_dir
self.handle_generation_config()

def _init_training_args(self) -> None:
additional_saved_files = []
Expand Down Expand Up @@ -1460,7 +1461,6 @@ def __post_init__(self) -> None:
self.handle_custom_dataset_info()
self.set_model_type()
self.check_flash_attn()
self.handle_generation_config()
self.is_multimodal = self._is_multimodal(self.model_type)
self.prepare_ms_hub()

Expand Down Expand Up @@ -1492,6 +1492,7 @@ def __post_init__(self) -> None:
self.sft_type = 'full'

self.handle_infer_backend()
self.handle_generation_config()

def handle_infer_backend(self):
model_info = MODEL_MAPPING[self.model_type]
Expand Down
29 changes: 18 additions & 11 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,19 @@ def _encode_context_list(
loss_scale.extend([loss_weight] * len(token_list))
return input_ids, labels, loss_scale, tokenizer_kwargs

@staticmethod
def use_dynamic_eos(labels: List[int], suffix_tokens_id: List[int]) -> None:
suffix_len = len(suffix_tokens_id)
start = 0
for i in range(1, len(labels)):
if labels[i - 1] >= 0 and labels[i] == -100:
start = i
if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
# [0, 1, 2, -100(start), -100, 3(i), 4]
length = i - start
if length >= suffix_len:
labels[start:start + suffix_len] = suffix_tokens_id

def _concat_and_tokenize(self,
query: str,
query_role: str,
Expand Down Expand Up @@ -840,18 +853,10 @@ def _concat_and_tokenize(self,
history.append([query, response])
history_roles.append([query_role, 'assistant'])

# Set the loss_scale of chat_sep or suffix to 1 if efficient_eos.
efficient_eos = False
if self.chat_sep is not None and len(self.chat_sep) > 0:
if isinstance(self.chat_sep[0], str) and isinstance(self.suffix[0], str) and self.chat_sep[0].startswith(
self.suffix[0]):
efficient_eos = True
elif isinstance(self.chat_sep[0], list) and self.chat_sep[0] == self.suffix[0]:
efficient_eos = True

for i, ((q, r), (qr, rr)) in enumerate(zip(history, history_roles)):
context_list = self.tool_prompt.copy() if qr == 'tool' else prompt.copy()
extra_context_list = []
is_suffix = False
if i < len(history) - 1:
context_list = [context for context in context_list if '{{SYSTEM}}' not in context]
context_list.append('{{RESPONSE}}')
Expand All @@ -861,14 +866,16 @@ def _concat_and_tokenize(self,
# last response
context_list.append('{{RESPONSE}}')
extra_context_list = self.suffix
efficient_eos = True
is_suffix = True
if q or r:
self._concat_context_list(
context_list, res_context_list, loss_scale_list, query=q, response=r, system=system, round0=i)
res_context_list += extra_context_list
loss_scale_list += ([1.] if efficient_eos else [0.]) * len(extra_context_list)
loss_scale_list += ([1.] if is_suffix else [0.]) * len(extra_context_list)
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, **kwargs)
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(res_context_list, loss_scale_list)
if labels is not None:
self.use_dynamic_eos(labels, self._encode_context_list(self.suffix)[0])

if response is None:
labels = None
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def _is_special(token: int) -> bool:
if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]):
e = i
result_str += f'[{input_ids[i - 1]} * {e - s}]'
if _is_special(input_ids[-1]):
result_str += f'[{input_ids[i - 1]} * {len(input_ids) - s}]'
if _is_special(input_ids[i]):
result_str += f'[{input_ids[i]} * {len(input_ids) - s}]'
else:
result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs)
return result_str
Expand Down

0 comments on commit 6523758

Please sign in to comment.