Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Blackwell] Enable MMA pipelining for scaled dot when TMEM copy is used #5812

Merged
merged 43 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
62b253e
load scales in lit test
masahi Jan 30, 2025
a131bb6
stub
masahi Jan 30, 2025
f0c4a78
wip
masahi Jan 30, 2025
c6e45f7
use 5d scale
masahi Jan 30, 2025
581e7e6
working?
masahi Jan 31, 2025
9fda44f
make lit test utccp-compatible
masahi Jan 31, 2025
f263972
add back 2d scale test
masahi Jan 31, 2025
ebee5a6
reenable MMA pipe for scaled dot
masahi Jan 31, 2025
fa1b451
update test
Jan 31, 2025
d6709e1
working for swp
Jan 31, 2025
a555e54
Support tmem copy op in transitive use chain
Jan 31, 2025
3565565
minor improv in SWP
Jan 31, 2025
50c3e07
add proper logic to decide when scaled dot is safe to pipeline
Jan 31, 2025
8baf909
format
Jan 31, 2025
c8aca61
wip
Jan 31, 2025
293b65d
attempt adding explicit barrier wait after UTCCP
Feb 1, 2025
7d989b5
restore test
Feb 3, 2025
4d43bea
Merge branch 'main' into reenable-mma-pipe-bw-mxfp
Feb 3, 2025
e437762
merge fix
Feb 3, 2025
7627f87
all tests pass by adding monkey patch for ptxas disable opt
Feb 3, 2025
42b8a8b
fixed BW pipeline test
Feb 3, 2025
f28471f
add SWP test for utccp
Feb 3, 2025
3b911e4
move sync lowering pass to ttgir pipeline
Feb 3, 2025
4d05667
wip
Feb 3, 2025
aeb3be4
fix accel matmul test
Feb 3, 2025
ff49757
Merge branch 'main' into reenable-mma-pipe-bw-mxfp
masahi Feb 3, 2025
e171a7c
update accel matmul lit test
Feb 3, 2025
d7bf456
revert
Feb 3, 2025
33d3f6e
add test for MMA pipeline with utccp
Feb 3, 2025
fd5a219
precommit
Feb 3, 2025
8e73cff
add comment
Feb 4, 2025
e55e130
minor
Feb 4, 2025
b033724
improve the note on the workaround in test
Feb 4, 2025
c8e5fb9
simplify the workaround comment
Feb 4, 2025
9fe1ce9
address feedback
Feb 4, 2025
236b110
Merge branch 'main' into reenable-mma-pipe-bw-mxfp
Feb 4, 2025
00c7db3
fix
Feb 4, 2025
b339461
precommit
Feb 4, 2025
7830dfc
more comment polish
Feb 4, 2025
9930e95
Merge branch 'main' into reenable-mma-pipe-bw-mxfp
Feb 4, 2025
8c7d071
fix in lit test
Feb 5, 2025
5c737c8
workaround in accel matmul for lit test having no load on scale
Feb 5, 2025
a80e014
Update lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp
masahi Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 81 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,25 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

static LocalAllocOp
getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
OpBuilder::InsertionGuard g(rewriter);
auto argType = cast<RankedTensorType>(arg.getType());
assert(argType.getEncoding() && "unexpected tensor type");
auto newOrder = getOrder(argType.getEncoding());

Attribute SharedMemorySpace =
SharedMemorySpaceAttr::get(argType.getContext());
auto CTALayout = getCTALayout(argType.getEncoding());
// No swizzling for scale for now
auto newLayout = SwizzledSharedEncodingAttr::get(argType.getContext(), 1, 1,
1, newOrder, CTALayout);
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);
return rewriter.create<LocalAllocOp>(loc, newType, arg);
}

