Skip to content

Commit

Permalink
Add static type information support to aten.mm (llvm#602)
Browse files Browse the repository at this point in the history
This commit adds static type information support to `aten.mm`. This is
needed for the forward pass of Bert training.
  • Loading branch information
ramiro050 authored Feb 18, 2022
1 parent abbde7d commit 2823277
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 27 deletions.
1 change: 1 addition & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"ReturnTwoTensorF32I64_basic",
"ElementwisePowModule_basic",
"BmmModule_basic",
"MmDagModule_basic",
"Matmul_dot",
"Matmul_3d",
"RsubModule_basic",
Expand Down
41 changes: 16 additions & 25 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,32 +823,23 @@ ChangeResult TypeAnalyzer::visitAtenMmOp(
auto &rhs = operands[1]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());

auto isRank2 = [](const ValueKnowledge &operand) -> bool {
return operand.hasSizes && operand.sizes.size() == 2;
};

// `aten.mm` expects both operands to be rank-2 tensors.
if (!isRank2(lhs) || !isRank2(rhs))
return getLatticeElement(op->getResult(0)).join(knowledge);

// If static information is available, check that both tensors are compatible.
if (lhs.sizes[1] != kUnknownSize && rhs.sizes[0] != kUnknownSize &&
lhs.sizes[1] != rhs.sizes[0])
return getLatticeElement(op->getResult(0)).join(knowledge);

knowledge.hasSizes = true;
// WARNING: We could be more precise here by calculating the output
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
// at this stage in the compiler because we don't really have many static
// guarantees about the input ranks because `aten` ops do dynamic error
// checking and safely abort the program. There is nothing preventing us
// from (correctly!) statically inferring the shapes of the operands to
// shapes that are guaranteed to cause an error at runtime.
//
// Example: Suppose a user program calls `aten.mm` with two rank-0
// operands. The program emits an error when invoked, but when running
// this pass, we will (correctly!) infer `lhs.hasSizes && lhs.sizes.size()
// == 0 && rhs.hasSizes && rhs.sizes.size() == 0` -- it's not safe to
// access `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
// function, it's not as simple as taking `lhs.sizes[0]` and
// `rhs.sizes[1]`, as both of those might read out of bounds of the array.
// It would require more complicated logic.
//
// Just knowing dtypes and ranks is sufficient at this stage
// in the compiler. The precise per-dimension size propagation is best
// done lower in the stack, such as at the linalg level, where we have
// more static guarantees and more structure.
knowledge.sizes.resize(2, kUnknownSize);
// TODO: Investigate promotion rules if element types mismatch.
// This is conservatively correct, assuming that if both element types are
// the same, then the result is of that same element type.
knowledge.sizes = {lhs.sizes[0], rhs.sizes[1]};

knowledge.dtype =
getPromotedResultTypeAssumingNonZeroRank(op->getContext(), {&lhs, &rhs});
return getLatticeElement(op->getResult(0)).join(knowledge);
Expand Down
42 changes: 40 additions & 2 deletions test/Dialect/Torch/refine-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ builtin.func @f(%arg0: !torch.vtensor<[2,3,?],f32>) -> !torch.vtensor {
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[?,?],f32> to !torch.vtensor
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2,?],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,?],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
Expand All @@ -51,6 +51,44 @@ builtin.func @f(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[?,?],f3

// -----

// CHECK-LABEL: func @g(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,3],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @g(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}

// -----

// CHECK-LABEL: func @h(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32>
// CHECK: %[[SHAPE_ERASED:.*]] = torch.tensor_static_info_cast %[[MM]] : !torch.vtensor<[2,4],f32> to !torch.vtensor
// CHECK: return %[[SHAPE_ERASED]] : !torch.vtensor
builtin.func @h(%arg0: !torch.vtensor<[2,?],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,?],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}

// -----

// CHECK-LABEL: func @i(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[2,5],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
// CHECK: %[[MM:.*]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
// CHECK: return %[[MM]] : !torch.vtensor
builtin.func @i(%arg0: !torch.vtensor<[2,5],f32>, %arg1: !torch.vtensor<[3,4],f32>) -> !torch.vtensor {
%1 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,5],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor
return %1 : !torch.vtensor
}

// -----

// CHECK-LABEL: func @f(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,3],f32>,
// CHECK-SAME: %[[WEIGHT:.*]]: !torch.vtensor<[5,3],f32>,
Expand Down

0 comments on commit 2823277

Please sign in to comment.