Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Long context full SFT validation causes OOM #7041

Open
1 task done
Yixi-Rao opened this issue Feb 23, 2025 · 1 comment
Open
1 task done

Long context full SFT validation causes OOM #7041

Yixi-Rao opened this issue Feb 23, 2025 · 1 comment
Labels
bug Something isn't working pending This problem is yet to be addressed

Comments

@Yixi-Rao
Copy link

Reminder

  • I have read the above rules and searched the existing issues.

System Info

I am doing long context full SFT, I can enable finetuning by setting:

bf16: true
gradient_checkpointing: true
disable_gradient_checkpointing: false
gradient_checkpointing: true
disable_gradient_checkpointing: false
enable_liger_kernel: true
use_unsloth_gc: true
flash_attn: fa2
torch_empty_cache_steps: 10

but I found that OOM happened during the validation stage, I have already set the batch size == 1

# eval
val_size: 0.02
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 24

I have to set up validation during SFT due to the specific task I am fine-tuning.
Are there any ways or suggestions to solve this validation OOM problem?
Thanks in advance!

Reproduction

model

model_name_or_path:

method

stage: sft
do_train: true
finetuning_type: full
deepspeed: examples/deepspeed/ds_z3_config.json

dataset

dataset:
template: qwen
cutoff_len: 120000
overwrite_cache: true
preprocessing_num_workers: 90

output

output_dir:
report_to: tensorboard
logging_dir:
logging_steps: 1
save_steps: 190
plot_loss: true
overwrite_output_dir: true

train

per_device_train_batch_size: 1
gradient_accumulation_steps: 4
learning_rate: 1.0e-6
num_train_epochs: 2.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
max_grad_norm: 1.0
bf16: true

gradient_checkpointing: true
disable_gradient_checkpointing: false

enable_liger_kernel: true
use_unsloth_gc: true

flash_attn: fa2

torch_empty_cache_steps: 10

ddp_timeout: 180000000
save_only_model: true

eval

val_size: 0.02
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 24

Others

No response

@Yixi-Rao Yixi-Rao added bug Something isn't working pending This problem is yet to be addressed labels Feb 23, 2025
@JackLingjie
Copy link

I encountered the same issue, and I was only able to resolve it by disabling the evaluation during training. Specifically, I had to cancel the eval process while training to avoid the OOM error.

For context, I’m training a full SFT 1.5B model. When using 24k context length and enabling Zero3 on 8x80GH100 GPUs, the training works fine. However, when using 4 nodes with 8x40GA100, I run into OOM issues. I’m not sure if this is a bug related to Llama Factory or something else, but it might be worth looking into.

Hope this helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pending This problem is yet to be addressed
Projects
None yet
Development

No branches or pull requests

2 participants