Skip to content

Commit

Permalink
Add pass/ops to expand functions to accept/return dynamic dimensions.
Browse files Browse the repository at this point in the history
* This pass runs early and sets up each function for the later materialization passes.
* Some limited canonicalization that will elide get_ranked_shape ops when trivially resolvable.
* Skeleton of a doc outlining where this is going.

PiperOrigin-RevId: 291452226
  • Loading branch information
Stella Laurenzo authored and copybara-github committed Jan 24, 2020
1 parent d0a0629 commit 3501288
Show file tree
Hide file tree
Showing 16 changed files with 686 additions and 15 deletions.
166 changes: 166 additions & 0 deletions docs/dynamic_shapes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Dyanmic Shapes

NOTE: Effort is being made to make this facility generic so that it can be
eventually upstreamed to MLIR in some fashion. However, because MLIR lacks a set
of frontend ops and generally does not currently have any frontend oriented
infrastructure, it is being prototyped within IREE in order to find a working
set of ops and algorithms.

## Levels of dynamicism

In general, there are three levels of shape information that can be present in
the input IR (or trivially derived by applying some form of shape inferencing
algorithm). Each additional one imposes more work on the compiler and runtime,
so generally, the implementation progresses by addressing each once the former
is well established:

1. Fully static shapes: No tensors have dynamic dimensions. All tensors are
ranked.
2. Ranked Dynamicism: All tensors have ranks, but some dimensions may be
unspecified.
3. Unranked Dynamicism: Some tensors have indeterminate ranks.

At this stage, *Dynamic Shapes* in IREE refers to supporting dynamic ranked
dynamic tensors, where some dimensions are left unspecified at public function
boundaries. It is expected that once this is solid, some support can be
considered for unranked dynamicism, and it is further expected that will entail
new ops, algorithms and runtime support, apart from what is needed for ranked
dynamicism.

Within the category of Ranked Dynamicism, it is well known that some dynamic
dimensions are easier to deal with than others: in common DNN use, outer
dimensions are much easier and more common with respect to code generation and
kernel fanout than dynamic inner dimensions.

While the shape handling machinery is relatively generic, we expect that real
backends will be limited with respect to how much they support all combinations
of dynamic dimensions. Eventually, IREE intends to solve this by having
relatively robust CPU fallback for fully dynamic cases and actionable warnings
that pinpoint when more specificity could increase performance.

## Compiler Frontend

In general, the IREE compiler frontend should accept modules containing
functions with operands/results that have dynamic dimensions. Such functions may
also have runtime dependent shapes in the form of `GetShape`-style ops which get
a shape from an arbitrary tensor, perform some arithmetic on it and use the
results elsewhere.

### Shape dialect and lowering

IREE is introducing a `shape` dialect with a handful of ops and transformations
that are useful for materializing dynamic shape computations in terms of high
level ops on tensors.

#### Types:

* `ranked_shape`: This value type represents the dynamic dimensions of a
partially known, ranked shape. It is used early in the compiler to represent
anywhere that dynamic dimensions need to be passed (i.e. function
args/results, etc). At lower levels of the compiler, it will generally be
dis-aggregated into loose SSA values. This type also carries the datatype
used to represent the dimensions. This is currently fixed to i32 but may be
leveraged eventually to use smaller integer when such things are known to be
legal.

#### Ops:

* `get_ranked_shape`: Takes a tensor SSA value and returns a corresponding
`ranked_shape`. Early in the compilation flow, anything that needs a ranked
shape should add such ops so that the compiler can later determine which
shape computations to materialize. Getting the `ranked_shape` of a static
tensor should yield a constant.
* `tie_shape`: Takes tensor and ranked_shape SSA values and returns the
tensor. This is used as a junction point by the shape materialization passes
to know at various points precisely what the shape is.
* ... TODO: need `get_shape_dim` and conversions to/from 1D tensors and loose
SSA values.

### Materialization

#### Function signature expansion

Early in the process, all functions should have their arguments and results
expanded so that any dynamic tensors in their signature will gain a new
argument/result for the corresponding `ranked_shape`. This is done by expanding
the signatures and for arguments, inserting placeholder `tie_shape` ops which
preserve the association for later materialization. For results,
`get_ranked_shape` ops are inserted.