SmallVector<unsigned, 3>
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
Expand Down Expand Up @@ -575,6 +594,60 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
}
};

Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) {
/*
Rewrite load(scale) -> local_load(local_alloc(load(scale))).
This function does not add anything to the final IR when num_stages > 1,
but it makes it easy to apply TMEM copy rewriting later.

Since scales are stored in TMEM for MMAv5 scaled dot, loading of scales do
not needs to be put into SMEM. But in practice, the software pipeliner puts
loading of scales into multi-buffered SMEM. At that point, the SMEM
allocation created here is eliminated.
*/
OpBuilder::InsertionGuard g(rewriter);
auto op = scale.getDefiningOp();
Operation *loadConsumer = nullptr;

if (!op)
return scale;

while (!isa<LoadOp>(op)) {
if (auto reshape = dyn_cast<ReshapeOp>(op)) {
op = reshape.getSrc().getDefiningOp();
loadConsumer = reshape;
} else if (auto trans = dyn_cast<TransOp>(op)) {
op = trans.getSrc().getDefiningOp();
loadConsumer = trans;
} else if (auto cvt = dyn_cast<ConvertLayoutOp>(op)) {
op = cvt.getSrc().getDefiningOp();
loadConsumer = cvt;
} else {
// Unrecognized pattern, bail out. In practice, this implies that MMA
// pipelining will not apply to the scaled dot op, since tmem_copy would
// not be inserted before the pipeline pass.
return scale;
}
}

auto scaleAfterLoad = op->getResult(0);
auto scaleSmemAlloc =
getSharedMemoryScale(scaleAfterLoad, rewriter, op->getLoc());

rewriter.setInsertionPointAfterValue(scaleSmemAlloc);
auto localLoad = rewriter.create<LocalLoadOp>(
op->getLoc(), scaleAfterLoad.getType(), scaleSmemAlloc);

rewriter.replaceAllUsesExcept(scaleAfterLoad, localLoad.getResult(),
scaleSmemAlloc);

if (loadConsumer) {
return scale;
} else {
return localLoad;
}
}

class ScaledBlockedToMMAv5
: public mlir::OpRewritePattern<triton::DotScaledOp> {
int computeCapability;
Expand Down Expand Up @@ -688,10 +761,14 @@ class ScaledBlockedToMMAv5
oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleALayout);
RankedTensorType newScaleBType = RankedTensorType::get(
oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleBLayout);
Value newScaleA = rewriter.create<ConvertLayoutOp>(loc, newScaleAType,
dotOp.getLhsScale());
Value newScaleB = rewriter.create<ConvertLayoutOp>(loc, newScaleBType,
dotOp.getRhsScale());

auto lhsScale = addSmemStageToScaleLoad(dotOp.getLhsScale(), rewriter);
auto rhsScale = addSmemStageToScaleLoad(dotOp.getRhsScale(), rewriter);

Value newScaleA =
rewriter.create<ConvertLayoutOp>(loc, newScaleAType, lhsScale);
Value newScaleB =
rewriter.create<ConvertLayoutOp>(loc, newScaleBType, rhsScale);
Value scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
loc, scaleAType, newScaleA);
Value scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
dfs(defOp, finalUser, distance);
}
}
if (auto tmemAlloc = dyn_cast<nvidia_gpu::TMEMAllocOp>(op)) {
if (!tmemAlloc.getSrc()) {
for (auto user : tmemAlloc.getResult().getUsers()) {
if (auto tmemCopy = dyn_cast<nvidia_gpu::TMEMCopyOp>(user)) {
dfs(tmemCopy.getSrc().getDefiningOp(), finalUser, distance);
break;
}
}
}
}
};

bool seenDot = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,18 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
Operation *wait = builder.createWithStage<ttg::AsyncWaitOp>(
loc, stageForFirstUse, clusterForFirstUse, commit->getResult(0), 0);

auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared;

// Extract part.
SmallVector<Value> loadOffsets(allocTy.getRank(), zero);
loadOffsets[0] = extractIdx;
auto viewLoad = builder.createWithStage<ttg::MemDescSubviewOp>(
loc, stageForFirstUse, clusterForFirstUse, subviewTy, alloc, loadOffsets);
if (loadIsMMAv3Shared) {
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));

