-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
43 lines (35 loc) · 1000 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import transformers
import model
import data
import os
is_ddp = int(os.environ.get("WORLD_SIZE", 1)) != 1
m = model.get_model()
ds = data.TrainDataset()
collator = transformers.DataCollatorForSeq2Seq(ds.tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)
# import torch
# import peft
# adapters_weights = torch.load("./output/checkpoint-800/adapter_model.bin")
# peft.set_peft_model_state_dict(m, adapters_weights)
trainer = transformers.Trainer(
model=m,
train_dataset=ds,
data_collator=collator,
args=transformers.TrainingArguments(
per_device_train_batch_size=4,
num_train_epochs=1,
learning_rate=3e-4,
fp16=True,
logging_steps=10,
optim="adamw_torch",
evaluation_strategy="no",
save_strategy="steps",
eval_steps=None,
save_steps=200,
output_dir="./output",
save_total_limit=3,
ddp_find_unused_parameters=False if is_ddp else None,
),
)
m.config.use_cache = False
trainer.train()
m.save_pretrained("./weights")