Skip to content

Commit

Permalink
Update finetune_moss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xyltt authored Apr 24, 2023
1 parent 89fee5d commit 24c2f16
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions finetune_moss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Code for moss-moon-003-sft"""
"""Code for moss-sft"""

import os
import copy
Expand All @@ -7,15 +7,15 @@
import logging
import argparse

from tqdm import tqdm
import torch.distributed as dist

from tqdm import tqdm
from accelerate import Accelerator
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from accelerate import Accelerator, DeepSpeedPlugin
from transformers import set_seed, get_cosine_schedule_with_warmup

from models.modeling_moss import MossForCausalLM
from models.tokenization_moss import MossTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -174,8 +174,10 @@ def train(args):

accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_bsz_per_gpu

tokenizer = MossTokenizer.from_pretrained(args.model_path)
model = MossForCausalLM.from_pretrained(args.model_path, use_cache=False)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
tokenizer.eos_token_id = 106068 # The eos_token_id of base model is 106028. We need map the eos token to <eom> (its token id is 106068)

model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, use_cache=False)

model.transformer.gradient_checkpointing = True
assert model.transformer.gradient_checkpointing is True
Expand Down Expand Up @@ -254,7 +256,7 @@ def train(args):

val_acc, val_loss = val_metric.get_metric()

if accelerator.is_main_process:
if accelerator.is_local_main_process:
writer.add_scalar(f'val_loss', val_loss, global_step=global_step)
writer.add_scalar(f'val_acc', val_acc, global_step=global_step)
accelerator.print(f"Epoch: {epoch}, Step: {batch_cnt}, Val loss: {val_loss}, Val acc: {val_acc}")
Expand All @@ -272,7 +274,7 @@ def train(args):
parser = argparse.ArgumentParser(description='Args of sft')

# Model Args
parser.add_argument('--model_path', default='./ckpts/moss-16B-base', type=str)
parser.add_argument('--model_name_or_path', default='./ckpts/moss-16B-base', type=str)

# Data Args
parser.add_argument('--data_dir', default='./data/sft', type=str)
Expand Down

0 comments on commit 24c2f16

Please sign in to comment.