Skip to content

Commit

Permalink
Support GaLore (modelscope#532)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Mar 11, 2024
1 parent dda4c64 commit 45ada3e
Show file tree
Hide file tree
Showing 20 changed files with 1,058 additions and 33 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用


## 🎉 News
- 🔥2024.03.11: Support [GaLore](https://arxiv.org/abs/2403.03507), which can efficiently reduce the memory usage(almost half of the original memory) when training the full model.
- 🔥2024.03.10: For the end-to-end best practice of fine-tuning to deployment of Qwen1.5-7B-Chat and Qwen1.5-72B-Chat, you can refer to the [Qwen1.5 Full Workflow Best Practice](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Qwen1.5%E5%85%A8%E6%B5%81%E7%A8%8B%E6%9C%80%E4%BD%B3%E5%AE%9E%E8%B7%B5.md).
- 🔥2024.03.09: Support training and inference of MAMBA series, use [this script](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/mamba-1.4b/lora/sft.sh) to begin.
- 2024.03.09: Support training and inference of AQLM quantized models, use [this script](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/llama2_7b_aqlm_2bit_1x16/lora/sft.sh) to begin.
Expand Down
1 change: 1 addition & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。

## 🎉 新闻
- 🔥2024.03.11: 支持[GaLore](https://arxiv.org/abs/2403.03507), 用于在全参数训练中有效减小显存占用至原来的1/2.
- 🔥2024.03.10: Qwen1.5-7B-Chat与Qwen1.5-72B-Chat从微调到部署[全流程最佳实践](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Qwen1.5%E5%85%A8%E6%B5%81%E7%A8%8B%E6%9C%80%E4%BD%B3%E5%AE%9E%E8%B7%B5.md).
- 🔥2024.03.09: 支持MAMBA模型的训练和推理, 使用[这个脚本](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/mamba-1.4b/lora/sft.sh)来开始训练!.
- 2024.03.09: 支持AQLM量化模型的训练和推理, 使用[这个脚本](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/llama2_7b_aqlm_2bit_1x16/lora/sft.sh)开始训练!
Expand Down
11 changes: 11 additions & 0 deletions docs/source/LLM/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@

- `--lora_lr_ratio`: 默认值`None`, 建议值`10~16`, 使用lora时指定该参数即可使用lora+.

### GaLore微调参数

- `--use_galore: bool` : 默认值False, 是否使用GaLore.
- `--galore_target_modules: Union[str, List[str]]` : 默认值None, 不传的情况下对attention和mlp应用GaLore.
- `--galore_rank: int` : 默认值128, GaLore的rank值.
- `--galore_update_proj_gap: int` : 默认值50, 分解矩阵的更新间隔.
- `--galore_scale: int` : 默认值1.0, 矩阵权重系数.
- `--galore_proj_type: str` : 默认值`std`, GaLore矩阵分解类型.
- `--galore_optim_per_parameter: bool` : 默认值False, 是否给每个Galore目标Parameter设定一个单独的optimizer.
- `--galore_with_embedding: bool` : 默认值False, 是否对embedding应用GaLore.

### LLaMA-PRO微调参数

- `--llamapro_num_new_blocks`: 默认值`4`, 插入的新layers总数.
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/llm/scripts/qwen1half_7b_chat/full/sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ swift sft \
--model_type qwen1half-7b-chat \
--sft_type full \
--train_dataset_sample -1 \
--eval_steps 100 \
--eval_steps 1000 \
--output_dir output \
--num_train_epochs 1 \
--max_length 4096 \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Experimental environment: A100
CUDA_VISIBLE_DEVICES=0 \
swift infer \
--ckpt_dir "output/qwen1half-7b-chat/vx-xxx/checkpoint-xxx" \
--load_dataset_config true \
--max_length 4096 \
--use_flash_attn true \
18 changes: 18 additions & 0 deletions examples/pytorch/llm/scripts/qwen1half_7b_chat/galore/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Experimental environment: A100
# 40GB GPU memory
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model_type qwen1half-7b-chat \
--sft_type full \
--use_galore true \
--galore_update_proj_gap 400 \
--train_dataset_sample -1 \
--eval_steps 1000 \
--output_dir output \
--num_train_epochs 1 \
--max_length 4096 \
--learning_rate 1e-5 \
--use_flash_attn true \
--save_only_model true \
--dataset codefuse-evol-instruction-zh \
--preprocess_num_proc 4 \
18 changes: 18 additions & 0 deletions swift/llm/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,24 @@ def prepare_model(model, args: SftArguments):
model = Swift.prepare_model(model, {'neftune': neftune_config})
logger.info(f'neftune_config: {neftune_config}')

if args.use_galore:
from swift.trainers.optimizers.galore import GaLoreConfig
model_type = args.model_type
for key in MODEL_KEYS_MAPPING.keys():
if key in model_type.lower():
model_type = key
break
args.training_args.galore_config = GaLoreConfig(
model_type=model_type,
target_modules=args.galore_target_modules,
rank=args.galore_rank,
update_proj_gap=args.galore_update_proj_gap,
galore_scale=args.galore_scale,
proj_type=args.galore_proj_type,
optim_per_parameter=args.galore_optim_per_parameter,
with_embedding=args.galore_with_embedding,
)

class TrainerAdapterCallback(TrainerCallback):

def __init__(self):
Expand Down
10 changes: 10 additions & 0 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ class SftArguments:
lora_loftq_config: Dict = field(default_factory=dict)
use_dora: bool = False

# galore
use_galore: bool = False
galore_rank: int = 128
galore_target_modules: Union[str, List[str]] = None
galore_update_proj_gap: int = 50
galore_scale: float = 1.0
galore_proj_type: str = 'std'
galore_optim_per_parameter: bool = False
galore_with_embedding: bool = False

# adalora
adalora_target_r: int = 8
adalora_init_r: int = 12
Expand Down
7 changes: 5 additions & 2 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TextStreamer, trainer)

from swift.hub import ModelScopeConfig
from swift.tuners.module_mapping import MODEL_KEYS_MAPPING, ModelKeys
from swift.utils import (get_dist_setting, get_logger, is_ddp_plus_mp, is_dist,
is_local_master, is_master, stat_array, upper_bound)
from .template import History, StopWords, StopWordsCriteria, Template
Expand Down Expand Up @@ -367,8 +368,10 @@ def find_all_linears(model: Module, quantization_bit: int,
model_type: str) -> List[str]:
"""ref: https://github.com/artidoro/qlora"""
head_module_name = 'lm_head'
if model_type.startswith('chatglm'):
head_module_name = 'output_layer'
if model_type in MODEL_KEYS_MAPPING:
output = MODEL_KEYS_MAPPING[model_type].output
idx = output.rfind('.')
head_module_name = output[idx + 1:]
if quantization_bit == 4:
from bitsandbytes.nn import Linear4bit
linear_cls = [Linear4bit]
Expand Down
50 changes: 26 additions & 24 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import importlib
import os
import re
import shutil
from pathlib import Path
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import json
import numpy as np
Expand All @@ -17,7 +18,8 @@
from peft import PeftModel
from torch import nn
from torch.nn import Module
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (PreTrainedModel, PreTrainedTokenizerBase,
is_bitsandbytes_available)
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import unwrap_model
from transformers.trainer import ADAPTER_CONFIG_NAME # noqa
Expand All @@ -35,6 +37,7 @@
from swift.tuners import SwiftModel
from swift.utils import check_json_format, create_ms_repo, get_logger
from swift.utils.constants import Invoke
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import (can_return_loss, find_labels, get_function,
is_instance_of_ms_model)

Expand Down Expand Up @@ -586,6 +589,23 @@ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
self.log(logs)
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)

def create_optimizer_and_scheduler(self, num_training_steps: int):
if hasattr(self.args, 'galore_config'):
optimizer, lr_scheduler = create_optimizer_and_scheduler(
self.model,
self.args,
self.args.galore_config,
num_training_steps,
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
self.create_optimizer()
self.create_scheduler(
num_training_steps=num_training_steps,
optimizer=self.optimizer)

def create_optimizer(self):
opt_model = self.model

Expand All @@ -596,13 +616,15 @@ def create_optimizer(self):
f'If you are using lora+, please remember using transformers>=4.34.0, '
f'but now is {transformers.__version__}')
return super().create_optimizer()
else:
decay_parameters = self.get_decay_parameter_names(opt_model)

decay_parameters = self.get_decay_parameter_names(opt_model)
if isinstance(self.model, SwiftModel):
# Lora+ parameter groups (or a default one)
optimizer_grouped_parameters = self.model.create_optimizer_param_groups(
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay)
else:
# Default parameter groups
optimizer_grouped_parameters = [
{
'params': [
Expand All @@ -624,26 +646,6 @@ def create_optimizer(self):

optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args)

self.optimizer = optimizer_cls(optimizer_grouped_parameters,
**optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes

manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({
p.data_ptr(): p.numel()
for p in module.parameters()
}.values())
logger.info(
f'skipped {module}: {skipped/2**20}M params')
manager.register_module_override(
module, 'weight', {'optim_bits': 32})
logger.debug(
f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped/2**20}M params')
return self.optimizer
Empty file.
28 changes: 28 additions & 0 deletions swift/trainers/optimizers/galore/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import TYPE_CHECKING

from swift.utils.import_utils import _LazyModule

if TYPE_CHECKING:
from .utils import create_optimizer_and_scheduler, GaLoreConfig
from .adafactor import GaLoreAdafactor
from .adamw8bit import GaLoreAdamW8bit
from .adamw import GaLoreAdamW
else:
_import_structure = {
'utils': ['GaLoreConfig', 'create_optimizer_and_scheduler'],
'adafactor': ['GaLoreAdafactor'],
'adamw8bit': ['GaLoreAdamW8bit'],
'adamw': ['GaLoreAdamW'],
}

import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)
Loading

0 comments on commit 45ada3e

Please sign in to comment.