Skip to content

Commit

Permalink
Add SM margin to LayerNorm in inference (NVIDIA#772)
Browse files Browse the repository at this point in the history
* Add LN margin to inference

Signed-off-by: Sangkug Lym <[email protected]>

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* Fix symbolic func registration

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix grads

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
erhoo82 and ksivaman authored Apr 12, 2024
1 parent b4ef463 commit 5d34b2a
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 22 deletions.
3 changes: 3 additions & 0 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def forward(self, inp):
self.meta,
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)

ret = cast_from_fp8(
Expand Down Expand Up @@ -748,6 +749,7 @@ def forward(self, inp):
self.meta,
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)

ret = cast_from_fp8(
Expand Down Expand Up @@ -1279,6 +1281,7 @@ def forward(self, inp, weight):
self.meta,
self.fp8_tensor,
self.fp8_type,
0,
zero_centered_gamma)

x = cast_from_fp8(
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/paddle/layer/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def __init__(
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

def _te_forward(
self,
Expand Down Expand Up @@ -600,7 +601,7 @@ def _te_forward(
self.activation_dtype,
self.return_layernorm_output,
paddle.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/paddle/layer/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def __init__(
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

def _te_forward(
self,
Expand Down Expand Up @@ -865,7 +866,7 @@ def _te_forward(
self.activation_dtype,
self.return_layernorm_output,
paddle.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma,
) -> torch.Tensor:
"""LayerNorm with FP8 output.
Expand All @@ -83,6 +84,7 @@ def layernorm_fwd_fp8_inf(
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
sm_margin,
zero_centered_gamma)
return ret

Expand All @@ -92,6 +94,7 @@ def layernorm_fwd_inf(
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""LayerNorm with FP8 output"""
Expand All @@ -100,6 +103,7 @@ def layernorm_fwd_inf(
weight,
bias,
eps,
sm_margin,
zero_centered_gamma,
)

Expand Down Expand Up @@ -149,6 +153,7 @@ def rmsnorm_fwd_fp8_inf(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma,
) -> torch.Tensor:
"""RMSNorm with FP8 output.
Expand All @@ -165,6 +170,7 @@ def rmsnorm_fwd_fp8_inf(
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
sm_margin,
zero_centered_gamma)
return ret

Expand All @@ -173,12 +179,14 @@ def rmsnorm_fwd_inf(
inp: torch.Tensor,
weight: torch.Tensor,
eps: float,
sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
"""RMSNorm with FP8 output"""
return torch.ops.tex_ts.rmsnorm_fwd_inf_ts(
inp,
weight,
eps,
sm_margin,
zero_centered_gamma,
)
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);

Expand All @@ -432,6 +433,7 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);

Expand Down Expand Up @@ -478,6 +480,7 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);

Expand All @@ -499,6 +502,7 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);

Expand Down
13 changes: 9 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,13 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
return out[0];
}

Expand Down Expand Up @@ -203,11 +204,13 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma);
std::vector<at::Tensor> out = layernorm_fwd(input, weight, bias, eps, sm_margin,
zero_centered_gamma);
return out[0];
}

Expand Down Expand Up @@ -345,12 +348,13 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd_fp8(
input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma);
input, weight, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
return out[0];
}

Expand Down Expand Up @@ -391,10 +395,11 @@ std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
// This is a specialized version of rmsnorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma);
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma);
return out[0];
}
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const int8_t sm_margin,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
Expand All @@ -377,6 +378,7 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
amax,
scale_inv,
otype_arg,
sm_margin,
zero_centered_gamma);

return output;
Expand All @@ -387,13 +389,15 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
double eps,
const int8_t sm_margin,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);

at::Tensor output = layernorm_fwd_inf(input,
weight,
bias,
eps_float,
sm_margin,
zero_centered_gamma);

return output;
Expand All @@ -408,6 +412,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const int8_t sm_margin,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
Expand All @@ -419,6 +424,7 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
amax,
scale_inv,
otype_arg,
sm_margin,
zero_centered_gamma);

return output;
Expand All @@ -428,12 +434,14 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
const int8_t sm_margin,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);

at::Tensor output = rmsnorm_fwd_inf(input,
weight,
eps_float,
sm_margin,
zero_centered_gamma);

return output;
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _apply_normalization(inputmat:torch.Tensor,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
), None, None
else:
Expand All @@ -88,7 +89,7 @@ def _apply_normalization(inputmat:torch.Tensor,
)
else:
return normalization_func(
*inputs, eps, zero_centered_gamma
*inputs, eps, fwd_ln_sm_margin, zero_centered_gamma
), None, None
if normalization == "RMSNorm":
output = (ln_out, None, output[1])
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def forward(
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
inf_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
activation_dtype: torch.dtype,
Expand All @@ -58,7 +59,7 @@ def forward(
ctx.zero_centered_gamma = zero_centered_gamma
else:
ln_out, mu, rsigma = layernorm_fwd_inf(inputmat, ln_weight,
ln_bias, eps, zero_centered_gamma), None, None
ln_bias, eps, inf_ln_sm_margin, zero_centered_gamma), None, None
return ln_out.view_as(inp)

@staticmethod
Expand All @@ -72,7 +73,7 @@ def backward(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None, None


class LayerNorm(torch.nn.Module):
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
Expand Down Expand Up @@ -198,6 +200,7 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor:
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.inf_ln_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled(),
self.activation_dtype,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,7 @@ def __init__(
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
Expand Down Expand Up @@ -1165,7 +1166,7 @@ def forward(
self.return_layernorm_output,
self.return_layernorm_output_gathered,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,7 @@ def __init__(
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0"))

# Initialize a dummy tensor to be used as gradient hook for bwd amax reduction.
self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True)
Expand Down Expand Up @@ -1575,7 +1576,7 @@ def forward(
self.bias_gelu_nvfusion,
self.set_parallel_mode,
torch.is_grad_enabled(),
self.fwd_ln_sm_margin,
self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
Expand Down
Loading

0 comments on commit 5d34b2a

Please sign in to comment.