forked from pengxiao-song/LaWGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
merge.py
74 lines (54 loc) · 2.09 KB
/
merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
import transformers
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
import argparse
parser = argparse.ArgumentParser(description='Merge Base Model and Lora')
parser.add_argument('--base_model', type=str, default="minlik/chinese-llama-7b-merged", help='base model path')
parser.add_argument('--lora_model', type=str, default="entity303/legal-lora-7b", help='lora model path')
parser.add_argument('--output_dir', type=str, default="./models/base_models/llama-7b-legal-lora-merged", help='output model path')
args = parser.parse_args()
BASE_MODEL = args.base_model
LORA_MODEL = args.lora_model
OUTPUT_DIR = args.output_dir
assert (
BASE_MODEL
), "Please specify a value for BASE_MODEL environment variable, e.g. `export BASE_MODEL=huggyllama/llama-7b`" # noqa: E501
print(f"{'*'*20} Using base model: {BASE_MODEL} {'*'*20}")
print(f"{'*'*20} Using lora model: {LORA_MODEL} {'*'*20}")
print(f"{'*'*20} Saving to: {OUTPUT_DIR} {'*'*20}")
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map={"": "cpu"},
)
first_weight = base_model.model.layers[0].self_attn.q_proj.weight
first_weight_old = first_weight.clone()
lora_model = PeftModel.from_pretrained(
base_model,
LORA_MODEL,
device_map={"": "cpu"},
torch_dtype=torch.float16,
)
lora_weight = lora_model.base_model.model.model.layers[
0
].self_attn.q_proj.weight
assert torch.allclose(first_weight_old, first_weight)
# merge weights - new merging method from peft
lora_model = lora_model.merge_and_unload()
lora_model.train(False)
# did we do anything?
assert not torch.allclose(first_weight_old, first_weight)
lora_model_sd = lora_model.state_dict()
deloreanized_sd = {
k.replace("base_model.model.", ""): v
for k, v in lora_model_sd.items()
if "lora" not in k
}
LlamaForCausalLM.save_pretrained(
base_model, OUTPUT_DIR, state_dict=deloreanized_sd, max_shard_size="2048MB"
)
LlamaTokenizer.save_pretrained(tokenizer, OUTPUT_DIR)