Skip to content

Commit 70dbb62

Browse files
authored
- steel_attenion.metal (new) was missing from the build
1 parent 7f02cd8 commit 70dbb62

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
// clang-format off
4+
#include "../../../utils.h"
5+
6+
#include "../../../steel/attn/attn.h"
7+
#include "../../../steel/attn/kernels/steel_attention.h"
8+
9+
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
10+
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
11+
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
12+
const device dtype* Q [[buffer(0)]], \
13+
const device dtype* K [[buffer(1)]], \
14+
const device dtype* V [[buffer(2)]], \
15+
device dtype* O [[buffer(3)]],\
16+
const constant AttnParams* params [[buffer(4)]], \
17+
uint simd_lane_id [[thread_index_in_simdgroup]], \
18+
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
19+
uint3 tid [[threadgroup_position_in_grid]], \
20+
uint3 lid [[thread_position_in_threadgroup]]);
21+
22+
#define instantiate_attn_shapes_helper(iname, itype) \
23+
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
24+
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
25+
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
26+
27+
instantiate_attn_shapes_helper(float16, half);
28+
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
29+
30+
instantiate_attn_shapes_helper(float32, float);
31+
// clang-format on

Tests/MLXTests/MLXFastKernelTests.swift

+26
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,30 @@ class MLXFastKernelTests: XCTestCase {
7070
XCTAssertTrue(allClose(out[0], full([2, 2], values: 14.0484)).all().item())
7171
XCTAssertTrue(allClose(out[1], full([3, 2], values: -2)).all().item())
7272
}
73+
74+
func testFastSDPA() {
75+
// https://github.com/ml-explore/mlx-swift/issues/172
76+
// this will just make sure the MLXFast.scaled_dot_product_attention is
77+
// callable in the various cases, based on
78+
// https://github.com/ml-explore/mlx/blob/main/python/tests/test_fast_sdpa.py#L65-L87
79+
80+
let Dk = 64
81+
let scale = 1.0 / sqrt(Float(Dk))
82+
let dTypes = [DType.float32, DType.float16]
83+
for SEQUENCE_LENGTH in [63, 129, 400] {
84+
for dtype in dTypes {
85+
let B = 2
86+
let H = 24
87+
let q = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype)
88+
let k = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype)
89+
let v = MLXRandom.normal([B, H, SEQUENCE_LENGTH, Dk]).asType(dtype)
90+
91+
let result = MLXFast.scaledDotProductAttention(
92+
queries: q, keys: k, values: v, scale: scale, mask: nil,
93+
memoryEfficientThreshold: 2)
94+
95+
eval(result)
96+
}
97+
}
98+
}
7399
}

tools/fix-metal-includes.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ KERNEL_LIST=" \
2323
arg_reduce.metal \
2424
conv.metal \
2525
gemv.metal \
26+
layer_norm.metal \
2627
random.metal \
2728
rms_norm.metal \
28-
layer_norm.metal \
2929
rope.metal \
30-
scaled_dot_product_attention.metal"
30+
scaled_dot_product_attention.metal \
31+
steel/attn/kernels/steel_attention.metal"
3132

3233
# We fixup all the header files AND the listed kernel files
3334
HEADERS=$(find "${KERNELS_DIR}" -name "*.h")

0 commit comments

Comments
 (0)