Skip to content

Commit

Permalink
Fix rlhf ref model (modelscope#2003)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Sep 10, 2024
1 parent 38ae2e4 commit cd4fe61
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 84 deletions.
7 changes: 3 additions & 4 deletions swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from transformers import IntervalStrategy
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import is_torch_npu_available
from trl.models import create_reference_model

from swift.trainers import RLHFTrainerFactory, get_preprocess_func, get_preprocessed_rlhf_dataset, patch_trl
from swift.utils import (append_to_jsonl, check_json_format, get_dist_setting, get_logger, get_main, get_model_info,
Expand Down Expand Up @@ -154,7 +153,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
args.ref_model_type or args.model_type,
args.torch_dtype,
model_kwargs,
model_id_or_path=args.ref_model_id_or_path or args.model_id_or_path,
model_id_or_path=args.ref_model_id_or_path if args.ref_model_type else args.model_id_or_path,
revision=args.model_revision,
quant_method=args.quant_method,
**kwargs)
Expand Down Expand Up @@ -228,9 +227,9 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
is_encoder_decoder=is_encoder_decoder)
td0, tkwargs0 = preprocess_func(train_dataset[0]), {}
print_example(td0, tokenizer, tkwargs0)
train_dataset = LazyLLMDataset(train_dataset, template, encode_func=preprocess_func)
train_dataset = LazyLLMDataset(train_dataset, preprocess_func)
if val_dataset is not None:
val_dataset = LazyLLMDataset(val_dataset, template, encode_func=preprocess_func)
val_dataset = LazyLLMDataset(val_dataset, preprocess_func)
else:
train_dataset, val_dataset = get_preprocessed_rlhf_dataset(
train_dataset,
Expand Down
12 changes: 6 additions & 6 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
train_dataset, val_dataset = _get_train_val_dataset(args)
td0, tkwargs0 = template.encode(train_dataset[0])
print_example(td0, tokenizer, tkwargs0)
train_dataset = LazyLLMDataset(train_dataset, template)
train_dataset = LazyLLMDataset(train_dataset, template.encode)
if val_dataset is not None:
val_dataset = LazyLLMDataset(val_dataset, template)
val_dataset = LazyLLMDataset(val_dataset, template.encode)

res = MegatronArguments.load_megatron_config(tokenizer.model_dir)
res.update(MegatronArguments.from_sft_args(args, train_dataset, val_dataset))
Expand Down Expand Up @@ -326,6 +326,8 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
else:
template.model = None
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
td0, tkwargs0 = template.encode(train_dataset[0])
print_example(td0, tokenizer, tkwargs0)
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
if val_dataset is not None:
val_dataset = dataset_map(val_dataset, template.encode, args.preprocess_num_proc, streaming=streaming)
Expand All @@ -340,18 +342,16 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
raise AttributeError('Failed to access dataset attributes,train_dataset is None. This might be because:\n'
'(1) The dataset contains None for input or labels;\n'
"(2) The 'max_length' setting is too short causing data truncation.")
td0, tkwargs0 = train_dataset.data[0] if not streaming else (next(iter(train_dataset)), {})
print_example(td0, tokenizer, tkwargs0)
dataset_info['train_dataset'] = stat_dataset(train_dataset) if not streaming else None
if val_dataset is not None:
dataset_info['val_dataset'] = stat_dataset(val_dataset) if not streaming else None
else:
dataset_info = None
td0, tkwargs0 = template.encode(train_dataset[0])
print_example(td0, tokenizer, tkwargs0)
train_dataset = LazyLLMDataset(train_dataset, template)
train_dataset = LazyLLMDataset(train_dataset, template.encode)
if val_dataset is not None:
val_dataset = LazyLLMDataset(val_dataset, template)
val_dataset = LazyLLMDataset(val_dataset, template.encode)
if val_dataset is None:
training_args.evaluation_strategy = IntervalStrategy.NO
training_args.eval_strategy = IntervalStrategy.NO
Expand Down
1 change: 0 additions & 1 deletion swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,6 @@ def _init_training_args(self) -> None:
metric_for_best_model='rouge-l' if self.predict_with_generate else 'loss',
greater_is_better=self.predict_with_generate,
full_determinism=self.full_determinism,
sortish_sampler=True,
optim=self.optim,
adam_beta1=self.adam_beta1,
adam_beta2=self.adam_beta2,
Expand Down
3 changes: 0 additions & 3 deletions swift/llm/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import os
import types

import numpy as np
import torch
from peft import PeftModel
from transformers import TrainerCallback
from transformers.modeling_utils import unwrap_model


class TrainerAdapterCallback(TrainerCallback):
Expand Down
5 changes: 4 additions & 1 deletion swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,10 @@ def get_model_tokenizer_paligemma_vision(model_dir: str,


def _clone_hook(module, input, output):
return output.requires_grad_(True).clone()
if module.training:
return output.requires_grad_(True).clone()
else:
return output


@register_model(
Expand Down
38 changes: 20 additions & 18 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def _pre_forward_hook(module, args, kwargs):
res_extra = []
data = kwargs.pop('_data')
for d in data:
res_extra.append(self._post_encode(d))
res_extra.append(self._post_encode(module, d))
kwargs.update(to_device(self.data_collator(res_extra), module.device))
if 'inputs_embeds' in kwargs:
kwargs.pop('input_ids', None)
Expand All @@ -332,7 +332,7 @@ def _pre_forward_hook(module, args, kwargs):
return args, kwargs

parameters = inspect.signature(self.model.register_forward_pre_hook).parameters
handle = None
handle, handle2 = None, None
deepspeed = None
if 'with_kwargs' in parameters:
handle = self.model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
Expand All @@ -355,6 +355,8 @@ def _initialize(*args, **kwargs):
self._is_training = False
if handle:
handle.remove()
if handle2:
handle2.remove()
if deepspeed:
deepspeed.initialize = _old_initialize

Expand All @@ -370,7 +372,7 @@ def lmdeploy_context(self):
yield
self._is_lmdeploy = False

def _post_encode(self, data: Any) -> Dict[str, Any]:
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
return {}

def check_example(self, example: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -540,7 +542,7 @@ def encode(self, example: Dict[str, Any], streaming: bool = False) -> Tuple[Dict
if not self._is_training and '_data' in inputs:
data = inputs.pop('_data')
data = to_device(data, self.model.device)
inputs.update(self._post_encode(data))
inputs.update(self._post_encode(self.model, data))
return res if not streaming else inputs

async def prepare_lmdeploy_inputs(self, inputs: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -1778,7 +1780,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
inputs['_data'] = {'input_ids': inputs['input_ids'], 'labels': inputs['labels'], 'images': images}
return inputs, {}

def _post_encode(self, data: Any) -> Dict[str, Any]:
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
input_ids = data['input_ids']
labels = data['labels']
images = data['images']
Expand All @@ -1799,13 +1801,13 @@ def _post_encode(self, data: Any) -> Dict[str, Any]:
res_labels = []
wrap_im_mask = []
pre_i, i, idx = 0, 0, 0
device = self.model.device
internlm2_model = self.model.model
device = model.device
internlm2_model = model.model
if not hasattr(internlm2_model, 'tok_embeddings'):
internlm2_model = internlm2_model.model
tok_embeddings = internlm2_model.tok_embeddings
if len(images) > 0:
images = self.model.img2emb(images)[0]
images = model.img2emb(images)[0]
while i < len(input_ids):
if input_ids[i] == 2: # replace_token
res_input_ids = torch.tensor([1] + input_ids[pre_i:i], device=device)
Expand Down Expand Up @@ -1914,20 +1916,20 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
inputs.pop('loss_scale', None)
return inputs, {}

def _post_encode(self, data: Any) -> Dict[str, Any]:
embedding = self.model.get_input_embeddings()
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
embedding = model.get_input_embeddings()
device = embedding.weight.device
input_ids = data['input_ids']
inputs_embeds = embedding(input_ids[None])[0].to(device=device)
pixel_values = data['pixel_values']
if pixel_values is not None:
pixel_values = pixel_values.to(device=device)
vit_embeds = self.model.extract_feature(pixel_values).to(device=device)
vit_embeds = model.extract_feature(pixel_values).to(device=device)
selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
elif is_deepspeed_zero3_enabled():
dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
vit_embeds = self.model.extract_feature(dummy_pixel_values).to(device=device)
vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
inputs_embeds += vit_embeds.mean() * 0.
return {'inputs_embeds': inputs_embeds}

Expand Down Expand Up @@ -2693,8 +2695,8 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
inputs = {'input_ids': new_input_ids, 'labels': new_labels, '_data': batched_output}
return inputs, {}

def _post_encode(self, data: Any) -> Dict[str, Any]:
inputs_embeds = self.model.prepare_inputs_embeds(**data)[0]
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
inputs_embeds = model.prepare_inputs_embeds(**data)[0]
return {'inputs_embeds': inputs_embeds}

@staticmethod
Expand Down Expand Up @@ -2936,8 +2938,8 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
}
return inputs, {}

def _post_encode(self, data: Any) -> Dict[str, Any]:
inputs_embeds, _ = self.model.get_vllm_embedding(data)
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
inputs_embeds, _ = model.get_vllm_embedding(data)
return {'inputs_embeds': inputs_embeds[0]}

@staticmethod
Expand Down Expand Up @@ -3193,8 +3195,8 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An
inputs['labels'] = labels
return inputs, {}

def _post_encode(self, data: Any) -> Dict[str, Any]:
image_embeds = self.model.forward_image(data['pixel_values'])
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
image_embeds = model.forward_image(data['pixel_values'])
return {'image_embeds': image_embeds}


Expand Down
74 changes: 24 additions & 50 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from datasets import Dataset as HfDataset
from datasets import IterableDataset as HfIterableDataset
from modelscope.utils.config_ds import MS_CACHE_HOME
from torch import device as Device
from torch.nn import Linear, Module
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, IterableDataset
Expand Down Expand Up @@ -143,10 +142,10 @@ def __init__(self, data: List[Dict[str, Any]]) -> None:

def __getitem__(self, idx: Union[int, str]) -> Dict[str, Any]:
if isinstance(idx, int):
data, _ = self.data[idx]
data = self.data[idx]
return data
elif isinstance(idx, str):
return [d[0][idx] for d in self.data]
return [d[idx] for d in self.data]
else:
raise ValueError(f'idx: {idx}')

Expand Down Expand Up @@ -250,36 +249,32 @@ class LazyLLMDataset(Dataset):

def __init__(self,
dataset: HfDataset,
template: Template,
encode_func: Callable[[Dict[str, Any]], Union[Tuple[Dict[str, Any], Dict[str, Any]], Dict[str, Any]]],
*,
try_fetch_time: int = 20,
encode_func: Callable = None) -> None:
try_fetch_time: int = 20) -> None:
self.dataset = dataset
self.template = template
self.try_fetch_time = min(try_fetch_time, len(self.dataset))
self.encode_func = encode_func
self.try_fetch_time = min(try_fetch_time, len(self.dataset))
assert self.try_fetch_time >= 1

def __getitem__(self, idx: int) -> Dict[str, Any]:
res = self._try_fetch(idx)
if res is not None:
data, _ = res
return data
return res
raise ValueError('Please check if the max_length is appropriate.')

def _try_fetch(self, first_idx: int) -> Optional[Dict[str, Any]]:
idx = np.random.permutation(len(self))[:self.try_fetch_time - 1]
for i in [first_idx] + idx.tolist():
data = self.dataset[i]
try:
if self.encode_func:
res = self.encode_func(data), {}
else:
res = self.template.encode(data)
res = self.encode_func(data)
if isinstance(res, (tuple, list)) and len(res) == 2:
res = res[0]
except Exception as e:
logger.error(f'Error occurs in lazy tokenize: {e}')
continue
if len(res[0]) > 0:
if len(res) > 0:
return res

def __len__(self) -> int:
Expand All @@ -290,8 +285,8 @@ def __len__(self) -> int:


def _single_map(d: Dict[str, Any], map_func: MapFunc) -> Optional[Dict[str, Any]]:
d = map_func(d)
if len(d[0]) == 0:
d = map_func(d)[0]
if len(d) == 0:
return None
return d

Expand Down Expand Up @@ -358,7 +353,9 @@ def stat_dataset(llm_dataset: Dataset) -> str:
_token_len.append(len(ii))
else:
for d in llm_dataset:
_token_len.append(len(d['input_ids']))
for k, v in d.items():
if k == 'input_ids' or k.endswith('_input_ids'): # sft, rlhf
_token_len.append(len(v))
_, stat_str = stat_array(_token_len)
logger.info(f'Dataset Token Length: {stat_str}')
return stat_str
Expand Down Expand Up @@ -403,36 +400,13 @@ def print_example(example: Dict[str, Any],
tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None:
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
input_ids = example.get('input_ids')
chosen_input_ids = example.get('chosen_input_ids')
chosen_labels = example.get('chosen_labels')
rejected_input_ids = example.get('rejected_input_ids')
rejected_labels = example.get('rejected_labels')
labels = example.get('labels')
if input_ids is not None:
logger.info(f'[INPUT_IDS] {input_ids}')
input_str = safe_tokenizer_decode(tokenizer, input_ids, **tokenizer_kwargs)
logger.info(f'[INPUT] {input_str}')
if chosen_input_ids is not None:
logger.info(f'[CHOSEN_INPUT_IDS] {chosen_input_ids}')
input_str = safe_tokenizer_decode(tokenizer, chosen_input_ids, **tokenizer_kwargs)
logger.info(f'[CHOSEN_INPUT] {input_str}')
if rejected_input_ids is not None:
logger.info(f'[REJECTED_INPUT_IDS] {rejected_input_ids}')
input_str = safe_tokenizer_decode(tokenizer, rejected_input_ids, **tokenizer_kwargs)
logger.info(f'[REJECTED_INPUT] {input_str}')
if labels is not None:
logger.info(f'[LABELS_IDS] {labels}')
labels_str = safe_tokenizer_decode(tokenizer, labels, **tokenizer_kwargs)
logger.info(f'[LABELS] {labels_str}')
if chosen_labels is not None:
logger.info(f'[CHOSEN_LABELS_IDS] {chosen_labels}')
labels_str = safe_tokenizer_decode(tokenizer, chosen_labels, **tokenizer_kwargs)
logger.info(f'[CHOSEN_LABELS] {labels_str}')
if rejected_labels is not None:
logger.info(f'[REJECTED_LABELS_IDS] {rejected_labels}')
labels_str = safe_tokenizer_decode(tokenizer, rejected_labels, **tokenizer_kwargs)
logger.info(f'[REJECTED_LABELS] {labels_str}')
for key in ['input', 'chosen_input', 'rejected_input', 'labels', 'chosen_labels', 'rejected_labels']:
val = example.get(key) or example.get(f'{key}_ids')
if val is not None:
key_upper = key.upper()
logger.info(f'[{key_upper}_IDS] {val}')
val_str = safe_tokenizer_decode(tokenizer, val, **tokenizer_kwargs)
logger.info(f'[{key_upper}] {val_str}')


def _find_layers(model: Module, module_cls: type) -> List[str]:
Expand Down Expand Up @@ -535,7 +509,7 @@ def sort_by_max_length(llm_dataset: LLMDataset, num_dataset: int) -> LLMDataset:
return llm_dataset.select(idx)


def to_device(inputs: Any, device: Device) -> Any:
def to_device(inputs: Any, device: torch.device) -> Any:
if callable(getattr(inputs, 'to', None)):
return inputs.to(device=device)

Expand Down Expand Up @@ -1124,7 +1098,7 @@ def get_rope_scaling(config: PretrainedConfig):
@wraps(infer_auto_device_map)
def _infer_auto_device_map_patch(model: Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
**kwargs) -> Dict[str, Union[int, str, Device]]:
**kwargs) -> Dict[str, Union[int, str, torch.device]]:
"""The auxiliary function for supports DDP+MP. Monkey Patching.
add feat in accelerate to support DDP + MP"""
verbose = kwargs.pop('verbose', False)
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
class SwiftArgumentsMixin:
# ckpt only save model
save_only_model: bool = False
train_sampler_random: bool = True
acc_strategy: str = field(default='token', metadata={'choices': ['token', 'sentence']})
loss_name: Optional[str] = field(default=None, metadata={'help': f'loss_func choices: {list(LOSS_MAPPING.keys())}'})
additional_saved_files: Optional[List[str]] = None
# torchacc
train_sampler_random: bool = True
metric_warmup_step: Optional[float] = 0
train_dataset_sample: Optional[int] = -1

Expand Down

0 comments on commit cd4fe61

Please sign in to comment.