forked from togethercomputer/OpenChatKit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request togethercomputer#160 from togethercomputer/jue-pat…
…ch-1 Adding Llama-2
- Loading branch information
Showing
10 changed files
with
1,188 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
import argparse | ||
import torch | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig | ||
|
||
DIR = os.path.dirname(os.path.abspath(__file__)) | ||
USE_AUTH_TOKEN = False | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Convert HF checkpoints') | ||
parser.add_argument('--model-name', type=str, default='togethercomputer/Llama-2-7B-32K-beta', | ||
help='model-name') | ||
parser.add_argument('--save-dir', type=str, default=DIR, | ||
help='model-name') | ||
parser.add_argument('--offload-dir', type=str, default=None, | ||
help='directory to offload from memory') | ||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.save_dir): | ||
os.mkdir(args.save_dir) | ||
save_path = os.path.join(args.save_dir, args.model_name.replace('/', '_')) | ||
if not os.path.exists(save_path): | ||
os.mkdir(save_path) | ||
|
||
print('loading model from HF...') | ||
config = AutoConfig.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN) | ||
config.save_pretrained(save_path) | ||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=USE_AUTH_TOKEN) | ||
tokenizer.save_pretrained(save_path) | ||
|
||
# offload model from memory to disk if offload-dir is specified | ||
if args.offload_dir is not None: | ||
if not os.path.exists(args.offload_dir): | ||
os.mkdir(args.offload_dir) | ||
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, device_map="auto", offload_folder=args.offload_dir, use_auth_token=USE_AUTH_TOKEN) | ||
else: | ||
model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_auth_token=USE_AUTH_TOKEN) | ||
print('loaded model from HF...') | ||
|
||
print('converting the embedding layer...') | ||
item = {} | ||
item['embed_tokens.weight'] = model.model.embed_tokens.weight | ||
torch.save(item, os.path.join(save_path, 'pytorch_embs.pt')) | ||
print('converted the embedding layer.') | ||
|
||
for i in range(len(model.model.layers)): | ||
print(f'converting the {i}-th transformer layer...') | ||
torch.save(model.model.layers[i].state_dict(), os.path.join(save_path, f'pytorch_{i}.pt')) | ||
print(f'converted the {i}-th transformer layer.') | ||
|
||
print('converting the lm_head layer...') | ||
item = {} | ||
item['lm_head.weight'] = model.lm_head.weight | ||
item['norm.weight'] = model.model.norm.weight | ||
torch.save(item, os.path.join(save_path, 'pytorch_lm_head.pt')) | ||
print('converted the lm_head layer.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import os | ||
import argparse | ||
import torch | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from transformers import LlamaForCausalLM | ||
from transformers import AutoConfig, AutoTokenizer | ||
|
||
from transformers.modeling_utils import no_init_weights | ||
import os | ||
|
||
|
||
def create_emtpy_llama(config): | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
_reset_parameters_linear = nn.Linear.reset_parameters | ||
def dummy(*args, **kargs): | ||
pass | ||
nn.Linear.reset_parameters = dummy | ||
|
||
# 1. disable init for faster initialization | ||
# 2. avoid tie token embeddings with lm_head, as we train them separately. | ||
with no_init_weights(_enable=True): | ||
model = LlamaForCausalLM(config).eval() | ||
|
||
nn.Linear.reset_parameters = _reset_parameters_linear | ||
|
||
return model | ||
|
||
def load_decentralized_checkpoint(model, checkpoint_path, n_stages=2, n_layer_per_stage=16, ): | ||
input_path = checkpoint_path | ||
|
||
n_layers = len(model.model.layers) | ||
assert n_stages * n_layer_per_stage >= len(model.model.layers) | ||
# assert model.lm_head.weight.data is not model.transformer.wte.weight.data | ||
|
||
for i in range(n_stages): | ||
|
||
print(f'loading stage {i}') | ||
|
||
checkpoint = torch.load(os.path.join(input_path, f'prank_{i}_checkpoint.pt'), map_location=torch.device("cpu")) | ||
|
||
if i == 0: | ||
_tmp = {k[len(f"{0}."):]:v for k,v in checkpoint.items() if k.startswith(f"0.")} | ||
# torch.save(_tmp, os.path.join(output_path, f'pytorch_embs.pt')) | ||
model.model.embed_tokens.weight.data[:] = _tmp['embed_tokens.weight'] | ||
|
||
for j in range(n_layer_per_stage): | ||
_tmp = {k[len(f"{j+1}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j+1}.")} | ||
if len(_tmp) == 0: | ||
break | ||
# torch.save(_tmp, os.path.join(output_path, f'pytorch_{j}.pt')) | ||
ret = model.model.layers[j].load_state_dict(_tmp, strict=False) | ||
if len(ret.missing_keys): | ||
print('The following weight keys are missing:') | ||
print(ret.missing_keys) | ||
if len(ret.unexpected_keys): | ||
print('The following weight keys are unexpected:') | ||
print(ret.unexpected_keys) | ||
|
||
elif i == n_stages - 1: | ||
for j in range(n_layer_per_stage): | ||
if i*n_layer_per_stage + j == n_layers: | ||
break | ||
_tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} | ||
if len(_tmp) == 0: | ||
break | ||
# torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt')) | ||
ret = model.model.layers[i*n_layer_per_stage + j].load_state_dict(_tmp, strict=False) | ||
if len(ret.missing_keys): | ||
print('The following weight keys are missing:') | ||
print(ret.missing_keys) | ||
if len(ret.unexpected_keys): | ||
print('The following weight keys are unexpected:') | ||
print(ret.unexpected_keys) | ||
else: | ||
j += 1 | ||
|
||
_tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} | ||
if len(_tmp) == 0: | ||
break | ||
# torch.save(_tmp, os.path.join(output_path, f'pytorch_lm_head.pt')) | ||
model.model.norm.weight.data[:] = _tmp['norm.weight'] | ||
if 'norm.bias' in _tmp: | ||
model.model.norm.bias.data[:] = _tmp['norm.bias'] | ||
model.lm_head.weight.data[:] = _tmp['lm_head.weight'] | ||
if 'lm_head.bias' in _tmp: | ||
model.lm_head.bias.data[:] = _tmp['lm_head.bias'] | ||
|
||
else: | ||
for j in range(n_layer_per_stage): | ||
_tmp = {k[len(f"{j}."):]:v for k,v in checkpoint.items() if k.startswith(f"{j}.")} | ||
if len(_tmp) == 0: | ||
break | ||
# torch.save(_tmp, os.path.join(output_path, f'pytorch_{i*n_layer_per_stage + j}.pt')) | ||
ret = model.model.layers[i*n_layer_per_stage + j].load_state_dict(_tmp, strict=False) | ||
if len(ret.missing_keys): | ||
print('The following weight keys are missing:') | ||
print(ret.missing_keys) | ||
if len(ret.unexpected_keys): | ||
print('The following weight keys are unexpected:') | ||
print(ret.unexpected_keys) | ||
|
||
return model | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser(description='Convert HF checkpoints') | ||
parser.add_argument('--config-name', type=str, default='togethercomputer/Llama-2-7B-32K-beta', | ||
help='config-name') | ||
parser.add_argument('--ckpt-path', type=str, default=None, | ||
help='ckpt-path') | ||
parser.add_argument('--save-path', type=str, default=None, | ||
help='save-path') | ||
parser.add_argument('--n-stages', type=int, default=8, | ||
help='pipeline group size') | ||
parser.add_argument('--n-layer-per-stage', type=int, default=4, | ||
help='n layers per GPU device') | ||
parser.add_argument('--fp16', default=False, action='store_true') | ||
args = parser.parse_args() | ||
|
||
assert args.ckpt_path is not None | ||
assert args.save_path is not None | ||
|
||
if not os.path.exists(args.save_path): | ||
os.mkdir(args.save_path) | ||
|
||
# LlamaForCausalLM LlamaConfig LlamaTokenizer | ||
print('loading config...') | ||
config = AutoConfig.from_pretrained(args.config_name) | ||
print('loaded config.') | ||
print('loading tokenizer...') | ||
tokenizer = AutoTokenizer.from_pretrained(args.config_name) | ||
print('loaded tokenizer.') | ||
print('creating empty model...') | ||
model = create_emtpy_llama(config) | ||
if args.fp16: | ||
model = model.half() | ||
print('created empty model.') | ||
print('loading model ckpt...') | ||
load_decentralized_checkpoint( | ||
model, args.ckpt_path, n_stages=args.n_stages, n_layer_per_stage=args.n_layer_per_stage, | ||
) | ||
print('loaded model ckpt.') | ||
|
||
print('saving HF model...') | ||
model.save_pretrained(args.save_path) | ||
print(f'saved HF model to `{args.save_path}`') | ||
config.save_pretrained(args.save_path) | ||
tokenizer.save_pretrained(args.save_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.