Skip to content

Commit

Permalink
Moving iree.* ops only used by the interpreter to an interpreter comm…
Browse files Browse the repository at this point in the history
…on ops file.

PiperOrigin-RevId: 287700345
  • Loading branch information
benvanik authored and copybara-github committed Jan 1, 2020
1 parent 940ae62 commit c956c4c
Show file tree
Hide file tree
Showing 38 changed files with 810 additions and 548 deletions.
2 changes: 1 addition & 1 deletion docs/simple_ir_walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ module {
}
func @simple_mul(%arg0: memref<4xf32>, %arg1: memref<4xf32>) -> memref<4xf32>
attributes {iree.module.export} {
%0 = iree.constant dense<[4, 1, 1]> : tensor<3xi32>
%0 = iree_interp.constant dense<[4, 1, 1]> : tensor<3xi32>
%1 = "iree_hl_seq.alloc_heap"() : () -> memref<4xf32>
iree_hl_seq.dispatch simple_mul_ex_dispatch_0::simple_mul_rgn_dispatch_0[%0 : memref<3xi32>](%arg0, %arg1, %1) : (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
iree_hl_seq.return %1 : memref<4xf32>
Expand Down
60 changes: 3 additions & 57 deletions iree/compiler/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,63 +32,9 @@ def IREE_Dialect : Dialect {
// General types and helpers
//===----------------------------------------------------------------------===//

class IREE_ScalarMemRefOf<list<Type> allowedTypes> :
MemRefRankOf<allowedTypes, [0]>;

//===----------------------------------------------------------------------===//
// High-level types
//===----------------------------------------------------------------------===//

def IREEHL_Bool :
def IREE_Bool :
AnyTypeOf<[I1, I8], "boolean-storing type (1 or 8 -bit integer)">;
def IREEHL_Element : AnyTypeOf<[AnyInteger, AnyFloat]>;

def IREEHL_MemRef : MemRefOf<[IREEHL_Element]>;
def IREEHL_BoolMemRef : MemRefOf<[IREEHL_Bool]>;
def IREEHL_IntMemRef : MemRefOf<[AnyInteger]>;
def IREEHL_FloatMemRef : MemRefOf<[AnyFloat]>;
def IREEHL_IndexMemRef : MemRefOf<[AnyInteger]>;

def IREEHL_AnyScalar : IREE_ScalarMemRefOf<[IREEHL_Element]>;
def IREEHL_BoolScalar : IREE_ScalarMemRefOf<[IREEHL_Bool]>;
def IREEHL_IntScalar : IREE_ScalarMemRefOf<[AnyInteger]>;
def IREEHL_FloatScalar : IREE_ScalarMemRefOf<[AnyFloat]>;
def IREEHL_IndexScalar : IREE_ScalarMemRefOf<[AnyInteger]>;
def IREEHL_I32Scalar : IREE_ScalarMemRefOf<[I32]>;

def IREEHL_1DIntMemRef : MemRefRankOf<[AnyInteger], [1]>;
def IREEHL_1DIndexMemRef : MemRefRankOf<[AnyInteger], [1]>;


//===----------------------------------------------------------------------===//
// Low-level types
//===----------------------------------------------------------------------===//

def IREELL_Bool : TypeAlias<I8, "boolean-storing type (8-bit integer)">;
def IREELL_Int: AnyTypeOf<[I8, I16, I32, I64], "8/16/32/64-bit integer">;
def IREELL_Float: AnyTypeOf<[F32, F64], "32/64-bit float">;
def IREELL_Index : AnyTypeOf<[I32, I64], "32/64-bit index integer">;
def IREELL_Element : AnyTypeOf<[IREELL_Int, IREELL_Float]>;

def IREELL_MemRef : MemRefOf<[IREELL_Element]>;
def IREELL_IntMemRef : MemRefOf<[IREELL_Int]>;
def IREELL_FloatMemRef : MemRefOf<[IREELL_Float]>;
def IREELL_BoolMemRef : MemRefOf<[IREELL_Bool]>;
def IREELL_IndexMemRef : MemRefOf<[IREELL_Index]>;
// For shape computation outputs, we want to consistently output I32 not I64
// TODO(b/138851470) Iron out story for index types
def IREELL_I32MemRef : MemRefOf<[I32]>;

def IREELL_ElementScalar : IREE_ScalarMemRefOf<[IREELL_Element]>;
def IREELL_IntScalar : IREE_ScalarMemRefOf<[IREELL_Int]>;
def IREELL_BoolScalar : IREE_ScalarMemRefOf<[IREELL_Bool]>;
def IREELL_FloatScalar : IREE_ScalarMemRefOf<[IREELL_Float]>;
def IREELL_IndexScalar : IREE_ScalarMemRefOf<[IREELL_Index]>;
// For shape computation outputs, we want to consistently output I32 not I64
// TODO(b/138851470) Iron out story for index types
def IREELL_I32Scalar : IREE_ScalarMemRefOf<[I32]>;

def IREELL_1DIntMemRef : MemRefRankOf<[IREELL_Int], [1]>;
def IREELL_1DIndexMemRef : MemRefRankOf<[IREELL_Index], [1]>;
def IREE_Element : AnyTypeOf<[AnyInteger, AnyFloat]>;
def IREE_MemRef : MemRefOf<[IREE_Element]>;

#endif // IREE_OP_BASE
235 changes: 0 additions & 235 deletions iree/compiler/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,241 +29,6 @@ namespace mlir {
namespace iree_compiler {
namespace IREE {

//===----------------------------------------------------------------------===//
// iree.constant
//===----------------------------------------------------------------------===//

static ParseResult parseConstantOp(OpAsmParser &parser,
OperationState &result) {
Attribute valueAttr;
Type type;
if (parser.parseLSquare() ||
parser.parseAttribute(valueAttr, "value", result.attributes) ||
parser.parseRSquare() ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return failure();

return parser.addTypeToList(type, result.types);
}

static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) {
p << "iree.constant[";
p.printAttribute(op.getValue());
p << "] ";
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});

p << " : ";
p.printType(op.getType());
}

namespace {

// TODO(gcmn) this is duplicated from MemRefUtils to avoid a circular
// dependency. Extract op-dependent parts of memref utils to allow reuse.
MemRefType convertTypeToMemRef(Type type) {
if (type.isIntOrIndexOrFloat()) {
return MemRefType::get({}, type, {}, 0);
} else if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
return MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else if (auto memRefType = type.dyn_cast<MemRefType>()) {
return MemRefType::get(memRefType.getShape(), memRefType.getElementType());
} else {
llvm_unreachable("Unconvertable type");
}
}

} // namespace

void ConstantOp::build(Builder *builder, OperationState &state,
ElementsAttr value) {
auto type = convertTypeToMemRef(value.getType());
return build(builder, state, type, value);
}

// TODO(b/134575149): enable folder when we store the correct type.
// OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
// assert(operands.empty() && "constant has no operands");
// return getValue();
// }

//===----------------------------------------------------------------------===//
// iree.tensor_to_memref
//===----------------------------------------------------------------------===//

static ParseResult parseTensorToMemRefOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operand;
Type operandType;
Type resultType;
if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
failed(parser.parseColonType(operandType)) ||
failed(parser.resolveOperand(operand, operandType, state.operands)) ||
failed(parser.parseRParen()) ||
failed(parser.parseColonType(resultType)) ||
failed(parser.addTypeToList(resultType, state.types))) {
return failure();
}
return success();
}

static void printTensorToMemRefOp(OpAsmPrinter &p, TensorToMemRefOp &op) {
p << "iree.tensor_to_memref(";
p.printOperand(op.getOperand());
p << " : ";
p.printType(op.getOperand()->getType());
p << ") : ";
p.printType(op.getType());
}

OpFoldResult TensorToMemRefOp::fold(ArrayRef<Attribute> operands) {
if (auto memrefToTensorOp = dyn_cast_or_null<IREE::MemRefToTensorOp>(
getOperand()->getDefiningOp())) {
return memrefToTensorOp.getOperand();
}

return {};
}

void TensorToMemRefOp::build(Builder *builder, OperationState &state,
Value arg) {
build(builder, state, convertTypeToMemRef(arg->getType()), arg);
}

//===----------------------------------------------------------------------===//
// iree.memref_to_tensor
//===----------------------------------------------------------------------===//

static ParseResult parseMemRefToTensorOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operand;
Type operandType;
Type resultType;
if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
failed(parser.parseColonType(operandType)) ||
failed(parser.resolveOperand(operand, operandType, state.operands)) ||
failed(parser.parseRParen()) ||
failed(parser.parseColonType(resultType)) ||
failed(parser.addTypeToList(resultType, state.types))) {
return failure();
}
return success();
}

