forked from EvolvingLMMs-Lab/Otter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconverting_otter_to_lora.py
61 lines (53 loc) · 1.75 KB
/
converting_otter_to_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import argparse
import torch
import sys
from modeling_otter import OtterForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
MODEL_CLASSES = {
"LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj",
"GPTNeoXForCausalLM": "gpt_neox",
"MPTForCausalLM": "mpt",
}
# Define argument parser
parser = argparse.ArgumentParser(description="Load a model with specified precision and save it to a specified path.")
# Add arguments
parser.add_argument(
"--checkpoint_path",
type=str,
help="Path to the pre-trained model checkpoint.",
default="/data/bli/checkpoints/OTTER-MPT7B-Instruct0705",
)
parser.add_argument(
"--save_path",
type=str,
default="/data/bli/checkpoints/OTTER-MPT7B-Instruct0705-LoRA",
help="Path to the converted model checkpoint.",
)
# Parse the input arguments
args = parser.parse_args()
# Load the model
model = OtterForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map="auto")
# adding lora
standard_modules = ["q_proj", "v_proj"]
lang_encoder_short_name = MODEL_CLASSES[model.config.text_config.architectures[0]]
model_to_lora_modules = {
"llama": standard_modules,
"opt": standard_modules,
"gptj": standard_modules,
"gpt_neox": ["query_key_value"],
"mpt": ["Wqkv"],
}
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
task_type=TaskType.CAUSAL_LM,
target_modules=model_to_lora_modules[lang_encoder_short_name],
)
model.config.update({"lora_config": {"r": 16, "lora_alpha": 32, "lora_dropout": 0.05}})
model.lang_encoder = get_peft_model(model.lang_encoder, lora_config)
# Save the model
checkpoint_path = args.save_path
OtterForConditionalGeneration.save_pretrained(model, checkpoint_path)