Skip to content

Commit

Permalink
support loading lora from hub
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jun 15, 2023
1 parent 0cee6ad commit 0574b59
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 25 deletions.
27 changes: 14 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

## Changelog

[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model.

[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)

[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
Expand Down Expand Up @@ -111,7 +113,7 @@ python -m transformers.models.llama.convert_llama_weights_to_hf \

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--do_train \
--dataset wiki_demo \
--finetuning_type lora \
Expand All @@ -132,11 +134,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_pt.py \

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--finetuning_type lora \
--checkpoint_dir path_to_pt_checkpoint \
--output_dir path_to_sft_checkpoint \
--overwrite_cache \
--per_device_train_batch_size 4 \
Expand All @@ -146,7 +147,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--save_steps 1000 \
--learning_rate 5e-5 \
--num_train_epochs 3.0 \
--resume_lora_training False \
--plot_loss \
--fp16
```
Expand All @@ -155,11 +155,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--do_train \
--dataset comparison_gpt4_en \
--finetuning_type lora \
--checkpoint_dir path_to_pt_checkpoint \
--output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
Expand All @@ -176,11 +175,11 @@ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--do_train \
--dataset alpaca_gpt4_en \
--finetuning_type lora \
--checkpoint_dir path_to_pt_checkpoint,path_to_sft_checkpoint \
--checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \
--output_dir path_to_ppo_checkpoint \
--per_device_train_batch_size 2 \
Expand All @@ -205,7 +204,7 @@ accelerate launch src/train_XX.py # arguments (same as above)

```bash
CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--do_eval \
--dataset alpaca_gpt4_en \
--checkpoint_dir path_to_checkpoint \
Expand All @@ -215,28 +214,28 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--predict_with_generate
```

We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation.
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.

### CLI Demo

```bash
python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```

### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```

### Export model

```bash
python src/export_model.py \
--model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
Expand All @@ -249,6 +248,8 @@ Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/ma

Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.

Please follow the [baichuan-7B License](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) to use the baichuan-7B model.

## Citation

If this work is helpful, please cite as:
Expand Down
13 changes: 7 additions & 6 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_peft_model
)

from peft.utils import CONFIG_NAME
from peft.utils import CONFIG_NAME, WEIGHTS_NAME

from trl import AutoModelForCausalLMWithValueHead

Expand Down Expand Up @@ -103,8 +103,10 @@ def _init_adapter(
lastest_checkpoint = None

if model_args.checkpoint_dir is not None:
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
please specify `--finetuning_type full/freeze` instead.")

if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
Expand Down Expand Up @@ -170,8 +172,7 @@ def load_pretrained(
**config_kwargs
)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
if tokenizer.pad_token_id == 64000:
tokenizer.pad_token_id = 0 # for baichuan model (need fix)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)

config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True
Expand Down Expand Up @@ -212,7 +213,7 @@ def load_pretrained(
low_cpu_mem_usage=True,
**config_kwargs
)
model = prepare_model_for_training(model) if is_trainable else model
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)

if stage == "rm" or stage == "ppo": # add value head
Expand Down
8 changes: 5 additions & 3 deletions src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ class FinetuningArguments:
default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM choices: [\"mlp\", \"self_attention\"]"}
BLOOM choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"}
)
lora_rank: Optional[int] = field(
default=8,
Expand All @@ -212,8 +213,9 @@ class FinetuningArguments:
lora_target: Optional[str] = field(
default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"gate_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"]"}
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
)

def __post_init__(self):
Expand Down
7 changes: 4 additions & 3 deletions src/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_logits_processor() -> LogitsProcessorList:
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
def prepare_model_for_training(
model: PreTrainedModel,
finetuning_type: str,
output_embedding_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
Expand All @@ -93,13 +94,13 @@ def make_inputs_require_grad(module, input, output):
model.gradient_checkpointing_enable()
model.config.use_cache = False # turn off when gradient checkpointing is enabled

if hasattr(model, output_embedding_layer_name):
output_embedding_layer = getattr(model, output_embedding_layer_name)
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
input_dtype = output_embedding_layer.weight.dtype

class CastOutputToFloat(torch.nn.Sequential):

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return super().forward(x.to(input_dtype)).to(torch.float32)

setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
Expand Down

0 comments on commit 0574b59

Please sign in to comment.