Skip to content

Commit

Permalink
Fix Pissa and OLoRA (modelscope#1852)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Aug 29, 2024
1 parent 8a1606d commit 14f10be
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 72 deletions.
2 changes: 1 addition & 1 deletion swift/llm/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def llm_rlhf(args: RLHFArguments) -> Dict[str, Any]:
ref_model = None

if hasattr(model, 'hf_device_map'):
logger.info(f'model.hf_device_map: {json.dumps(model.hf_device_map)}')
logger.info(f'model.hf_device_map: {model.hf_device_map}')

train_dataset, val_dataset = _get_train_val_dataset(args)
if val_dataset is None:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Any]:
is_training=True,
**kwargs)
if hasattr(model, 'hf_device_map'):
logger.info(f'model.hf_device_map: {json.dumps(model.hf_device_map)}')
logger.info(f'model.hf_device_map: {model.hf_device_map}')
for k in ['gptq', 'awq', 'aqlm']:
if getattr(model, f'is_{k}', None):
args.quant_method = k
Expand Down
72 changes: 2 additions & 70 deletions swift/llm/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from swift.utils import activate_model_parameters, freeze_model_parameters, get_logger, use_torchacc
from swift.utils.module_mapping import MODEL_KEYS_MAPPING
from .utils import SftArguments, find_all_linears, find_embedding, find_ln, is_adapter
from .utils.callbacks import DynamicLayerActivationCallback, TrainerAdapterCallback

logger = get_logger()

Expand Down Expand Up @@ -321,82 +322,13 @@ def prepare_model(model, args: SftArguments):
callbacks = []
if args.lisa_activated_layers > 0:
assert args.sft_type == 'full', 'LISA only supports full parameter training.'

class DynamicLayerActivationCallback(TrainerCallback):

def __init__(self, n_layers: int, step_interval: int, model: torch.nn.Module):
super().__init__()
self.n_layers = n_layers
self.step_interval = step_interval
self.model = model
layers_name = None
layers = None
for name, module in model.named_modules():
if isinstance(module, torch.nn.ModuleList):
layers_name = name
layers = module
break
assert layers_name is not None
self.layers_attribute = layers_name
self.total_layers = len(layers)

# Freeze all layers upon initialization
self.freeze_all_layers()
self.active_layers_indices = []

def freeze_all_layers(self):
layers = self.model.get_submodule(self.layers_attribute)
for layer in layers:
for param in layer.parameters():
param.requires_grad = False

def on_step_begin(self, args, state, control, **kwargs):
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.step_interval == 0 or state.global_step == 1:
self.switch_active_layers()

def switch_active_layers(self):
# First, disable gradients for all layers
self.freeze_all_layers()

# Randomly select n_layers to activate
layers = self.model.get_submodule(self.layers_attribute)
self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
# Enable gradients only for the selected layers
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True

lisa_callback = DynamicLayerActivationCallback(
n_layers=args.lisa_activated_layers, # Number of layers to activate
step_interval=args.lisa_step_interval, # Step interval to update active layers
model=model)
lisa_callback.switch_active_layers() # Make trainable parameters printing a correct value
callbacks.append(lisa_callback)

class TrainerAdapterCallback(TrainerCallback):

def __init__(self):
self.global_step = 0

# offload original_modules to cpu, to save memory
def on_train_begin(self, _args, state, control, **kwargs):
if hasattr(model, 'set_active_adapters'):
model.set_active_adapters(model.adapters.keys(), offload='cpu')
if args.sft_type == 'adalora':
model.peft_config['default'].total_step = state.max_steps

def zero_grad(_self, *args, **kwargs):
_self.update_and_allocate(self.global_step + 1)
_self._zero_grad(*args, **kwargs)

model._zero_grad = model.zero_grad
model.zero_grad = types.MethodType(zero_grad, model)

def on_step_end(self, _args, state, control, **kwargs):
if args.sft_type == 'adalora':
self.global_step = state.global_step

if is_adapter(args.sft_type) and args.tuner_backend == 'swift':
callbacks.append(TrainerAdapterCallback())
callbacks.append(TrainerAdapterCallback(args))
return model, callbacks
80 changes: 80 additions & 0 deletions swift/llm/utils/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import types

import numpy as np
import torch
from peft import PeftModel
from transformers import TrainerCallback
from transformers.modeling_utils import unwrap_model


class TrainerAdapterCallback(TrainerCallback):

def __init__(self, args):
self.global_step = 0
self.args = args

# offload original_modules to cpu, to save memory
def on_train_begin(self, _args, state, control, **kwargs):
model = kwargs['model']
if hasattr(model, 'set_active_adapters'):
model.set_active_adapters(model.adapters.keys(), offload='cpu')
if self.args.sft_type == 'adalora':
model.peft_config['default'].total_step = state.max_steps

def zero_grad(_self, *args, **kwargs):
_self.update_and_allocate(self.global_step + 1)
_self._zero_grad(*args, **kwargs)

model._zero_grad = model.zero_grad
model.zero_grad = types.MethodType(zero_grad, model)

def on_step_end(self, _args, state, control, **kwargs):
if self.args.sft_type == 'adalora':
self.global_step = state.global_step


class DynamicLayerActivationCallback(TrainerCallback):

def __init__(self, n_layers: int, step_interval: int, model: torch.nn.Module):
super().__init__()
self.n_layers = n_layers
self.step_interval = step_interval
self.model = model
layers_name = None
layers = None
for name, module in model.named_modules():
if isinstance(module, torch.nn.ModuleList):
layers_name = name
layers = module
break
assert layers_name is not None
self.layers_attribute = layers_name
self.total_layers = len(layers)

# Freeze all layers upon initialization
self.freeze_all_layers()
self.active_layers_indices = []

def freeze_all_layers(self):
layers = self.model.get_submodule(self.layers_attribute)
for layer in layers:
for param in layer.parameters():
param.requires_grad = False

def on_step_begin(self, args, state, control, **kwargs):
# Check if it's time to switch active layers, including at step 0
if state.global_step % self.step_interval == 0 or state.global_step == 1:
self.switch_active_layers()

def switch_active_layers(self):
# First, disable gradients for all layers
self.freeze_all_layers()

# Randomly select n_layers to activate
layers = self.model.get_submodule(self.layers_attribute)
self.active_layers_indices = np.random.choice(range(self.total_layers), self.n_layers, replace=False)
# Enable gradients only for the selected layers
for idx in self.active_layers_indices:
for param in layers[idx].parameters():
param.requires_grad = True
25 changes: 25 additions & 0 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,28 @@ def _save_optimizer_and_scheduler(self, output_dir):

ta_save_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, output_dir)

def _save_initial_model(self, output_dir):
model = unwrap_model(self.model)
if isinstance(model, PeftModel):
config = model.peft_config.get('default', {})
init_lora_weights = getattr(config, 'init_lora_weights', '')
if 'pissa' in init_lora_weights or 'olora' in init_lora_weights:
config.init_lora_weights = True
model.save_pretrained(os.path.join(output_dir, 'initial_model'))
config.init_lora_weights = init_lora_weights

def _save_converted_model(self, output_dir):
model = unwrap_model(self.model)
if isinstance(model, PeftModel):
config = model.peft_config.get('default', {})
init_lora_weights = getattr(config, 'init_lora_weights', '')
if 'pissa' in init_lora_weights or 'olora' in init_lora_weights:
model.save_pretrained(
os.path.join(output_dir, 'converted'),
path_initial_model_for_weight_conversion=os.path.join(os.path.dirname(output_dir), 'initial_model'),
)
config.init_lora_weights = init_lora_weights

def _load_optimizer_and_scheduler(self, checkpoint):
if not (use_torchacc() and self.sft_args.fsdp_num > 1):
if self._resume_only_model:
Expand Down Expand Up @@ -428,6 +450,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
shutil.copy(src_path, dst_path)
elif os.path.isdir(src_path):
shutil.copytree(src_path, dst_path)
self._save_converted_model(output_dir)

def _save_checkpoint(self, model, trial, metrics=None):
self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}')
Expand Down Expand Up @@ -549,6 +572,8 @@ def train(self, resume_from_checkpoint: Optional[Union[str, bool]] = None, *args
resume_from_checkpoint = None
if self._resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_fsdp_enabled:
self._load_from_checkpoint(self._resume_from_checkpoint)

self._save_initial_model(self.args.output_dir)
res = super().train(resume_from_checkpoint, *args, **kwargs)
self._resume_from_checkpoint = None
if self.max_memory != 0:
Expand Down

0 comments on commit 14f10be

Please sign in to comment.