Skip to content

Commit

Permalink
[xla:cpu] Vectorize dot ops.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 510387472
  • Loading branch information
tyb0807 authored and TensorFlow MLIR Team committed Feb 17, 2023
1 parent 50584fa commit 5ec24f5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
8 changes: 7 additions & 1 deletion gml_st/transforms/vectorization/vectorize_for_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ namespace {
#include "gml_st/transforms/passes.h.inc"

using mlir::linalg::BroadcastOp;
using mlir::linalg::DotOp;
using mlir::linalg::FillOp;
using mlir::linalg::GenericOp;
using mlir::linalg::MapOp;
using mlir::linalg::MatmulOp;
using mlir::linalg::MatvecOp;
using mlir::linalg::Mmt4DOp;
using mlir::linalg::ReduceOp;
using mlir::linalg::TransposeOp;
using mlir::linalg::VecmatOp;
using mlir::tensor::ExpandShapeOp;
using mlir::thlo::ReverseOp;
using mlir::vector::TransferReadOp;
Expand Down Expand Up @@ -208,11 +211,14 @@ struct VectorizeForCPUPass
VectorizationPattern<BroadcastOp>,
VectorizationPattern<FillOp>,
VectorizationPattern<GenericOp>,
VectorizationPattern<DotOp>,
VectorizationPattern<MapOp>,
VectorizationPattern<MatmulOp>,
VectorizationPattern<MatvecOp>,
VectorizationPattern<Mmt4DOp>,
VectorizationPattern<ReduceOp>,
VectorizationPattern<TransposeOp>
VectorizationPattern<TransposeOp>,
VectorizationPattern<VecmatOp>
>(ctx, isInsidePerfectlyTiledLoopOrSmall);
// clang-format on
populateTransferReadOfOneDimExpandShapePattern(patterns);
Expand Down
51 changes: 51 additions & 0 deletions tests/Dialect/gml_st/vectorize_for_cpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,54 @@ func.func @perfectly_tiled_reverse_4d(%input: tensor<1x1x1x8xf32>,
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[SHUFFLE]], %[[ARG1]]
// CHECK-SAME: : vector<8xf32>, tensor<1x1x1x8xf32>
// CHECK: return %[[WRITE]]

// -----

func.func @matvec(%lhs: tensor<33x17xf32>, %rhs: tensor<17xf32>,
%output: tensor<33xf32>) -> tensor<33xf32> {
%2 = linalg.matvec ins(%lhs, %rhs : tensor<33x17xf32>, tensor<17xf32>)
outs(%output : tensor<33xf32>) -> tensor<33xf32>
return %2 : tensor<33xf32>
}

// CHECK-LABEL: @matvec
// CHECK-SAME: %[[LHS:.*]]: tensor<33x17xf32>, %[[RHS:.*]]: tensor<17xf32>, %[[OUT:.*]]: tensor<33xf32>
// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]]
// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]]
// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT]]
// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}}%[[LHS_READ]], %[[RHS_READ]], %[[OUT_READ]]
// CHECK: vector.transfer_write %[[CONTRACT]], %[[OUT]]

// -----

func.func @vecmat(%lhs: tensor<17xf32>, %rhs: tensor<17x33xf32>,
%output: tensor<33xf32>) -> tensor<33xf32> {
%2 = linalg.vecmat ins(%lhs, %rhs : tensor<17xf32>, tensor<17x33xf32>)
outs(%output : tensor<33xf32>) -> tensor<33xf32>
return %2 : tensor<33xf32>
}

// CHECK-LABEL: @vecmat
// CHECK-SAME: %[[LHS:.*]]: tensor<17xf32>, %[[RHS:.*]]: tensor<17x33xf32>, %[[OUT:.*]]: tensor<33xf32>
// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]]
// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]]
// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT]]
// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}}%[[LHS_READ]], %[[RHS_READ]], %[[OUT_READ]]
// CHECK: vector.transfer_write %[[CONTRACT]], %[[OUT]]

// -----

func.func @dot(%lhs: tensor<17xf32>, %rhs: tensor<17xf32>,
%output: tensor<f32>) -> tensor<f32> {
%2 = linalg.dot ins(%lhs, %rhs : tensor<17xf32>, tensor<17xf32>)
outs(%output : tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}

// CHECK-LABEL: @dot
// CHECK-SAME: %[[LHS:.*]]: tensor<17xf32>, %[[RHS:.*]]: tensor<17xf32>, %[[OUT:.*]]: tensor<f32>
// CHECK: %[[LHS_READ:.*]] = vector.transfer_read %[[LHS]]
// CHECK: %[[RHS_READ:.*]] = vector.transfer_read %[[RHS]]
// CHECK: %[[OUT_READ:.*]] = vector.transfer_read %[[OUT]]
// CHECK: %[[CONTRACT:.*]] = vector.contract {{.*}}%[[LHS_READ]], %[[RHS_READ]]
// CHECK: vector.transfer_write {{.*}}, %[[OUT]]

0 comments on commit 5ec24f5

Please sign in to comment.