Skip to content

Commit

Permalink
Merge pull request togethercomputer#160 from togethercomputer/jue-pat…
Browse files Browse the repository at this point in the history
…ch-1

Adding Llama-2
  • Loading branch information
zhangce authored Jul 28, 2023
2 parents 207f5c0 + 0cdc940 commit 3691351
Show file tree
Hide file tree
Showing 10 changed files with 1,188 additions and 23 deletions.
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ In this repo, you'll find code for:
- [Getting Started](#getting-started)
* [Requirements](#requirements)
* [Chatting with Pythia-Chat-Base-7B](#chatting-with-pythia-chat-base-7b)
- [Fine-tuning Llama-2-7B-32K-beta](#fine-tuning-llama-2-7b-32k-beta)
* [Downloading and converting the base model](#downloading-and-converting-the-base-model)
* [Fine-tuning the model](#fine-tuning-the-model)
* [Converting trained weights to Hugging Face format](#converting-trained-weights-to-hugging-face-format)
- [Reproducing Pythia-Chat-Base-7B](#reproducing-pythia-chat-base-7b)
* [Downloading training data and the base model](#downloading-training-data-and-the-base-model)
* [(Optional) 8bit Adam](#optional-8bit-adam)
Expand Down Expand Up @@ -104,6 +108,59 @@ The shell also supports additional commands to inspect hyperparamters, the full
Please see [the inference README](inference/README.md) for more details about arguments, running on multiple/specific GPUs, and running on consumer hardware.

# Fine-tuning Llama-2-7B-32K-beta

Llama-2-7B-32K-beta model can be fine-tuned using various datasets. In this tutorial, we will use the multi-document natural questions dataset and BookSum dataset.

## Downloading and converting the base model

To download model Llama-2-7B-32K-beta and prepare it for fine-tuning, run this command from the root of the repository.

```shell
python pretrained/Llama-2-7B-32K-beta/prepare.py
```

The weights for this model will be in the `pretrained/Llama-2-7B-32K-beta/togethercomputer_Llama-2-7B-32K-beta` directory.


## Fine-tuning the model

The `training/finetune_llama-2-7b-32k-mqa.sh` and `training/finetune_llama-2-7b-32k-booksum.sh` scripts configure and run the training loop.

1. To fine-tune the multi-document natural questions dataset, run:
```shell
bash training/finetune_llama-2-7b-32k-mqa.sh
```

2. To fine-tune the BookSum dataset, run:
```shell
bash training/finetune_llama-2-7b-32k-booksum.sh
```

As the training loop runs, checkpoints are saved to the `model_ckpts` directory at the root of the repo.

Please see [the training README](training/README.md) for more details about customizing the training run.

## Converting trained weights to Hugging Face format

Before you can use this model to perform inference, it must be converted to the Hugging Face format. Run this command from the root of the repo to do so.

For example
```shell
mkdir huggingface_models \
&& python tools/convert_to_hf_llama.py \
--config-name togethercomputer/Llama-2-7B-32K-beta \
--ckpt-path model_ckpts/llama-2-7b-32k-mqa/checkpoint_10 \
--save-path huggingface_models/llama-2-7b-32k-mqa \
--n-stages 4 \
--n-layer-per-stage 8 \
--fp16
```
where the `--fp16` flag will load and store models in fp16.

Make sure to replace model_ckpts/llama-2-7b-32k-mqa/checkpoint_10` with the latest checkpoint in the `model_ckpts/llama-2-7b-32k-mqa` or `model_ckpts/llama-2-7b-32k-booksum` directory.


# Reproducing Pythia-Chat-Base-7B

This tutorial walks through reproducing the Pythia-Chat-Base-7B model by fine-tuning Eleuther AI's Pythia-6.9B-deduped model using the OIG dataset.
Expand Down
31 changes: 16 additions & 15 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,28 @@ channels:
- conda-forge
- defaults
dependencies:
- cudatoolkit=11.6.0
- cupy=10.4.0
- cudatoolkit=11.8.0
- cupy=12.1.0
- faiss-gpu=1.7.2
- fastparquet=0.5.0
- nccl=2.12.12.1
- pip=22.3.1
- pyarrow=8.0.0
- nccl=2.18.3.1
- pip=23.2
- pyarrow=12.0.1
- python=3.10.9
- python-snappy=0.6.1
- pytorch=1.13.1
- pytorch-cuda=11.6
- pytorch=2.0.1
- pytorch-cuda=11.8
- snappy=1.1.9
- torchaudio=0.13.1
- torchvision=0.14.1
- torchaudio=2.0.2
- torchvision=0.15.2
- pip:
- accelerate==0.17.1
- accelerate==0.21.0
- boto3
- datasets==2.10.1
- datasets==2.13.1
- loguru==0.6.0
- netifaces==0.11.0
- pandas==1.5.3
- transformers==4.21.1
- wandb==0.13.10
- zstandard==0.20.0
- pandas==2.0.3
- transformers==4.31.0
- wandb==0.15.5
- zstandard==0.21.0
- sentencepiece
56 changes: 56 additions & 0 deletions pretrained/Llama-2-7B-32K-beta/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

DIR = os.path.dirname(os.path.abspath(__file__))
USE_AUTH_TOKEN = False

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Convert HF checkpoints')
parser.add_argument('--model-name', type=str, default='togethercomputer/Llama-2-7B-32K-beta',
help='model-name')
parser.add_argument('--save-dir', type=str, default=DIR,
help='model-name')
parser.add_argument('--offload-dir', type=str, default=None,
help='directory to offload from memory')
args = parser.parse_args()

if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
save_path = os.path.join(args.save_dir, args.model_name.replace('/', '_'))
if not os.path.exists(save_path):
os.mkdir(save_path)

print('loading model from HF...')
config = AutoConfig.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN)
config.save_pretrained(save_path)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN)
tokenizer.save_pretrained(save_path)

# offload model from memory to disk if offload-dir is specified
if args.offload_dir is not None:
if not os.path.exists(args.offload_dir):
os.mkdir(args.offload_dir)
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, device_map="auto", offload_folder=args.offload_dir, use_auth_token=USE_AUTH_TOKEN)
else:
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_auth_token=USE_AUTH_TOKEN)
print('loaded model from HF...')

print('converting the embedding layer...')
item = {}
item['embed_tokens.weight'] = model.model.embed_tokens.weight
torch.save(item, os.path.join(save_path, 'pytorch_embs.pt'))
print('converted the embedding layer.')

for i in range(len(model.model.layers)):
print(f'converting the {i}-th transformer layer...')
torch.save(model.model.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt'))
print(f'converted the {i}-th transformer layer.')

print('converting the lm_head layer...')
item = {}
item['lm_head.weight'] = model.lm_head.weight
item['norm.weight'] = model.model.norm.weight
torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt'))
print('converted the lm_head layer.')
155 changes: 155 additions & 0 deletions tools/convert_to_hf_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import argparse
import torch

import torch
import torch.nn as nn

from transformers import LlamaForCausalLM
from transformers import AutoConfig, AutoTokenizer

from transformers.modeling_utils import no_init_weights
import os


def create_emtpy_llama(config):

import torch
import torch.nn as nn

_reset_parameters_linear = nn.Linear.reset_parameters
def dummy(*args, **kargs):
pass
nn.Linear.reset_parameters = dummy

# 1. disable init for faster initialization
# 2. avoid tie token embeddings with lm_head, as we train them separately.
with no_init_weights(_enable=True):
model = LlamaForCausalLM(config).eval()

nn.Linear.reset_parameters = _reset_parameters_linear

return model

def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_per_stage=16, ):
input_path = checkpoint_path

n_layers = len(model.model.layers)
assert n_stages * n_layer_per_stage >= len(model.model.layers)
# assert model.lm_head.weight.data is not model.transformer.wte.weight.data

for i in range(n_stages):

print(f'loading stage {i}')

checkpoint = torch.load(os.path.join(input_path, f'prank_{i}_checkpoint.pt'), map_location=torch.device("cpu"))

if i == 0:
_tmp = {k[len(f"{0}."):]:v for k,v in checkpoint.items() if k.startswith(f"0.")}
# torch.save(_tmp, os.path.join(output_path, f'pytorch_embs.pt'))
model.model.embed_tokens.weight.data[:] = _tmp['embed_tokens.weight']

for j in range(n_layer_per_stage):
_tmp = {k[len(f"{j+1}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j+1}.")}
if len(_tmp) == 0:
break
# torch.save(_tmp, os.path.join(output_path, f'pytorch_{j}.pt'))
ret = model.model.layers[j].load_state_dict(_tmp, strict=False)
if len(ret.missing_keys):
print('The following weight keys are missing:')
print(ret.missing_keys)
if len(ret.unexpected_keys):
print('The following weight keys are unexpected:')
print(ret.unexpected_keys)

elif i == n_stages - 1:
for j in range(n_layer_per_stage):
if i*n_layer_per_stage + j == n_layers:
break
_tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")}
if len(_tmp) == 0:
break
# torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt'))
ret = model.model.layers[i*n_layer_per_stage + j].load_state_dict(_tmp, strict=False)
if len(ret.missing_keys):
print('The following weight keys are missing:')
print(ret.missing_keys)
if len(ret.unexpected_keys):
print('The following weight keys are unexpected:')
print(ret.unexpected_keys)
else:
j += 1

_tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")}
if len(_tmp) == 0:
break
# torch.save(_tmp, os.path.join(output_path, f'pytorch_lm_head.pt'))
model.model.norm.weight.data[:] = _tmp['norm.weight']
if 'norm.bias' in _tmp:
model.model.norm.bias.data[:] = _tmp['norm.bias']
model.lm_head.weight.data[:] = _tmp['lm_head.weight']
if 'lm_head.bias' in _tmp:
model.lm_head.bias.data[:] = _tmp['lm_head.bias']

else:
for j in range(n_layer_per_stage):
_tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")}
if len(_tmp) == 0:
break
# torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt'))
ret = model.model.layers[i*n_layer_per_stage + j].load_state_dict(_tmp, strict=False)
if len(ret.missing_keys):
print('The following weight keys are missing:')
print(ret.missing_keys)
if len(ret.unexpected_keys):
print('The following weight keys are unexpected:')
print(ret.unexpected_keys)

return model


if __name__ == '__main__':

parser = argparse.ArgumentParser(description='Convert HF checkpoints')
parser.add_argument('--config-name', type=str, default='togethercomputer/Llama-2-7B-32K-beta',
help='config-name')
parser.add_argument('--ckpt-path', type=str, default=None,
help='ckpt-path')
parser.add_argument('--save-path', type=str, default=None,
help='save-path')
parser.add_argument('--n-stages', type=int, default=8,
help='pipeline group size')
parser.add_argument('--n-layer-per-stage', type=int, default=4,
help='n layers per GPU device')
parser.add_argument('--fp16', default=False, action='store_true')
args = parser.parse_args()

assert args.ckpt_path is not None
assert args.save_path is not None

if not os.path.exists(args.save_path):
os.mkdir(args.save_path)

# LlamaForCausalLM LlamaConfig LlamaTokenizer
print('loading config...')
config = AutoConfig.from_pretrained(args.config_name)
print('loaded config.')
print('loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(args.config_name)
print('loaded tokenizer.')
print('creating empty model...')
model = create_emtpy_llama(config)
if args.fp16:
model = model.half()
print('created empty model.')
print('loading model ckpt...')
load_decentralized_checkpoint(
model, args.ckpt_path, n_stages=args.n_stages, n_layer_per_stage=args.n_layer_per_stage,
)
print('loaded model ckpt.')

print('saving HF model...')
model.save_pretrained(args.save_path)
print(f'saved HF model to `{args.save_path}`')
config.save_pretrained(args.save_path)
tokenizer.save_pretrained(args.save_path)
3 changes: 2 additions & 1 deletion training/dist_clm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def main():
help='an uuid')

# Add AWS arguments for uploading checkpoints to S3
parser.add_argument('--checkpoint-upload-prefix', required=True, help='S3 bucket name')
parser.add_argument('--checkpoint-upload-prefix', default=None, help='S3 bucket name')
add_aws_arguments(parser)

args = parser.parse_args()
Expand Down Expand Up @@ -417,6 +417,7 @@ def main():

tokenizer = build_tokenizer(args)
tokenizer.model_max_length = args.seq_length
config.max_position_embeddings = args.seq_length
# config.vocab_size = tokenizer.vocab_size
config.bos_token_id = tokenizer.bos_token_id
config.eos_token_id = tokenizer.eos_token_id
Expand Down
Loading

0 comments on commit 3691351

Please sign in to comment.