Skip to content

Commit

Permalink
support part tuner replace_key False (modelscope#2438)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Nov 13, 2024
1 parent d13c431 commit 490349a
Show file tree
Hide file tree
Showing 11 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion swift/tuners/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _feed_forward_chunk(self, attention_output):
setattr(module, f'adapter_{adapter_name}', adapter_module)
logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}')

def state_dict_callback(state_dict, adapter_name: str):
def state_dict_callback(state_dict, adapter_name: str, **kwargs):
return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key}

def mark_trainable_callback(model):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/llamapro.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -
model.config.num_hidden_layers = len(new_module_list)
LLaMAPro._set_module_list(config, model, new_module_list)

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)
new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx]
return {
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/longlora/longlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def prepare_model(model: nn.Module, config: LongLoRAConfig, adapter_name: str):
"""Prepare a model with `LongLoRAConfig`"""
LoraModel(model, config, adapter_name)

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
_state_dict = lora_state_dict(state_dict, adapter_name, config.bias)
for name, value in state_dict.items():
if isinstance(config.embedder_and_normalizer, str):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str):
config.group_size = getattr(auto_gptq_config, 'group_size', None)
LoraModel(model, config, adapter_name)

def state_dict_callback(state_dict, adapter_name, cfg=None):
def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs):
return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias)

def mark_trainable_callback(model, cfg=None):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/neftune.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def neftune_hook(module, args, output):
sub_module.register_forward_hook(neftune_hook)
sub_module.nef_activated = True

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
return state_dict

def mark_trainable_callback(model):
Expand Down
18 changes: 12 additions & 6 deletions swift/tuners/part.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def _forward(self, *args, **kwargs):
setattr(module, f'_part_{adapter_name}', new_module)
new_module.requires_grad_(True)

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
if f'_part_{adapter_name}.' in key:
new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '')
if kwargs.get('replace_key', True):
new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '')
else:
new_key = key
new_state_dict[new_key] = value

return new_state_dict
Expand All @@ -90,11 +93,14 @@ def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Di
for param_name in state_dict:
if param_name.startswith(name):
end = param_name[len(name):]
if hasattr(module, 'base_layer'):
new_state_dict[name + f'.base_layer._part_{adapter_name}'
+ end] = state_dict[param_name]
if '_part_' not in param_name:
if hasattr(module, 'base_layer'):
new_state_dict[name + f'.base_layer._part_{adapter_name}'
+ end] = state_dict[param_name]
else:
new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name]
else:
new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name]
new_state_dict[param_name] = state_dict[param_name]
return new_state_dict

return SwiftOutput(
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _forward(self, *args, **kwargs):
logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}')
match_module_keys.append(module_key)

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key}

def mark_trainable_callback(model):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/restuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _forward_restuning(self, origin_arg):
if target_module_ins is None:
raise Exception('Cannot match target modules')

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key}

def mark_trainable_callback(model):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/rome/rome.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def prepare_model(model: nn.Module, config: RomeConfig, adapter_name: str):
hparams = ROMEHyperParams.from_name(config.model_type)
modified_keys = apply_rome_to_model(model, config.tokenizer, config.knowledge, hparams, config.batch_first)

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if key in modified_keys}

def mark_trainable_callback(model):
Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/scetuning/scetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _forward_decoder_mode(self, *args, **kwargs):
if len(hint_module_ins_list) > 0:
setattr(t_module, 'hint', hint_module_ins_list[tuner_id])

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key}
return state_dict_new

Expand Down
2 changes: 1 addition & 1 deletion swift/tuners/side.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def forward_seq(self, input, *args, **kwargs):
setattr(tgt_module, f'side_{adapter_name}', side_module)
logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}')

def state_dict_callback(state_dict, adapter_name):
def state_dict_callback(state_dict, adapter_name, **kwargs):
return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key}

def mark_trainable_callback(model):
Expand Down

0 comments on commit 490349a

Please sign in to comment.