Skip to content

Commit

Permalink
[Codegen][GPU] Add support for transpose distribution with nested lay…
Browse files Browse the repository at this point in the history
…outs (iree-org#16630)

Transposition of a nested layout simply involves transposing all
non-basis entries by the permutation of the transpose.
  • Loading branch information
qedawkins authored Mar 1, 2024
1 parent 6ff9a3d commit 71f87af
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,8 @@ struct DistributeTranspose final : OpDistributionPattern<vector::TransposeOp> {
DistributionSignature &signature,
PatternRewriter &rewriter) const override {
VectorValue value = transposeOp.getVector();
LayoutAttr layout = dyn_cast<LayoutAttr>(signature[value]);
VectorLayoutInterface layout =
dyn_cast<VectorLayoutInterface>(signature[value]);
if (!layout) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -799,3 +799,82 @@ builtin.module attributes { transform.with_named_sequence } {
// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 0, 1, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 1, 0, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>
// CHECK: vector.insert %[[BCAST]], %{{.*}} [0, 1, 1, 1, 0, 0] : vector<1x4x4xf16> into vector<1x2x2x2x1x1x1x4x4xf16>

// -----

#layout = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [2, 2],
batches_per_subgroup = [2, 4],
outers_per_batch = [2, 1],
threads_per_outer = [4, 16],
elements_per_thread = [2, 2],
subgroup_basis = [2, 2],
thread_basis = [4, 16]
>

func.func @transpose(%src: vector<256x64xf16>) -> (vector<64x256xf16>) {
%transp = vector.transpose %src, [1, 0] {"__vector_layout_test_anchor_result_0" = #layout}
: vector<256x64xf16> to vector<64x256xf16>
%sqrt = math.sqrt %transp : vector<64x256xf16>
return %sqrt : vector<64x256xf16>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @transpose
// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<256x64xf16> -> vector<2x4x2x1x2x2xf16>
// CHECK: math.sqrt %{{.*}} : vector<2x4x2x1x2x2xf16>
// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x4x2x1x2x2xf16> -> vector<64x256xf16>

// -----

#layout = #iree_vector_ext.nested_layout<
subgroups_per_workgroup = [2, 2],
batches_per_subgroup = [2, 4],
outers_per_batch = [2, 1],
threads_per_outer = [4, 16],
elements_per_thread = [2, 2],
subgroup_basis = [2, 2],
thread_basis = [4, 16]
>

func.func @transpose(%src: vector<64x256xf16>) -> (vector<256x64xf16>) {
%transp = vector.transpose %src, [1, 0] {"__vector_layout_test_anchor_operand_0" = #layout}
: vector<64x256xf16> to vector<256x64xf16>
%sqrt = math.sqrt %transp : vector<256x64xf16>
return %sqrt : vector<256x64xf16>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK: #[[$LAYOUT:.+]] = #iree_vector_ext.nested_layout
// CHECK-SAME: subgroups_per_workgroup = [2, 2],
// CHECK-SAME: batches_per_subgroup = [4, 2]
// CHECK-SAME: outers_per_batch = [1, 2]
// CHECK-SAME: threads_per_outer = [16, 4]
// CHECK-SAME: elements_per_thread = [2, 2]
// CHECK-SAME: subgroup_order = [1, 0]
// CHECK-SAME: batch_order = [1, 0],
// CHECK-SAME: outer_order = [1, 0]
// CHECK-SAME: thread_order = [1, 0]
// CHECK-SAME: element_order = [1, 0]
// CHECK-SAME: subgroup_basis = [2, 2]
// CHECK-SAME: thread_basis = [4, 16]

// CHECK-LABEL: func @transpose
// CHECK: iree_vector_ext.to_simt %{{.*}} : vector<64x256xf16> -> vector<2x4x2x1x2x2xf16>
// CHECK: math.sqrt %{{.*}} : vector<2x4x2x1x2x2xf16>
// CHECK: iree_vector_ext.to_simd %{{.*}} : vector<2x4x2x1x2x2xf16> -> vector<256x64xf16>
// CHECK: return {{.*}}#[[$LAYOUT]]
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,32 @@ NestedLayoutAttr::project(ArrayRef<bool> droppedDims) const {

VectorLayoutInterface
NestedLayoutAttr::permute(ArrayRef<int64_t> permutation) const {
llvm_unreachable("Not yet implemented");
SmallVector<int64_t> subgroupCount =
applyPermutation(getSubgroupsPerWorkgroup(), permutation);
SmallVector<int64_t> subgroupOrder =
applyPermutation(getSubgroupOrder(), permutation);
SmallVector<int64_t> batchCount =
applyPermutation(getBatchesPerSubgroup(), permutation);
SmallVector<int64_t> batchOrder =
applyPermutation(getBatchOrder(), permutation);
SmallVector<int64_t> outerCount =
applyPermutation(getOutersPerBatch(), permutation);
SmallVector<int64_t> outerOrder =
applyPermutation(getOuterOrder(), permutation);
SmallVector<int64_t> threadCount =
applyPermutation(getThreadsPerOuter(), permutation);
SmallVector<int64_t> threadOrder =
applyPermutation(getThreadOrder(), permutation);
SmallVector<int64_t> elementCount =
applyPermutation(getElementsPerThread(), permutation);
SmallVector<int64_t> elementOrder =
applyPermutation(getElementOrder(), permutation);

return NestedLayoutAttr::get(
getContext(), subgroupCount, subgroupOrder, batchCount, batchOrder,
outerCount, outerOrder, threadCount, threadOrder, elementCount,
elementOrder, getSubgroupBasis(), getSubgroupActiveIds(),
getThreadBasis(), getThreadActiveIds());
}

/// We distribute to:
Expand Down

0 comments on commit 71f87af

Please sign in to comment.