if (loadToInfo[loadOp].isMMAv3Shared || loadToInfo[loadOp].isMMAv5Scale) {
auto user = *loadOp->getUsers().begin();
assert(isa<triton::gpu::LocalAllocOp>(user) &&
"Loading of MMAv3 operands and MMAv5 scale is expected to be "
"consumed by LocalAlloc.");
auto alloc = cast<ttg::LocalAllocOp>(user);
tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult());
alloc.erase();
} else {
Expand Down Expand Up @@ -455,6 +458,12 @@ getTransitiveUserInBlock(Operation *baseOp, scf::ForOp &forOp) {
for (Operation *user : op->getUsers())
if (user->getBlock() == op->getBlock())
dfs(user, baseOp, anyOp);
if (auto tmemCopy = dyn_cast<triton::nvidia_gpu::TMEMCopyOp>(op)) {
auto tmemAlloc =
tmemCopy.getDst()
.getDefiningOp<triton::nvidia_gpu::TMEMAllocOp>();
dfs(tmemAlloc, baseOp, anyOp);
}
};
// We are matching the behavior before refactoring:
// For loops without num_stage attributes, we check for dot users.
Expand Down
39 changes: 36 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,31 @@ void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp,
annotateWithPipelineStage(builder, info.phase.getDefiningOp(), 0);
}

bool isSafeToPipeline(ttng::TCGen5MMAScaledOp scaledDot) {
auto getNumUsers = [](Value value) {
return std::distance(value.user_begin(), value.user_end());
};

auto isCopiedByTMEMCopy = [=](Value scale) {
if (getNumUsers(scale) != 2) {
// MMA and TMEM copy must be the only users
return false;
}

for (auto user : scale.getUsers()) {
if (!isa<ttng::TMEMCopyOp, ttng::TCGen5MMAScaledOp>(user)) {
// If the scale is used by TMEM copy and the only other user is the
// scaled dot op, MMA pipelining is safe to apply.
return false;
}
}
return true;
};

return isCopiedByTMEMCopy(scaledDot.getAScale()) &&
isCopiedByTMEMCopy(scaledDot.getBScale());
}

// Find MMAs eligible for pipelining and lower them by:
// 1. Hoisting the accumulator allocation outside of the loop.
// 2. Creating a barrier alloc and lowering the MMA to MMA + wait barrier.
Expand All @@ -603,9 +628,17 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
SmallVector<Operation *> mmaOps;
forOp.walk([&](Operation *op) {
// Skip MMA nested in another forOp
if (isa<ttng::TCGen5MMAOp>(op) &&
op->getParentOfType<scf::ForOp>() == forOp) {
mmaOps.push_back(op);
if (op->getParentOfType<scf::ForOp>() == forOp) {
if (isa<ttng::TCGen5MMAOp>(op)) {
mmaOps.push_back(op);
} else if (auto scaledDot = dyn_cast<ttng::TCGen5MMAScaledOp>(op)) {
if (isSafeToPipeline(scaledDot)) {
mmaOps.push_back(op);
} else {
op->emitWarning("Skipping pipelining of an MMAv5 scaled op because "
"TMEM copy is not used.");
}
}
}
});

Expand Down
41 changes: 22 additions & 19 deletions python/test/unit/language/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,9 @@ def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, device):
rtol = 0.0001
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)

if NUM_STAGES > 1:
# TODO: Remove this check once MMA pipelining is working for these cases
if M >= BLOCK_M and N >= BLOCK_N and K >= BLOCK_K:
# Verify that MMA pipelining has been applied
# FIXME: Scaled dot pipelining is DISABLED
assert "ttng.wait_barrier" not in out.asm["ttgir"]
# Pipelining of dot_scaled requires tmem_copy to be used, which in turn
# requires the scales to be in the blocked layout in global memory.
assert "ttng.wait_barrier" not in out.asm["ttgir"]