This is carried out by the `iree-shape-expand-function-dynamic-dims` pass, which
uses the conversion framework under the hood to perform type expansion.

This pass is typically done early in the compiler flow.

#### Shape dependent codegen

A lot of scheduling logic will need to access shapes (i.e. allocation, workgroup
size calculation, etc). In general, this should all be done based on a
`get_ranked_shape` op and corresponding `get_shape_dim` ops. For fully static
cases, these should reduce down to constants. For dynamic dimensions, the
`get_ranked_shape` ops serve as anchors where later parts of the compiler know
they need to materialize shape values.

#### Materializing shape computations

TODO: We have a sketch of this but are still proving it out.

Generally, it should be possible, for any `get_ranked_shape` op, to trace up the
use-def chain and materialize shape manipulation arithmetic. Once materialized,
a `tie_shape` op should be inserted to memorialize the junction. Eventually,
every `get_ranked_shape` op should be follow a `tie_shape` op, and the
canonicalization rules will elide the `get_ranked_shape`. There is complexity
around blocks, control flow, etc, but this basic algorithm should be workable.

Work is ongoing upstream to provide a facility to register shape functions with
ops, which would provide a dynamic, dialect independent way to know what
arithmetic to materialize. However, in most cases this is not necessary. The
built-in traits around types and sizes will allow most propagation to happen
without shape functions. We intend to start with a static set of cases for the
rest in order to prove the concept.

#### Scalarization

TODO: We have a sketch of this but are still proving it out.

It is quite common in real-world DNN usage to get the 1D tensor representing a
shape and perform arbitrary tensor ops on it (usually basic arithmetic, slicing,
concating, tiling, etc). While this is perfectly acceptable from a correctness
standpoint, it is usually not performant: shapes are typically very small one
dimensional vectors, and computations on them are usually trivial to reduce to
small sequences of scalar machine code of a form that CPUs are very good at
executing. Further, we often want to run these calculations eagerly when
dispatching functions, etc (i.e. to pre-allocate buffers) and having them
isolated (versus treating them as arbitrary dense computations) can be quite
valuable.

We expect that the type bracketing that happens with `ranked_shape` and the
corresponding ops will make it easy to write some simple DRR patterns to
identify such shape manipulation sequences and lower them directly to regions of
`vm` ops operating on scalars. Such regions can be retained and directly emitted
when lowering to the `vm` dialect and/or CPU code generation and would run with
low overhead along with any other scheduling code.

While an optimization, we suspect this is an important one.

### Shape inference

TODO: This is mostly placeholder

There is work happening upstream to implement MLIR-integrated shape inference.
In the mean-time, IREE expects that the input to the compiler has already had
some shape inference performed on it. In practice, for TensorFlow, there is a
pass which applies TensorFlow's pre-MLIR shape inference mechanisms to derive
such things. This has limitations but is a reasonable starting point.

## Compiler Backends

TODO: This is mostly placeholder.

Much of the newer structured-ops based codegen is capable of working (within
bounds) with ranked dynamic shapes without much work. Given the lack of an e2e
story, much of this has been done "by way of code review" and there are
certainly issues to be resolved.

In addition, there are several ABI issues and negotiations with the backend that
still need to be fleshed out.
1 change: 1 addition & 0 deletions iree/compiler/Dialect/Shape/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

