diff --git a/swift/tuners/adapter.py b/swift/tuners/adapter.py index 3fe1879d1..aa3329589 100644 --- a/swift/tuners/adapter.py +++ b/swift/tuners/adapter.py @@ -111,7 +111,7 @@ def _feed_forward_chunk(self, attention_output): setattr(module, f'adapter_{adapter_name}', adapter_module) logger.info(f'Adapter modules(module_key): {module_key}.adapter_{adapter_name}') - def state_dict_callback(state_dict, adapter_name: str): + def state_dict_callback(state_dict, adapter_name: str, **kwargs): return {key: value for key, value in state_dict.items() if f'adapter_{adapter_name}' in key} def mark_trainable_callback(model): diff --git a/swift/tuners/llamapro.py b/swift/tuners/llamapro.py index 84b6dd452..13371eabd 100644 --- a/swift/tuners/llamapro.py +++ b/swift/tuners/llamapro.py @@ -77,7 +77,7 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) - model.config.num_hidden_layers = len(new_module_list) LLaMAPro._set_module_list(config, model, new_module_list) - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config) new_module_list = [model_key_mapping.module_list + f'.{i}' for i in new_module_idx] return { diff --git a/swift/tuners/longlora/longlora.py b/swift/tuners/longlora/longlora.py index a4a3c43e7..fd02f387c 100644 --- a/swift/tuners/longlora/longlora.py +++ b/swift/tuners/longlora/longlora.py @@ -51,7 +51,7 @@ def prepare_model(model: nn.Module, config: LongLoRAConfig, adapter_name: str): """Prepare a model with `LongLoRAConfig`""" LoraModel(model, config, adapter_name) - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): _state_dict = lora_state_dict(state_dict, adapter_name, config.bias) for name, value in state_dict.items(): if isinstance(config.embedder_and_normalizer, str): diff --git a/swift/tuners/lora.py b/swift/tuners/lora.py index d823a4c48..75367281b 100644 --- a/swift/tuners/lora.py +++ b/swift/tuners/lora.py @@ -81,7 +81,7 @@ def prepare_model(model: nn.Module, config: LoRAConfig, adapter_name: str): config.group_size = getattr(auto_gptq_config, 'group_size', None) LoraModel(model, config, adapter_name) - def state_dict_callback(state_dict, adapter_name, cfg=None): + def state_dict_callback(state_dict, adapter_name, cfg=None, **kwargs): return lora_state_dict(state_dict, adapter_name, cfg.bias if cfg else config.bias) def mark_trainable_callback(model, cfg=None): diff --git a/swift/tuners/neftune.py b/swift/tuners/neftune.py index 6b211ad94..6476283e5 100644 --- a/swift/tuners/neftune.py +++ b/swift/tuners/neftune.py @@ -49,7 +49,7 @@ def neftune_hook(module, args, output): sub_module.register_forward_hook(neftune_hook) sub_module.nef_activated = True - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): return state_dict def mark_trainable_callback(model): diff --git a/swift/tuners/part.py b/swift/tuners/part.py index 0cbdb09c3..585af9db5 100644 --- a/swift/tuners/part.py +++ b/swift/tuners/part.py @@ -70,11 +70,14 @@ def _forward(self, *args, **kwargs): setattr(module, f'_part_{adapter_name}', new_module) new_module.requires_grad_(True) - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): new_state_dict = {} for key, value in state_dict.items(): if f'_part_{adapter_name}.' in key: - new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '') + if kwargs.get('replace_key', True): + new_key = key.replace(f'_part_{adapter_name}.', '').replace('base_layer.', '') + else: + new_key = key new_state_dict[new_key] = value return new_state_dict @@ -90,11 +93,14 @@ def load_state_dict_callback(model: nn.Module, adapter_name: str, state_dict: Di for param_name in state_dict: if param_name.startswith(name): end = param_name[len(name):] - if hasattr(module, 'base_layer'): - new_state_dict[name + f'.base_layer._part_{adapter_name}' - + end] = state_dict[param_name] + if '_part_' not in param_name: + if hasattr(module, 'base_layer'): + new_state_dict[name + f'.base_layer._part_{adapter_name}' + + end] = state_dict[param_name] + else: + new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name] else: - new_state_dict[name + f'._part_{adapter_name}' + end] = state_dict[param_name] + new_state_dict[param_name] = state_dict[param_name] return new_state_dict return SwiftOutput( diff --git a/swift/tuners/prompt.py b/swift/tuners/prompt.py index ff4c5f5fe..262c382e1 100644 --- a/swift/tuners/prompt.py +++ b/swift/tuners/prompt.py @@ -126,7 +126,7 @@ def _forward(self, *args, **kwargs): logger.info(f'Prompt modules(module_key): {module_key}.prompt_{adapter_name}') match_module_keys.append(module_key) - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): return {key: value for key, value in state_dict.items() if f'prompt_{adapter_name}' in key} def mark_trainable_callback(model): diff --git a/swift/tuners/restuning.py b/swift/tuners/restuning.py index fe6f14514..b965b2b32 100644 --- a/swift/tuners/restuning.py +++ b/swift/tuners/restuning.py @@ -233,7 +233,7 @@ def _forward_restuning(self, origin_arg): if target_module_ins is None: raise Exception('Cannot match target modules') - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): return {key: value for key, value in state_dict.items() if f'restuning_{adapter_name}' in key} def mark_trainable_callback(model): diff --git a/swift/tuners/rome/rome.py b/swift/tuners/rome/rome.py index 14f757355..36dcc1e8a 100644 --- a/swift/tuners/rome/rome.py +++ b/swift/tuners/rome/rome.py @@ -76,7 +76,7 @@ def prepare_model(model: nn.Module, config: RomeConfig, adapter_name: str): hparams = ROMEHyperParams.from_name(config.model_type) modified_keys = apply_rome_to_model(model, config.tokenizer, config.knowledge, hparams, config.batch_first) - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): return {key: value for key, value in state_dict.items() if key in modified_keys} def mark_trainable_callback(model): diff --git a/swift/tuners/scetuning/scetuning.py b/swift/tuners/scetuning/scetuning.py index 039678b34..efc1c9a72 100644 --- a/swift/tuners/scetuning/scetuning.py +++ b/swift/tuners/scetuning/scetuning.py @@ -176,7 +176,7 @@ def _forward_decoder_mode(self, *args, **kwargs): if len(hint_module_ins_list) > 0: setattr(t_module, 'hint', hint_module_ins_list[tuner_id]) - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): state_dict_new = {key: value for key, value in state_dict.items() if f'scetuner_{adapter_name}' in key} return state_dict_new diff --git a/swift/tuners/side.py b/swift/tuners/side.py index dab925025..60a317b09 100644 --- a/swift/tuners/side.py +++ b/swift/tuners/side.py @@ -107,7 +107,7 @@ def forward_seq(self, input, *args, **kwargs): setattr(tgt_module, f'side_{adapter_name}', side_module) logger.info(f'Side modules(module_key): {module_key}.side_{adapter_name}') - def state_dict_callback(state_dict, adapter_name): + def state_dict_callback(state_dict, adapter_name, **kwargs): return {key: value for key, value in state_dict.items() if f'side_{adapter_name}' in key} def mark_trainable_callback(model):