def _knob_promote_lhs_to_tmem(monkeypatch):
Expand Down Expand Up @@ -437,13 +434,21 @@ def block_scale_mxfp_matmul( #
tl.store(output_ptrs, accumulator, mask=c_mask)


def _knob_disable_ptxas_opt(monkeypatch):
monkeypatch.setenv("DISABLE_PTXAS_OPT", "1")


@pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128),
(128, 128, 256), (128, 256, 256)])
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 4])
@pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device):
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device, monkeypatch):
if NUM_STAGES == 1 and USE_2D_SCALE_LOAD:
# Disabling ptxas optimization as a temporary workaround, otherwise the test does not pass
_knob_disable_ptxas_opt(monkeypatch)

if BLOCK_N == 256 and BLOCK_K == 256:
NUM_STAGES = min(NUM_STAGES, 2)
elif BLOCK_K == 256:
Expand All @@ -467,6 +472,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
a_scale.stride(2), a_scale.stride(3), a.stride(0), a.stride(1), b.stride(0),
b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
NUM_STAGES=NUM_STAGES, USE_2D_SCALE_LOAD=USE_2D_SCALE_LOAD)
ttgir = out.asm["ttgir"]

def flatten_scale(scale):
num_chunk_m, num_chunk_k, _, _, _ = scale.shape
Expand All @@ -488,30 +494,27 @@ def flatten_scale(scale):
rtol = 0.0001
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)

if NUM_STAGES > 1:
ttgir = out.asm["ttgir"]
if USE_2D_SCALE_LOAD:
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
assert "tmem_copy" in ttgir

if NUM_STAGES > 1:
if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K:
load_pipelined = ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") == 2
else:
load_pipelined = (ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") and
ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_K}x{BLOCK_N}"))

if load_pipelined:
# If load is pipelined, MMA pipelining should also kick in
# FIXME: Scaled dot pipelining is DISABLED
assert "ttng.wait_barrier" not in ttgir
else:
if load_pipelined and USE_2D_SCALE_LOAD:
# If load is pipelined and tmem_copy is used, MMA pipelining should also kick in
assert "ttng.wait_barrier" in ttgir
elif not load_pipelined:
# The behavior of load pipelining seems to depend on the size of input tensors.
# In this test, it fails to pipeline the RHS tensor when N is not a multiple of 128. Pipelining of the LHS tensor
# does not seem to be affected by the value of M, though.
print(f"SWP failed for M = {M}, N = {N}")

if USE_2D_SCALE_LOAD:
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
assert "tmem_copy" in ttgir


