Skip to content

Commit

Permalink
Merge branch 'ko3n1g/refactor/make-get_mlp_module_spec-public' into '…
Browse files Browse the repository at this point in the history
…main'

refactor: Make `get_mlp_module_spec` public

See merge request ADLR/megatron-lm!2534
  • Loading branch information
ko3n1g committed Jan 17, 2025
2 parents e02a860 + 4e87b4c commit fa35226
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 16 deletions.
8 changes: 5 additions & 3 deletions megatron/core/inference/modelopt_support/gpt/model_specs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from typing import Optional

from megatron.core.extensions.transformer_engine import TEDotProductAttention, TENorm
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.enums import AttnMaskType
Expand All @@ -13,7 +15,7 @@

# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec(
num_experts: int = None,
num_experts: Optional[int] = None,
moe_grouped_gemm: bool = False,
remap_te_layernorm: bool = False,
qk_layernorm: bool = False,
Expand All @@ -24,7 +26,7 @@ def get_gpt_layer_modelopt_spec(
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=False
)
sharded_state_dict_keys_map = {}
Expand Down
25 changes: 23 additions & 2 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_gpt_layer_with_transformer_engine_spec(
' and will be removed soon. Please update your code accordingly.'
)

mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_gpt_layer_local_spec(
' and will be removed soon. Please update your code accordingly.'
)

mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
Expand Down Expand Up @@ -236,6 +236,27 @@ def _get_mlp_module_spec(
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
):
warnings.warn(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)

return get_mlp_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)


def get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
Expand Down
14 changes: 8 additions & 6 deletions megatron/core/models/multimodal/llava_spec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional

from megatron.core.extensions.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
Expand All @@ -27,15 +29,15 @@

from megatron.core.transformer.torch_norm import WrappedTorchNorm

warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm


def decoder_model_with_transformer_engine_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder TE spec (uses Transformer Engine components)."""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
Expand All @@ -60,10 +62,10 @@ def decoder_model_with_transformer_engine_default_spec(


def decoder_model_with_local_default_spec(
num_experts: int = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
num_experts: Optional[int] = None, moe_grouped_gemm: bool = False, qk_layernorm: bool = False
) -> ModuleSpec:
"""LLava decoder local spec."""
mlp = _get_mlp_module_spec(
mlp = get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/models/vision/vit_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from megatron.core.transformer.torch_norm import WrappedTorchNorm

warnings.warn(f'Apex is not installed. Falling back to Torch Norm')
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm


Expand Down
41 changes: 39 additions & 2 deletions tests/unit_tests/models/test_gpt_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import inspect
import os

import pytest
import torch

from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec,
get_mlp_module_spec,
)
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -59,7 +63,7 @@ def test_set_input_tensor(self):

@pytest.mark.internal
def test_post_process_forward(self):
config: TransformerConfig = self.gpt_model.config
_ = self.gpt_model.config
sequence_length = self.gpt_model.max_sequence_length
micro_batch_size = 2

Expand All @@ -79,3 +83,36 @@ def test_post_process_forward(self):
assert logits.shape[0] == micro_batch_size
assert logits.shape[1] == sequence_length
assert logits.shape[2] == self.gpt_model.vocab_size


def test_get_mlp_module_spec_interface():
# Get the function signature
sig = inspect.signature(get_mlp_module_spec)

# Define the expected signature
expected_params = {
"use_te": inspect.Parameter.POSITIONAL_OR_KEYWORD,
"num_experts": inspect.Parameter.POSITIONAL_OR_KEYWORD,
"moe_grouped_gemm": inspect.Parameter.POSITIONAL_OR_KEYWORD,
"fp8": inspect.Parameter.POSITIONAL_OR_KEYWORD,
"moe_use_legacy_grouped_gemm": inspect.Parameter.POSITIONAL_OR_KEYWORD,
}

expected_defaults = {
"use_te": True,
"num_experts": None,
"moe_grouped_gemm": False,
"fp8": None,
"moe_use_legacy_grouped_gemm": False,
}

# Check parameter kinds
for param_name, param in sig.parameters.items():
assert param_name in expected_params.keys(), f"Unexpected parameter: {param_name}"
assert param.kind is expected_params[param_name], f"Wrong kind for parameter: {param_name}"

# Check default values
defaults = {
k: v.default for k, v in sig.parameters.items() if v.default is not inspect.Parameter.empty
}
assert defaults == expected_defaults, "Default values do not match the expected ones."
4 changes: 2 additions & 2 deletions tests/unit_tests/models/test_multimodal_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec
from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
Expand All @@ -20,7 +20,7 @@ def setup_method(self, method):
transformer_config = TransformerConfig(
num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True
)
mlp_layer_spec = _get_mlp_module_spec().submodules
mlp_layer_spec = get_mlp_module_spec().submodules

affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None)
self.mlp = MultimodalProjector(
Expand Down

0 comments on commit fa35226

Please sign in to comment.