static void printMemRefToTensorOp(OpAsmPrinter &p, MemRefToTensorOp &op) {
p << "iree.memref_to_tensor(";
p.printOperand(op.getOperand());
p << " : ";
p.printType(op.getOperand()->getType());
p << ") : ";
p.printType(op.getType());
}

OpFoldResult MemRefToTensorOp::fold(ArrayRef<Attribute> operands) {
if (auto tensorToMemRefOp = dyn_cast_or_null<IREE::TensorToMemRefOp>(
getOperand()->getDefiningOp())) {
return tensorToMemRefOp.getOperand();
}

return {};
}

void MemRefToTensorOp::build(Builder *builder, OperationState &state,
Value arg) {
// TODO(gcmn) Use getTensorType from MemRefUtils when circular dependency can
// be avoided.
auto memRefType = arg->getType().cast<MemRefType>();
auto tensorType =
RankedTensorType::get(memRefType.getShape(), memRefType.getElementType());
build(builder, state, tensorType, arg);
}

//===----------------------------------------------------------------------===//
// iree.scalar_to_memref
//===----------------------------------------------------------------------===//

static ParseResult parseScalarToMemRefOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operand;
Type operandType;
Type resultType;
if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
failed(parser.parseColonType(operandType)) ||
failed(parser.resolveOperand(operand, operandType, state.operands)) ||
failed(parser.parseRParen()) ||
failed(parser.parseColonType(resultType)) ||
failed(parser.addTypeToList(resultType, state.types))) {
return failure();
}
return success();
}