@triton.jit
def lhs_in_tmem_kernel( #
Expand Down
44 changes: 43 additions & 1 deletion test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -302,12 +302,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
// CHECK-DAG: %[[TRUE:.+]] = arith.constant true
// CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xi8, #{{.*}}>) -> !ttg.memdesc<128x64xi8, #{{.*}}, #smem
// CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x128xi8, #{{.*}}>) -> !ttg.memdesc<64x128xi8, #{{.*}}, #smem
// CHECK-DAG: %[[SCALEA_LOCAL:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #{{.*}}, #smem>
// CHECK: ttg.local_load %[[SCALEA_LOCAL]] : !ttg.memdesc<128x2xi8, #{{.*}}, #smem> -> tensor<128x2xi8, #{{.*}}>
// CHECK-DAG: %[[SCALEB_LOCAL:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #{{.*}}, #smem>
// CHECK: ttg.local_load %[[SCALEB_LOCAL]] : !ttg.memdesc<128x2xi8, #{{.*}}, #smem> -> tensor<128x2xi8, #{{.*}}>
// CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x128xf32, #{{.*}}>) -> !ttg.memdesc<128x128xf32, #{{.*}}, #ttng.tensor_memory, mutable>
// CHECK: %[[SCALEA:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #[[$TMEM1]], #ttng.tensor_memory>
// CHECK: %[[SCALEB:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #[[$TMEM1]], #ttng.tensor_memory>
// CHECK: ttng.tc_gen5_mma_scaled %[[A]], %[[B]], %[[ACC]], %[[SCALEA]], %[[SCALEB]], %[[TRUE]], %[[TRUE]] lhs = e4m3 rhs = e4m3
tt.func public @mmav5_block_scaled(%a: tensor<128x64xi8, #blocked2>, %scale_a: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xi8, #blocked>, %scale_b: tensor<128x2xi8, #blocked1>) -> tensor<128x128xf32, #blocked> {
tt.func public @mmav5_block_scaled(%a: tensor<128x64xi8, #blocked2>, %scale_a_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>, %b: tensor<64x128xi8, #blocked>, %scale_b_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>) -> tensor<128x128xf32, #blocked> {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%scale_a = tt.load %scale_a_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>
%scale_b = tt.load %scale_b_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>
%d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked>, tensor<128x2xi8, #blocked1> -> tensor<128x128xf32, #blocked>
tt.return %d : tensor<128x128xf32, #blocked>
}
Expand Down Expand Up @@ -389,3 +395,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
tt.return %d : tensor<128x128xf32, #blocked>
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}>
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 4, 8, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 1, 2, 3, 0]}>
#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[32, 0], [64, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
// CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding
// CHECK-LABEL: mmav5_block_scaled_5d_scale
// CHECK-DAG: %[[TRUE:.+]] = arith.constant true
// CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x128xi8, #{{.*}}>) -> !ttg.memdesc<128x128xi8, #{{.*}}, #smem
// CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x128xi8, #{{.*}}>) -> !ttg.memdesc<128x128xi8, #{{.*}}, #smem
// CHECK-DAG: %[[SCALEA_LOCAL:.+]] = ttg.local_alloc
// CHECK: ttg.local_load %[[SCALEA_LOCAL]]
// CHECK-DAG: %[[SCALEB_LOCAL:.+]] = ttg.local_alloc
// CHECK: ttg.local_load %[[SCALEB_LOCAL]]
// CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x128xf32, #{{.*}}>) -> !ttg.memdesc<128x128xf32, #{{.*}}, #ttng.tensor_memory, mutable>
// CHECK: %[[SCALEA:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x4xi8, #{{.*}}>) -> !ttg.memdesc<128x4xi8, #[[$TMEM1]], #ttng.tensor_memory>
// CHECK: %[[SCALEB:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x4xi8, #{{.*}}>) -> !ttg.memdesc<128x4xi8, #[[$TMEM1]], #ttng.tensor_memory>
// CHECK: ttng.tc_gen5_mma_scaled %[[A]], %[[B]], %[[ACC]], %[[SCALEA]], %[[SCALEB]], %[[TRUE]], %[[TRUE]] lhs = e4m3 rhs = e4m3
tt.func public @mmav5_block_scaled_5d_scale(%a: tensor<128x128xi8, #blocked2>, %scale_a_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>, %b: tensor<128x128xi8, #blocked>, %scale_b_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>) -> tensor<128x128xf32, #blocked> {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
%scale_a_5d = tt.load %scale_a_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>
%scale_a_trans = tt.trans %scale_a_5d {order = array<i32: 0, 3, 2, 1, 4>} : tensor<1x1x32x4x4xi8, #blocked3> -> tensor<1x4x32x1x4xi8, #blocked4>
%scale_a = tt.reshape %scale_a_trans : tensor<1x4x32x1x4xi8, #blocked4> -> tensor<128x4xi8, #linear>
%scale_b_5d = tt.load %scale_b_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>
%scale_b_trans = tt.trans %scale_b_5d {order = array<i32: 0, 3, 2, 1, 4>} : tensor<1x1x32x4x4xi8, #blocked3> -> tensor<1x4x32x1x4xi8, #blocked4>
%scale_b = tt.reshape %scale_b_trans : tensor<1x4x32x1x4xi8, #blocked4> -> tensor<128x4xi8, #linear>
%d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xi8, #blocked2>, tensor<128x4xi8, #linear> * tensor<128x128xi8, #blocked>, tensor<128x4xi8, #linear> -> tensor<128x128xf32, #blocked>
tt.return %d : tensor<128x128xf32, #blocked>
}
}
Loading
Loading