Skip to content

Commit

Permalink
support parallel adapter and gpt-j model
Browse files Browse the repository at this point in the history
  • Loading branch information
HZQ950419 committed Apr 1, 2023
1 parent e14ef9f commit 26dd513
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 13 deletions.
16 changes: 13 additions & 3 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
prepare_model_for_int8_training,
set_peft_model_state_dict,
)
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer # noqa: F402


def train(
Expand Down Expand Up @@ -50,6 +50,7 @@ def train(
bottleneck_size : int = 256,
non_linearity: str = "tanh",
adapter_dropout: float = 0.0,
use_parallel_adapter: bool = False,
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
group_by_length: bool = False, # faster, but produces an odd training loss curve
Expand All @@ -76,6 +77,10 @@ def train(
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"bottleneck_size: {bottleneck_size}\n"
f"non_linearity: {non_linearity}\n"
f"adapter_dropout: {adapter_dropout}\n"
f"use_parallel_adapter: {use_parallel_adapter}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
Expand Down Expand Up @@ -108,14 +113,18 @@ def train(
if len(wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = wandb_log_model

model = LlamaForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map=device_map,
)

tokenizer = LlamaTokenizer.from_pretrained(base_model)
if model.config.model_type == "llama":
# Due to the
tokenizer = LlamaTokenizer.from_pretrained(base_model)
else:
tokenizer = AutoTokenizer.from_pretrained(base_model)

tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
Expand Down Expand Up @@ -174,6 +183,7 @@ def generate_and_tokenize_prompt(data_point):
bottleneck_size=bottleneck_size,
non_linearity=non_linearity,
adapter_dropout=adapter_dropout,
use_parallel_adapter=use_parallel_adapter,
bias="none",
task_type="CAUSAL_LM",
)
Expand Down
2 changes: 2 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import sys

import fire
import gradio as gr
import torch
import transformers
sys.path.append(os.path.join(os.getcwd(), "peft/src"))
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

Expand Down
Binary file modified peft/src/peft/__pycache__/mapping.cpython-39.pyc
Binary file not shown.
18 changes: 15 additions & 3 deletions peft/src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@

TRANSFORMERS_MODELS_TO_BOTTLENECK_TARGET_MODULES_MAPPING = {
"llama": ["gate_proj", "up_proj", "down_proj"],
"gptj": ["fc_in", "fc_out"],

}

TRANSFORMERS_MODELS_TO_PARALLEL_TARGET_MODULES_MAPPING = {
"llama": ["v_proj", "down_proj"],
"gptj": ["v_proj", "fc_out"],
}



Expand Down Expand Up @@ -134,9 +140,15 @@ def _prepare_lora_config(peft_config, model_config):

def _prepare_bottleneck_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_BOTTLENECK_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_BOTTLENECK_TARGET_MODULES_MAPPING[model_config["model_type"]]
if peft_config.use_parallel_adapter:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_PARALLEL_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_PARALLEL_TARGET_MODULES_MAPPING[model_config["model_type"]]
else:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_BOTTLENECK_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_BOTTLENECK_TARGET_MODULES_MAPPING[model_config["model_type"]]

return peft_config


Expand Down
Binary file modified peft/src/peft/tuners/__pycache__/bottleneck.cpython-39.pyc
Binary file not shown.
43 changes: 36 additions & 7 deletions peft/src/peft/tuners/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from ..utils import PeftConfig, PeftType, transpose
from transformers.activations import ACT2FN


TRANSFORMERS_MODELS_TO_ADAPTER_TYPE_MAPPING = {
"llama": {"gate_proj": "mh_adapter", "up_proj":"mh_adapter", "down_proj":"output_adapter"},
"gptj": {"fc_in":"mh_adapter", "fc_out":"output_adapter"},
}

def is_bnb_available():
return importlib.util.find_spec("bitsandbytes") is not None

Expand All @@ -30,6 +36,7 @@ class BottleneckConfig(PeftConfig):
non_linearity (`str`): The non-linearity to apply to the bottleneck.
dropout (`float`, optional): The dropout probability of the bottleneck. Default to 0.0
bias ('str'): Bias type for Bottleneck. Can be 'none', 'all' or 'adapter_only'. Default to 'none'.
use_parallel_adapter (:obj:`bool`, optional): Whether to use parallel adapter. Defaults to False.
scaling (:obj:`float` or :obj:`str`, optional):
Scaling factor to use for scaled addition of adapter outputs as done by He et al. (2021). Can be either a
constant factor (float) or the string "learned", in which case the scaling factor is learned. Defaults to
Expand All @@ -50,6 +57,7 @@ class BottleneckConfig(PeftConfig):
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
},
)
use_parallel_adapter: bool = field(default=False, metadata={"help": "Whether to use parallel adapter"})
scaling: Union[float, str] = 1.0
bias: str = field(default="none", metadata={"help": "Bias type for Bottleneck. Can be 'none', 'all' or 'adapter_only'"})
init_weights: str = field(default="bert", metadata={"help": "Initialization method for the weights of the adapter modules."})
Expand Down Expand Up @@ -128,13 +136,11 @@ def _find_and_replace(self):
is_target_modules_in_base_model = True
parent, target, target_name = self._get_submodules(key)
# determine the type of adapter to be used, this will effect the forward pass
if self.model.config.model_type == "llama":
if target_name == "gate_proj" or target_name == "up_proj":
adapter_type = "mh_adapter"
kwargs.update({"adapter_type": adapter_type})
elif target_name == "down_proj":
adapter_type = "output_adapter"
kwargs.update({"adapter_type": adapter_type})
if self.peft_config.use_parallel_adapter:
adapter_type = "parallel_adapter"
else:
adapter_type = TRANSFORMERS_MODELS_TO_ADAPTER_TYPE_MAPPING[self.model.config.model_type][target_name]
kwargs.update({"adapter_type": adapter_type})

bias = target.bias is not None
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
Expand All @@ -150,11 +156,15 @@ def _find_and_replace(self):
new_module = Linear8bitLt(target.in_features, target.in_features, bias=bias, **kwargs)
elif adapter_type == "output_adapter":
new_module = Linear8bitLt(target.out_features, target.out_features, bias=bias, **kwargs)
elif adapter_type == "parallel_adapter":
new_module = Linear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
elif isinstance(target, torch.nn.Linear):
if adapter_type == "mh_adapter":
new_module = Linear(target.in_features, target.in_features, bias=bias, **kwargs)
elif adapter_type == "output_adapter":
new_module = Linear(target.out_features, target.out_features, bias=bias, **kwargs)
elif adapter_type == "parallel_adapter":
new_module = Linear(target.in_features, target.out_features, bias=bias, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
Expand Down Expand Up @@ -352,6 +362,13 @@ def forward(self, x: torch.Tensor):
output = self.adapter_up(self.act_fn(self.adapter_down(self.adapter_dropout(x)))) * self.adapter_scaling

result = output + residual
elif self.adapter_type == "parallel_adapter":
# for parallel_adapter, x will pass the linear layer first and the adapter layer parallelly.
# The output of the adapter layer will be added to the output of the linear layer
result = F.linear(x, self.weight, bias=self.bias)
output = self.adapter_up(self.act_fn(self.adapter_down(self.adapter_dropout(x)))) * self.adapter_scaling

result = result + output
return result


Expand Down Expand Up @@ -460,6 +477,18 @@ def forward(self, x: torch.Tensor):
residual = result_pre_forward
output = self.adapter_up(self.act_fn(self.adapter_down(self.adapter_dropout(result_pre_forward)))) * self.adapter_scaling
result = output + residual
elif self.adapter_type == "parallel_adapter":
if not torch.is_autocast_enabled():
expected_dtype = result_pre_forward.dtype

if x.dtype != torch.float32:
x = x.float()

output = self.adapter_up(self.act_fn(self.adapter_down(self.adapter_dropout(x)))).to(expected_dtype) * self.adapter_scaling
result = result_pre_forward + output
else:
output = self.adapter_up(self.act_fn(self.adapter_down(self.adapter_dropout(x)))) * self.adapter_scaling
result = result_pre_forward + output

return result

Expand Down

0 comments on commit 26dd513

Please sign in to comment.