static void printScalarToMemRefOp(OpAsmPrinter &p, ScalarToMemRefOp &op) {
p << "iree.scalar_to_memref(";
p.printOperand(op.getOperand());
p << " : ";
p.printType(op.getOperand()->getType());
p << ") : ";
p.printType(op.getType());
}

OpFoldResult ScalarToMemRefOp::fold(ArrayRef<Attribute> operands) {
if (auto memrefToScalarOp = dyn_cast_or_null<IREE::MemRefToScalarOp>(
getOperand()->getDefiningOp())) {
return memrefToScalarOp.getOperand();
}

return {};
}

void ScalarToMemRefOp::build(Builder *builder, OperationState &state,
Value arg) {
build(builder, state, convertTypeToMemRef(arg->getType()), arg);
}

//===----------------------------------------------------------------------===//
// iree.memref_to_scalar
//===----------------------------------------------------------------------===//

static ParseResult parseMemRefToScalarOp(OpAsmParser &parser,
OperationState &state) {
OpAsmParser::OperandType operand;
Type operandType;
Type resultType;
if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) ||
failed(parser.parseColonType(operandType)) ||
failed(parser.resolveOperand(operand, operandType, state.operands)) ||
failed(parser.parseRParen()) ||
failed(parser.parseColonType(resultType)) ||
failed(parser.addTypeToList(resultType, state.types))) {
return failure();
}
return success();
}

static void printMemRefToScalarOp(OpAsmPrinter &p, MemRefToScalarOp &op) {
p << "iree.memref_to_scalar(";
p.printOperand(op.getOperand());
p << " : ";
p.printType(op.getOperand()->getType());
p << ") : ";
p.printType(op.getType());
}

OpFoldResult MemRefToScalarOp::fold(ArrayRef<Attribute> operands) {
if (auto scalarToMemRefOp = dyn_cast_or_null<IREE::ScalarToMemRefOp>(
getOperand()->getDefiningOp())) {
return scalarToMemRefOp.getOperand();
}

return {};
}

void MemRefToScalarOp::build(Builder *builder, OperationState &state,
Value arg) {
build(builder, state, getElementTypeOrSelf(arg), arg);
}

//===----------------------------------------------------------------------===//
// iree.return
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit c956c4c

Please sign in to comment.