add_subdirectory(IR)
add_subdirectory(Transforms)
79 changes: 78 additions & 1 deletion iree/compiler/Dialect/Shape/IR/ShapeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
Expand All @@ -34,6 +35,77 @@ namespace mlir {
namespace iree_compiler {
namespace Shape {

//===----------------------------------------------------------------------===//
// Canonicalization
//===----------------------------------------------------------------------===//

class ElideTiedGetRankedShapePattern
: public OpRewritePattern<GetRankedShapeOp> {
using OpRewritePattern::OpRewritePattern;
PatternMatchResult matchAndRewrite(GetRankedShapeOp op,
PatternRewriter &rewriter) const override {
// If the immediate predecessor is a TieShapeOp, then this op can be
// erased in favor of the input to the tie op.
if (!matchPattern(op.operand(), m_Op<TieShapeOp>())) {
return matchFailure();
}

auto tieOp = cast<TieShapeOp>(op.operand().getDefiningOp());
rewriter.replaceOp(op, tieOp.shape(), op.operand());

return matchSuccess();
}
};

//===----------------------------------------------------------------------===//
// iree.tie_shape
//===----------------------------------------------------------------------===//

static ParseResult parseTieShapeOp(OpAsmParser &parser, OperationState &state) {
SmallVector<OpAsmParser::OperandType, 2> operands;
SmallVector<Type, 2> operandTypes;
if (parser.parseOperandList(operands) ||
parser.parseColonTypeList(operandTypes) ||
parser.parseOptionalAttrDict(state.attributes) ||
parser.resolveOperands(operands, operandTypes, parser.getNameLoc(),
state.operands)) {
return failure();
}

// The result type is the same as the first operand.
if (state.operands.empty()) return failure();
state.types.push_back(state.operands.front().getType());
return success();
}

static void printTieShapeOp(OpAsmPrinter &p, TieShapeOp op) {
p << op.getOperationName() << " ";
p.printOperands(op.getOperands());
p << " : ";
interleaveComma(op.getOperandTypes(), p);
p.printOptionalAttrDict(op.getOperation()->getAttrs());
}

static LogicalResult verifyTieShapeOp(TieShapeOp op) {
if (op.operand().getType() != op.result().getType()) {
return op.emitOpError("operand and result must be the same type");
}

// tie_shape currently only supports ranked tensors.
auto rankedTensorType = op.operand().getType().dyn_cast<RankedTensorType>();
auto rsType = op.shape().getType().dyn_cast<RankedShapeType>();
if (!rankedTensorType || !rsType) {
return op.emitOpError("currently only ranked tensors are supported");
}

SmallVector<int64_t, 4> rsDims;
rsType.getAllDims(rsDims);
if (!rankedTensorType.getShape().equals(rsDims)) {
return op.emitOpError("dims must match between tensor and shape");
}
return success();
}

//===----------------------------------------------------------------------===//
// iree.get_ranked_shape
//===----------------------------------------------------------------------===//
Expand All @@ -49,7 +121,7 @@ static ParseResult parseGetRankedShapeOp(OpAsmParser &parser,
}

static void printGetRankedShapeOp(OpAsmPrinter &p, GetRankedShapeOp op) {
p << "shape.get_ranked_shape ";
p << op.getOperationName() << " ";
p.printOperand(op.operand());
p << " : ";
p.printType(op.operand().getType());
Expand All @@ -72,6 +144,11 @@ static LogicalResult verifyGetRankedShapeOp(GetRankedShapeOp op) {
return success();
}

void GetRankedShapeOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<ElideTiedGetRankedShapePattern>(context);
}

#define GET_OP_CLASSES
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.cpp.inc"

Expand Down
28 changes: 28 additions & 0 deletions iree/compiler/Dialect/Shape/IR/ShapeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ class Shape_PureOp<string mnemonic, list<OpTrait> traits = []> :
// Dynamic shape support
//===----------------------------------------------------------------------===//

def Shape_TieShapeOp : Shape_PureOp<"tie_shape"> {
let summary = "Ties a tensor and a shape together.";
let description = [{
Ties a specific tensor and its shape together in the IR, allowing further
conversions to re-associate the two. This has no runtime implication and
will be removed late in conversion.

Usage:
%0 = shape.tie_shape %1, %2 : tensor<...>, shape.ranked_shape<...>
}];

let arguments = (ins AnyTensor:$operand, Shape_RankedShape:$shape);
let results = (outs AnyTensor:$result);

let verifier = [{ return verify$cppClass(*this); }];
}

def Shape_GetRankedShapeOp : Shape_PureOp<"get_ranked_shape"> {
let summary = "Gets the RankedShape associated with the given Tensor.";
let description = [{
Expand All @@ -43,12 +60,23 @@ def Shape_GetRankedShapeOp : Shape_PureOp<"get_ranked_shape"> {

Getting the RankedShape of a statically shaped tensor will canonicalize
to a static_ranked_shape op and will never cause a further SSA dependency.

Usage:
%0 = shape.get_ranked_shape %arg0 : tensor<2x?xf32> ->
!shape.ranked_shape<2x?xf32>

Canonicalization: This op includes a canonicalization pattern such that
if its operand is supplied by a tie_shape op, then it will replace itself
with the tie_shape's shape() operand. In this way, a function with all
shapes materialized and tied to intermediate tensors should canonicalize
to contain no get_ranked_shape ops.
}];

let arguments = (ins AnyTensor:$operand);
let results = (outs Shape_RankedShape:$shape);

let verifier = [{ return verify$cppClass(*this); }];
let hasCanonicalizer = 1;
}

#endif // IREE_DIALECT_SHAPE_OPS
17 changes: 17 additions & 0 deletions iree/compiler/Dialect/Shape/IR/test/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: iree-opt -split-input-file -verify-diagnostics -canonicalize %s | IreeFileCheck %s


// CHECK-LABEL: @elideTiedGetRankedShape
// CHECK-SAME: %[[T:[^:[:space:]]+]]: tensor<1x?x2x?xf32>
// CHECK-SAME: %[[SHAPE:[^:[:space:]]+]]: !shape.ranked_shape<1x?x2x?xi32>
func @elideTiedGetRankedShape(%arg0: tensor<1x?x2x?xf32>, %arg1: !shape.ranked_shape<1x?x2x?xi32>) -> (tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>) {
// Note that canonicalization does *not* remove tie_shape. That must be
// removed manually once all shape materialization is complete (otherwise,
// information needed to materialize would be lost).
// CHECK: %[[TIE_T:.+]] = shape.tie_shape %[[T]], %[[SHAPE]]
%0 = shape.tie_shape %arg0, %arg1 : tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>
// CHECK-NOT: shape.get_ranked_shape
%1 = shape.get_ranked_shape %0 : tensor<1x?x2x?xf32> -> !shape.ranked_shape<1x?x2x?xi32>
// CHECK-DAG: return %[[TIE_T]], %[[SHAPE]]
return %0, %1 : tensor<1x?x2x?xf32>, !shape.ranked_shape<1x?x2x?xi32>
}
22 changes: 22 additions & 0 deletions iree/compiler/Dialect/Shape/IR/test/op_verification.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: iree-opt -split-input-file -verify-diagnostics %s

// -----
func @tie_shape_mismatch_type(%arg0 : tensor<2x?x4xf32>, %arg1 : !shape.ranked_shape<1xi32>) {
// expected-error @+1 {{dims must match between tensor and shape}}
%0 = shape.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shape.ranked_shape<1xi32>
return
}

// -----
func @get_ranked_shape_same_rank(%arg0 : tensor<2x?x4xf32>) {
// expected-error @+1 {{op operand and result must be of same rank}}
%0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2xi32>
return
}

// -----
func @get_ranked_shape_not_equal_dims(%arg0 : tensor<2x?x4xf32>) {
// expected-error @+1 {{op operand tensor and result shape must be equal}}
%0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2x2x4xi32>
return
}
8 changes: 8 additions & 0 deletions iree/compiler/Dialect/Shape/IR/test/parse_print.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s

// -----
// CHECK-LABEL: @parse_print_tie_shape
func @parse_print_tie_shape(%arg0 : tensor<2x?x4xf32>, %arg1 : !shape.ranked_shape<2x?x4xi32>) {
%0 = shape.tie_shape %arg0, %arg1 : tensor<2x?x4xf32>, !shape.ranked_shape<2x?x4xi32>
return
}


// -----
// CHECK-LABEL: @parse_print_get_ranked_shape
func @parse_print_get_ranked_shape(%arg0 : tensor<2x?x4xi32>) {
Expand Down
14 changes: 0 additions & 14 deletions iree/compiler/Dialect/Shape/IR/test/ranked_shape_type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,3 @@ func @parseDynamicShape(%arg0 : !shape.ranked_shape<1x?x2x?xi32>) {
func @error(%arg0 : !shape.ranked_shape<1x?xf32>) {
return
}

// -----
func @get_ranked_shape_same_rank(%arg0 : tensor<2x?x4xf32>) {
// expected-error @+1 {{op operand and result must be of same rank}}
%0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2xi32>
return
}

// -----
func @get_ranked_shape_not_equal_dims(%arg0 : tensor<2x?x4xf32>) {
// expected-error @+1 {{op operand tensor and result shape must be equal}}
%0 = shape.get_ranked_shape %arg0 : tensor<2x?x4xf32> -> !shape.ranked_shape<2x2x4xi32>
return
}
Loading

0 comments on commit 3501288

Please sign in to comment.