Skip to content

Commit

Permalink
Merge pull request allenai#59 from allenai/qlora
Browse files Browse the repository at this point in the history
QLora support
  • Loading branch information
yizhongw authored Sep 15, 2023
2 parents 6249fd1 + bdc1e4d commit 376ebd4
Show file tree
Hide file tree
Showing 6 changed files with 375 additions and 160 deletions.
12 changes: 9 additions & 3 deletions eval/alpaca_farm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
parser.add_argument("--openai_engine", "-o", type=str, default=None)
# where to save generations - default current directory
parser.add_argument("--save_folder", "-s", type=str, default="")
parser.add_argument("--tokenizer", "-t", type=str, default=None)
parser.add_argument("--padding_side", "-p", type=str, default="right") # llama2 requires left padding
args = parser.parse_args()

assert not (args.model and args.openai_engine), "only provide one of --model or --openai"
Expand All @@ -35,19 +37,23 @@
my_outputs = []
if not os.path.exists(os.path.join(args.save_folder, sample_filename)):
if args.openai_engine is None:
model = AutoModelForCausalLM.from_pretrained(args.model, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map="auto",
torch_dtype=torch.bfloat16,)
tokenizer = AutoTokenizer.from_pretrained(args.model if args.tokenizer is None else args.tokenizer, legacy=True, use_fast=False)
# add padding token if not already there
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<pad>"})
model.resize_token_embeddings(len(tokenizer))
tokenizer.padding_side = args.padding_side
logging.info("model and data loaded!")
logging.info("generating...")
generation_config = GenerationConfig.from_pretrained(
args.model,
max_new_tokens=2048,
# top_p=0.9,
# do_sample=True,
# do_sample=False,
# num_return_sequences=1,
# temperature=1.0,
# top_k=0
Expand Down
133 changes: 100 additions & 33 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
get_scheduler,
GPTNeoXTokenizerFast,
GPT2Tokenizer,
OPTForCausalLM
OPTForCausalLM,
BitsAndBytesConfig,
)
from peft import LoraConfig, TaskType, get_peft_model
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training

logger = get_logger(__name__)

Expand Down Expand Up @@ -202,6 +203,31 @@ def parse_args():
"If passed, LLM loading time and RAM consumption will be benefited."
),
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help=(
"Turn on gradient checkpointing. Saves memory but slows training."
),
)
parser.add_argument(
"--use_qlora",
action="store_true",
help=(
"Use qLoRA training - main thing is initialising model in quantised form. Not compatible with deepspeed."
),
)
parser.add_argument(
'--clip_grad_norm',
type=float,
default=-1,
help='Clip gradient norm. Not compatible with deepspeed (use deepspeed config instead).',
)
parser.add_argument(
'--use_8bit_optimizer',
action='store_true',
help='Use 8bit optimizer from bitsandbytes. Not compatible with deepspeed (use deepspeed config instead).',
)
args = parser.parse_args()

# Sanity checks
Expand Down Expand Up @@ -298,14 +324,32 @@ def _concat_messages(messages):
'labels': labels.flatten(),
'attention_mask': attention_mask.flatten(),
}


def save_with_accelerate(accelerator, model, tokenizer, output_dir, args):
unwrapped_model = accelerator.unwrap_model(model)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
# Otherwise, sometimes the model will be saved with only part of the parameters.
# Also, accelerator needs to use the wrapped model to get the state_dict.
state_dict = accelerator.get_state_dict(model)
if args.use_lora:
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
# and has its own save_pretrained function for only saving lora modules.
# We have to manually specify the is_main_process outside the save_pretrained function.
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
else:
unwrapped_model.save_pretrained(
output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict
)


def main():
args = parse_args()

# A hacky way to make llama work with flash attention
if args.use_flash_attn:
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
from llama2_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()

# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
Expand All @@ -318,7 +362,6 @@ def main():
accelerator_log_kwargs["project_dir"] = args.output_dir

accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -381,12 +424,31 @@ def main():
)

if args.model_name_or_path:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
if args.use_qlora:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
device_index = accelerator.process_index
device_map = {"": device_index} # force data-parallel training.
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
load_in_4bit=True,
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
low_cpu_mem_usage=args.low_cpu_mem_usage,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)
Expand Down Expand Up @@ -417,15 +479,23 @@ def main():
model.resize_token_embeddings(len(tokenizer))

if args.use_lora:
if args.use_qlora:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

logger.info("Initializing LORA model...")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout
lora_dropout=args.lora_dropout,
target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"]
)
model = get_peft_model(model, peft_config)
# peft breaks flash attention due to casting norms to fp32. This fixes it back up.
# See https://github.com/huggingface/peft/issues/790
from llama_flash_attn_monkey_patch import upcast_layer_for_flash_attention
model = upcast_layer_for_flash_attention(model, torch.bfloat16)
model.print_trainable_parameters()

# Preprocessing the datasets.
Expand Down Expand Up @@ -483,7 +553,16 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
if args.use_qlora:
from bitsandbytes.optim import AdamW
optimizer = AdamW(
optimizer_grouped_parameters,
lr=args.learning_rate,
optim_bits=8 if args.use_8bit_optimizer else 32,
is_paged=True
)
else:
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
Expand Down Expand Up @@ -600,20 +679,22 @@ def main():
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch, use_cache=False)
outputs = model(**batch, use_cache=False)
loss = outputs.loss
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
# clip gradient norm. don't do this with deepspeed
if accelerator.sync_gradients and args.clip_grad_norm > 0:
accelerator.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

# # Checks if the accelerator has performed an optimization step behind the scenes
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
completed_steps += 1

if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = accelerator.gather(total_loss).mean().item() / args.gradient_accumulation_steps / args.logging_steps
logger.info(f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}")
Expand All @@ -632,15 +713,16 @@ def main():
output_dir = f"step_{completed_steps}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
save_with_accelerate(accelerator, model, tokenizer, output_dir, args)

if completed_steps >= args.max_train_steps:
break

if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
save_with_accelerate(accelerator, model, tokenizer, output_dir, args)

if args.with_tracking:
accelerator.end_training()
Expand All @@ -649,22 +731,7 @@ def main():
accelerator.wait_for_everyone()
if accelerator.is_main_process:
tokenizer.save_pretrained(args.output_dir)
unwrapped_model = accelerator.unwrap_model(model)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
# Otherwise, sometimes the model will be saved with only part of the parameters.
# Also, accelerator needs to use the wrapped model to get the state_dict.
state_dict = accelerator.get_state_dict(model)
if args.use_lora:
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
# and has its own save_pretrained function for only saving lora modules.
# We have to mannually specify the is_main_process outside the save_pretrained function.
if accelerator.is_main_process:
unwrapped_model.save_pretrained(args.output_dir, state_dict=state_dict)
else:
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=state_dict
)

save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 376ebd4

Please sign in to comment.