Skip to content

Commit

Permalink
Fix illegal memory access in Xformers kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix authored Feb 16, 2025
1 parent 77e211b commit a177bb4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion sparsify/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def embedding_bag_k(
for bag in range(0, bag_size):
my_index = tl.load(indices_ptr + out_idx * bag_size + bag).to(tl.int64)
my_scaling = tl.load(per_sample_weights + out_idx * bag_size + bag)
my_weight = tl.load(weight_ptr + tl.arange(0, dim_padded) + my_index * dim)
my_weight = tl.load(weight_ptr + tl.arange(0, dim_padded) + my_index * dim, mask=dim_mask)
out_value = out_value + my_weight.to(tl.float32) * my_scaling
tl.store(out_ptr + out_idx * dim + tl.arange(0, dim_padded), out_value,
mask=dim_mask)
Expand Down

0 comments on commit a177bb4

Please sign in to comment.