Skip to content

Commit

Permalink
Address nits from previous PR stack + bump version
Browse files Browse the repository at this point in the history
ghstack-source-id: 128d2ff4077d15fa24ceae532bfc5bd743e71b31
Pull Request resolved: facebookresearch#518
  • Loading branch information
danthe3rd committed Nov 10, 2022
1 parent 86e3db2 commit 8367685
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 122 deletions.
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
"nvcc": extra_compile_args.get("nvcc", [])
+ [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
Expand Down
27 changes: 18 additions & 9 deletions tests/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import copy
import functools
import random
from contextlib import nullcontext
from typing import ContextManager, Optional, Sequence, cast

import pytest
import torch

import xformers.ops.swiglu as xsw
import xformers.ops.swiglu_op as xsw

torch.backends.cuda.matmul.allow_tf32 = False
cuda_only = pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
Expand Down Expand Up @@ -106,12 +108,17 @@ def generate_test_shapes():
_ops: Sequence[xsw.SwiGLUOp] = [xsw.SwiGLUFusedOp, xsw.SwiGLUPackedFusedOp]


@pytest.mark.parametrize("bias", [False, True], ids=["nobias", "bias"])
@functools.lru_cache(maxsize=1)
def create_module_cached(**kwargs) -> xsw.SwiGLU:
return xsw.SwiGLU(**kwargs)


@pytest.mark.parametrize("autocast", [False, True], ids=["regular", "autocast"])
@pytest.mark.parametrize("pack_weights", [False, True], ids=["regular", "packed"])
@pytest.mark.parametrize("op", _ops, ids=[x.NAME for x in _ops])
@pytest.mark.parametrize("dtype", _dtypes, ids=[str(x) for x in _dtypes])
@pytest.mark.parametrize("device", _devices)
@pytest.mark.parametrize("bias", [False, True], ids=["nobias", "bias"])
@pytest.mark.parametrize("pack_weights", [False, True], ids=["regular", "packed"])
@pytest.mark.parametrize(
"shape",
_test_shapes,
Expand Down Expand Up @@ -154,11 +161,13 @@ def test_forward_backward(
inp_model_dtype = torch.float if autocast else dtype
x = torch.randn(shape[:2], device=device, dtype=inp_model_dtype)

module = xsw._SwiGLUModule(
in_features=shape[1],
hidden_features=shape[2],
pack_weights=pack_weights,
bias=bias,
module = copy.deepcopy(
create_module_cached(
in_features=shape[1],
hidden_features=shape[2],
bias=bias,
_pack_weights=pack_weights,
)
)
x_f32: Optional[torch.Tensor]
ref_f32: Optional[torch.Tensor]
Expand All @@ -180,7 +189,7 @@ def test_forward_backward(
)
with cm:
ref = module(x)
out = xsw.functional_swiglu(x, *module._ordered_params_for_op(), op=op)
out = xsw.swiglu(x, *module._ordered_params(), op=op)

if ref_f32 is None:
ref_f32 = ref
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.14.dev
0.0.15.dev
25 changes: 11 additions & 14 deletions xformers/benchmarks/benchmark_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils import benchmark
from utils import benchmark_main_helper

import xformers.ops.swiglu as xsw
import xformers.ops.swiglu_op as xsw

min_run_time = 0.5
device = torch.device("cuda")
Expand Down Expand Up @@ -68,9 +68,7 @@ def benchmark_swiglu(shape, dtype, bias: bool):

x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=True, bias=bias
)
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
.to(device)
.to(model_dtype)
)
Expand All @@ -79,25 +77,26 @@ def benchmark_swiglu(shape, dtype, bias: bool):
bstr = "bias" if bias else "nobi"
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"

params = module._ordered_params_for_op()
params = module._ordered_params()

PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": params,
"fn": partial(xsw.functional_swiglu, op=OP),
"fn": partial(xsw.swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt=f"{PREFIX}fn(x)",
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"fn": module,
"args": params,
"fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
},
label="swiglu_fw",
description="eager",
Expand All @@ -116,9 +115,7 @@ def benchmark_swiglu_bw(shape, dtype, bias: bool):
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
module = (
xsw._SwiGLUModule(
in_features=shape[1], hidden_features=shape[2], pack_weights=True, bias=bias
)
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
.to(device)
.to(model_dtype)
)
Expand All @@ -127,9 +124,9 @@ def benchmark_swiglu_bw(shape, dtype, bias: bool):
bstr = "bias" if bias else "nobi"
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"

params = module._ordered_params_for_op()
params = module._ordered_params()
with cm():
out = xsw.functional_swiglu(x, *params, op=OP)
out = xsw.swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)

yield benchmark.Timer(
Expand All @@ -145,7 +142,7 @@ def benchmark_swiglu_bw(shape, dtype, bias: bool):
del out

with cm():
out = module(x)
out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)

yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
Expand Down
11 changes: 2 additions & 9 deletions xformers/benchmarks/benchmark_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,15 @@ class TimmSwiGLU(nn.Module):
def __init__(self, mlp: TimmMlp, op=None) -> None:
super().__init__()
self.fc1 = mlp.fc1
self.swiglu = xops.swiglu._SwiGLUModule(
self.swiglu = xops.SwiGLU(
in_features=mlp.fc1.in_features,
hidden_features=mlp.fc1.out_features,
pack_weights=True,
bias=True,
)
self.op = op

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, M, C = x.shape

x = x.reshape([B * M, C])
x = xops.functional_swiglu(x, *self.swiglu._ordered_params_for_op(), op=self.op)
x = x.reshape([B, M, C])

return x
return self.swiglu(x)


def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> dual_gemm_silu_identity_mul_(
>;
{
cudaDeviceProp* p = at::cuda::getDeviceProperties(x.device().index());
TORCH_CHECK(p->major * 10 + p->minor >= ArchTag::kMinComputeCapability, "GPU not supported");
TORCH_CHECK(p->major * 10 + p->minor >= ArchTag::kMinComputeCapability, "Only A100+ GPUs are supported");
}

int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1;
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/swiglu/cuda/gemm_fused_operand_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void gemm_fused_operand_sum_(
>;
{
cudaDeviceProp* p = at::cuda::getDeviceProperties(a.device().index());
TORCH_CHECK(p->major * 10 + p->minor >= SmArch::kMinComputeCapability, "GPU not supported");
TORCH_CHECK(p->major * 10 + p->minor >= SmArch::kMinComputeCapability, "Only A100+ GPUs are supported");
}

// Below is the reduction kernel used in the case of parallel split-k
Expand Down
2 changes: 1 addition & 1 deletion xformers/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_features_status() -> Dict[str, str]:
features = {}
for op in ALL_OPS:
features[f"memory_efficient_attention.{op.NAME}"] = op.info()
for k, v in ops.swiglu._info().items():
for k, v in ops.swiglu_op._info().items():
features[f"swiglu.{k}"] = v
features["is_triton_available"] = str(_is_triton_available())
features["is_functorch_available"] = str(_is_functorch_available)
Expand Down
5 changes: 3 additions & 2 deletions xformers/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
MemoryEfficientAttentionOp,
memory_efficient_attention,
)
from .swiglu import ( # noqa: F401
from .swiglu_op import ( # noqa: F401
SwiGLU,
SwiGLUEagerOp,
SwiGLUFusedOp,
SwiGLUOp,
SwiGLUOpDispatch,
SwiGLUPackedFusedOp,
_info,
functional_swiglu,
swiglu,
)
from .unbind import get_stack_strides, stack_or_none, unbind # noqa: F401

Expand Down
Loading

0 comments on commit 8367685

Please sign in to comment.