Skip to content

SFT Loss unable to decrease on MATH data #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
yxchng opened this issue Dec 10, 2024 · 7 comments
Closed

SFT Loss unable to decrease on MATH data #482

yxchng opened this issue Dec 10, 2024 · 7 comments

Comments

@yxchng
Copy link

yxchng commented Dec 10, 2024

I am trying to do SFT following this doc, https://github.com/allenai/open-instruct/blob/main/docs/tulu3.md.

My data are formatted as follow:

{
  "id": "0_1",
  "messages": [
    {
      "content": "You are a helpful assistant.",
      "role": "system"
    },
    {
      "content": "Chandra has four bowls, each a different color: red, blue, yellow, and green. She also has exactly one glass of each corresponding color. If she chooses a bowl and a glass from the cupboard, how many pairings are possible? For example, one such pairing is a blue bowl and a yellow glass.",
      "role": "user"
    },
    {
      "content": "To solve this problem, we determine the number of possible pairings between the four bowls and the four glasses, given that:\n\n- Each bowl has a unique color: red, blue, yellow, or green.\n- Each glass also has a unique color corresponding to the bowls.\n\n### Steps to Solve:\n\n1. **Select a Bowl:**  \n   There are 4 ways to choose a bowl (red, blue, yellow, or green).\n\n2. **Select a Glass:**  \n   For each chosen bowl, there are 4 ways to choose a glass (one for each color).\n\n### Total Pairings:\n\nThe total number of pairings is the product of these choices:\n\n\\[\n\\text{Total pairings} = 4 \\times 4 = 16\n\\]\n\n### Conclusion:\n\nThe total number of unique pairings is:\n\n\\[\n\\boxed{16}\n\\]",
      "role": "assistant"
    }
  ],
  "source": "MATH"
}

However my loss are as follows and not decreasing:

Screenshot from 2024-12-10 19-26-48

I am completely following tulu3 sft launch command as follows:

MACHINE_RANK=0
MAIN_PROCESS_IP=localhost
NUM_MACHINES=1
NUM_PROCESSES=4
PER_DEVICE_TRAIN_BATCH_SIZE=2
GRADIENT_ACCUMULATION_STEPS=16
OMP_NUM_THREADS=4  accelerate launch \
    --mixed_precision bf16 \
    --num_machines $NUM_MACHINES \
    --num_processes $NUM_PROCESSES \
    --machine_rank $MACHINE_RANK \
    --main_process_ip $MAIN_PROCESS_IP \
    --main_process_port 29402 \
    --use_deepspeed \
    --deepspeed_config_file configs/ds_configs/stage3_offloading_accelerate.conf \
    --deepspeed_multinode_launcher standard open_instruct/finetune.py \
    --model_name_or_path meta-llama/Llama-3.1-8B \
    --tokenizer_name meta-llama/Llama-3.1-8B \
    --use_slow_tokenizer \
    --use_flash_attn \
    --max_seq_length 4096 \
    --preprocessing_num_workers 16 \
    --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
    --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
    --learning_rate 5e-06 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --weight_decay 0.0 \
    --num_train_epochs 2 \
    --output_dir output/Llama-3-1-8B-Instruct_MATH_turn1 \
    --with_tracking \
    --report_to wandb \
    --logging_steps 1 \
    --reduce_loss sum \
    --model_revision main \
    --dataset_mixer_list data/train/Llama-3-1-8B-Instruct_MATH_turn1.json 1.0 \
    --checkpointing_steps epoch \
    --dataset_mix_dir output/Llama-3-1-8B-Instruct_MATH_turn1 \
    --exp_name Llama-3-1-8B-Instruct_MATH_turn1-sft \
    --seed 123 

Can I know if I am doing anything wrong?

@natolambert
Copy link
Collaborator

natolambert commented Dec 18, 2024

I'd say this looks normal @yxchng! The loss goes down very slowly for SFT, the biggest delta is seen in between epochs.

This screenshot is on of our OLMo 2 13B Instruct SFT runs for 2 epochs over the SFT data -- see the number of steps.
Screenshot 2024-12-18 at 2 38 42 PM

@vwxyzjn vwxyzjn closed this as not planned Won't fix, can't repro, duplicate, stale Jan 8, 2025
@berserkr
Copy link

berserkr commented Feb 4, 2025

@natolambert I am seeing the same too, just to confirm, expected behavior? I assume bc sum loss right?

@natolambert
Copy link
Collaborator

@berserkr can you share more -- the loss doesn't go down a ton in IFT? Share plots?

@berserkr
Copy link

berserkr commented Feb 5, 2025

@natolambert here is a loss for one the models I am testing. It is an MOE variant.

Image

@natolambert
Copy link
Collaborator

Hey @berserkr and others here. A few things.

  • in general this looks as expected.
  • epoch boundries will have bigger loss drops (how many epochs is this).
  • try training on a smaller dataset for multiple eochs.
  • can you share evaluation scores for the models?

@berserkr
Copy link

berserkr commented Feb 5, 2025

@natolambert I will wait for it to complete and perhaps try another epoch, I am using the default 70b parameters. I see MMLU for example go from 67.7 to 67.8, very little improvement. I will complete another full run before I report more :) Thank you!

@natolambert
Copy link
Collaborator

Also @berserkr MMLU is a tricky eval for SFT. Something more specific is usually easier to debug :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants