Skip to content

Commit

Permalink
Merge pull request THUDM#438 from zRzRzRzRzRzRzR/main
Browse files Browse the repository at this point in the history
更新base模型微调脚本部分代码
  • Loading branch information
zRzRzRzRzRzRzR authored Nov 24, 2023
2 parents 1fd657d + 7ce416f commit bf56c85
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 80 deletions.
79 changes: 2 additions & 77 deletions finetune_basemodel_demo/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,87 +155,12 @@ def main():
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
#model.gradient_checkpointing_enable()
# model.gradient_checkpointing_enable()
model.enable_input_require_grads()
trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
trainer.save_state()


if __name__ == "__main__":
main()

# # Set seed before initializing model.
# set_seed(training_args.seed)
#
# # Load pretrained model and tokenizer
# config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
#
# tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
#
# if model_args.lora_checkpoint is not None:
# model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True).cuda()
# prefix_state_dict = torch.load(os.path.join(model_args.lora_checkpoint, "pytorch_model.bin"))
# new_prefix_state_dict = {}
# for k, v in prefix_state_dict.items():
# if k.startswith("transformer.prefix_encoder."):
# new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
# model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
# else:
# model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True).cuda()
#
# if model_args.quantization_bit is not None:
# print(f"Quantized to {model_args.quantization_bit} bit")
# model = model.quantize(model_args.quantization_bit)
# # model = model.float()
#
# with open(data_args.train_file, "r", encoding="utf-8") as f:
# if data_args.train_file.endswith(".json"):
# train_data = json.load(f)
# elif data_args.train_file.endswith(".jsonl"):
# train_data = [json.loads(line) for line in f]
#
# if data_args.train_format == "input-output":
# train_dataset = InputOutputDataset(
# train_data,
# tokenizer,
# data_args.max_source_length,
# data_args.max_target_length,
# )
# else:
# raise ValueError(f"Unknown train format: {data_args.train_format}")
# print(f"Train dataset size: {len(train_dataset)}")
# if training_args.local_rank < 1:
# sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer)
#
# # Data collator
# data_collator = DataCollatorForSeq2Seq(
# tokenizer,
# model=model,
# label_pad_token_id=-100,
# pad_to_multiple_of=None,
# padding=False
# )
#
# # Initialize our Trainer
# trainer = LoRATrainer(
# model=model,
# args=training_args,
# train_dataset=train_dataset,
# tokenizer=tokenizer,
# data_collator=data_collator,
# save_changed=model_args.lora_rank is not None
# )
#
# checkpoint = None
# if training_args.resume_from_checkpoint is not None:
# checkpoint = training_args.resume_from_checkpoint
# model.gradient_checkpointing_enable()
# model.enable_input_require_grads()
# trainer.train(resume_from_checkpoint=checkpoint)
# trainer.save_model() # Saves the tokenizer too for easy upload
# trainer.save_state()
#
#
# if __name__ == "__main__":
# main()
main()
4 changes: 2 additions & 2 deletions finetune_basemodel_demo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

# Argument Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="/data/share/models/chatglm3-6b-base",
parser.add_argument("--model", type=str, default=None,
help="The directory of the model")
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
parser.add_argument("--lora-path", type=str, default="/data/yuxuan/Code/ChatGLM3//output/chatglm-lora.pt",
parser.add_argument("--lora-path", type=str, default=None,
help="Path to the LoRA model checkpoint")
parser.add_argument("--device", type=str, default="cuda", help="Device to use for computation")
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum new tokens for generation")
Expand Down
3 changes: 2 additions & 1 deletion finetune_basemodel_demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
tqdm
datasets
fsspec
astunparse
astunparse
peft

0 comments on commit bf56c85

Please sign in to comment.