Skip to content

Commit

Permalink
Fix LLaVa mcore loss convergence issue (#385)
Browse files Browse the repository at this point in the history
Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 authored Nov 21, 2024
1 parent 51d49f0 commit a797b0c
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
8 changes: 4 additions & 4 deletions examples/llava_mcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ sh run_mcore_llava.sh \
dsw \
7B \
1 \
64 \
256 \
0.00015 \
1e-5 \
1e-6 \
576 \
1024 \
bf16 \
Expand All @@ -143,7 +143,7 @@ false \
/mnt/llava-datasets/LLaVA-Pretrain/wds \
/mnt/llava-datasets/LLaVA-Pretrain/wds \
/mnt/mistral-clip-ckpts/Mistral-7B-Instruct-v0.3-to-mcore-tp4-pp1 \
10000 \
100 \
20000 \
200 \
/workspace/output_mcore_llava_pretrain
```
4 changes: 2 additions & 2 deletions examples/llava_mcore/run_mcore_llava.sh
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ megatron_options=" \
--lr ${LR} \
--min-lr ${MIN_LR} \
--lr-decay-style cosine \
--weight-decay 0.1 \
--weight-decay 1e-2 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--clip-grad 1.0 \
Expand Down Expand Up @@ -225,7 +225,7 @@ megatron_options=" \
--no-load-optim \
--no-load-rng \
--num-workers 2 \
--patch-tokenizer-type LLama3Tokenizer \
--patch-tokenizer-type MistralTokenizer \
--swiglu \
--normalization RMSNorm \
--norm-epsilon 1e-05 \
Expand Down
13 changes: 7 additions & 6 deletions examples/mistral/pretrain_mcore_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from functools import partial
from typing import Union

from megatron import get_args
from megatron import get_timers
from megatron.training import get_args, get_timers
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
import megatron.model
from megatron.utils import (
from megatron.training.utils import (
get_batch_on_this_tp_rank,
get_batch_on_this_cp_rank,
average_losses_across_data_parallel_group
Expand All @@ -32,7 +30,7 @@
from megatron.training import pretrain
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import GPTDataset
from megatron.arguments import core_transformer_config_from_args
from megatron.training.arguments import core_transformer_config_from_args

from megatron_patch.data import build_pretrain_dataset_from_original
from megatron_patch.data.utils import get_batch_on_this_tp_rank_original
Expand All @@ -42,7 +40,10 @@
from megatron_patch.model.mixtral.layer_specs import get_gpt_layer_with_transformer_engine_spec


def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.model.GPTModel]:
def model_provider(
pre_process=True, post_process=True
) -> Union[GPTModel]:

args = get_args()
build_tokenizer(args)
config = core_transformer_config_from_args(get_args())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ run_cmd="torchrun ${DISTRIBUTED_ARGS} hf2mcore_llava.py \
--no-bias-swiglu-fusion \
--seq-length 1 \
--no-async-tensor-model-parallel-allreduce \
--patch-tokenizer-type LLamaTokenizer \
--patch-tokenizer-type MistralTokenizer \
--extra-vocab-size ${EXTRA_VOCAB_SIZE} \
--untie-embeddings-and-output-weights \
--no-rope-fusion \
Expand Down
8 changes: 3 additions & 5 deletions toolkits/model_checkpoints_convertor/llava/hf2mcore_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
)
sys.path.append(os.path.join(path_dir, "examples"))
from llama3_1.pretrain_llama import model_provider
from mistral.pretrain_mcore_mistral import model_provider
from megatron_patch.arguments import get_patch_args

torch.backends.cudnn.deterministic = True
Expand Down Expand Up @@ -603,9 +603,7 @@ def convert_checkpoint_from_transformers_to_megatron(hfmodel, mgmodel, args):
)

# 4. final layernorm
# NOTE: we move the final layernorm out of decoder to apply LLaMARMSNorm
# mgmodel.decoder.final_layernorm.weight.copy_(hfmodel.model.norm.weight)
mgmodel.final_layernorm.weight.copy_(hfmodel.model.norm.weight)
mgmodel.decoder.final_layernorm.weight.copy_(hfmodel.model.norm.weight)
# 5. output layer
mgmodel.output_layer.weight.copy_(hfmodel.lm_head.weight)

Expand Down Expand Up @@ -730,7 +728,7 @@ def convert_clip(download_root, output_path, tensor_parallel_size, use_te):

for i in range(tensor_parallel_size):
output_dir_tp = os.path.join(output_path, "clip_release", f"mp_rank_0{i}")
os.makedirs(output_dir_tp)
os.makedirs(output_dir_tp, exist_ok=True)
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
torch.save(new_state_dicts[i], output_path_tp)

Expand Down

0 comments on commit a797b0c

Please sign in to comment.