Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@fd0c6f53913f
Browse files Browse the repository at this point in the history
Updates LLVM usage to match
[fd0c6f53913f](llvm/llvm-project@fd0c6f53913f)

PiperOrigin-RevId: 423461771
Change-Id: I053cc1ae6c55aea3de63fd1afccdfca97f2f7ac3
  • Loading branch information
tensorflower-gardener committed Jan 22, 2022
1 parent 42ecc9d commit 91997c8
Show file tree
Hide file tree
Showing 49 changed files with 355 additions and 133 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/hlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,7 @@ cc_library(
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TensorUtils",
"@llvm-project//mlir:Transforms",
],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements logic for injecting execution context to the entry
// function.
//
// Below is an example. Before Conversion:
// ```
// func @main(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) ->
// memref<?x?xf32> {
// %0 = memref.alloc(...)
// "lmhlo.add"(%arg0, %arg1, %0) : (memref<?x?xf32>, memref<?x?xf32>,
// memref<?x?xf32>) -> memref<?x?xf32> return %0 : memref<?x?xf32>
// }
// ```
// After conversion:
// ```
// func @main(%ctx: !disc_ral.context) {
// %c0 = arith.constant 0 : index
// %c1 = arith.constant 1 : index
// "disc_ral.recv_input"(%ctx, %c0) : (!disc_ral.context, index) ->
// memref<?x?xf32> "disc_ral.recv_input"(%ctx, %c1) : (!disc_ral.context,
// index) -> memref<?x?xf32> %0 = memref.alloc(...) "lmhlo.add"(%arg0,
// %arg1, %0) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) ->
// memref<?x?xf32> "disc_ral.send_output"(%ctx, %c0, %0) :
// (!disc_ral.context, index, memref<?x?xf32>) -> ()
// }
// ```

// 1. rewrite entry function (supposed that no other function directly calls the
// entry function)
// - function signature rewrite
// - return-like ops rewrite.
// 2. Currently we suppose that functions except the entry function are inlined
// to the entry function. Thus, we don't rewrite all call ops and other
// functions a.t.m. Re-visit this assumption if necessary.

#include "mlir-hlo/Dialect/disc-ral/IR/disc_ral_ops.h"
#include "mlir-hlo/Dialect/disc-ral/transforms/PassDetail.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace disc_ral {

namespace {

struct RalInjectExecutionContextPass
: public RalInjectExecutionContextPassBase<RalInjectExecutionContextPass> {
explicit RalInjectExecutionContextPass(const std::string& entry_func_name)
: RalInjectExecutionContextPassBase<RalInjectExecutionContextPass>::
RalInjectExecutionContextPassBase() {
this->entry_func_name_ = entry_func_name;
}

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<RalDialect>();
}

void runOnOperation() override {
ModuleOp m = getOperation();
FuncOp main = m.lookupSymbol<FuncOp>(entry_func_name_);
if (!main) {
m.emitError("entry func: " + entry_func_name_ + " not found");
signalPassFailure();
}

Location loc = main.getLoc();
FunctionType funcType = main.getType();
OpBuilder b(&main.getBody());
Block* entry_block = &main.getBody().front();
Type ctx_type = RalExecutionContextType::get(b.getContext());

// 1. Prepend context to the entry block arguments
Value ctx = entry_block->insertArgument(0u, ctx_type, loc);

// 2. remap original arguments to recv_input ops
for (auto&& en : llvm::enumerate(
llvm::zip(funcType.getInputs(),
entry_block->getArguments().drop_front(1)))) {
Value idx = b.create<arith::ConstantIndexOp>(loc, en.index());
Type argType = std::get<0>(en.value());
Value oldArgument = std::get<1>(en.value());
Value newInput = b.create<RecvInputOp>(loc, argType, ctx, idx);
oldArgument.replaceAllUsesWith(newInput);
}

// 3. remap all return-like ops to send_output ops
for (auto& block : main.getBody()) {
if (block.empty()) continue;
Operation& operation = block.back();
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
b.setInsertionPoint(&operation);
for (auto& en : llvm::enumerate(operation.getOperands())) {
Value idx = b.create<arith::ConstantIndexOp>(loc, en.index());
b.create<SendOutputOp>(loc, ctx, idx, en.value());
}
operation.eraseOperands(0, operation.getNumOperands());
}

// 4. remove unused block arguments of entry block
for (int i = 0, e = funcType.getInputs().size(); i < e; ++i) {
// continue to remove the 1st (starting from zero) argument
entry_block->eraseArgument(1);
}

// 5. set entry func to new type
main.setType(b.getFunctionType({ctx_type}, {}));
}
};

} // namespace

std::unique_ptr<OperationPass<ModuleOp>> createRalInjectExecutionContextPass(
const std::string& entry_func_name) {
return std::make_unique<RalInjectExecutionContextPass>(entry_func_name);
}

} // namespace disc_ral
} // namespace mlir
9 changes: 6 additions & 3 deletions tensorflow/compiler/mlir/hlo/lib/Dialect/lhlo/IR/lhlo_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,12 @@ struct RemoveCopyInReduceBody : public OpRewritePattern<ReduceOp> {
if (!the_only_copy) return failure();

auto new_reduce = rewriter.cloneWithoutRegions(reduce);
Block* new_block =
rewriter.createBlock(&new_reduce.body(), new_reduce.body().end(),
reduce.body().front().getArgumentTypes());
auto& old_reduce_body = reduce.body().front();
Block* new_block = rewriter.createBlock(
&new_reduce.body(), new_reduce.body().end(),
old_reduce_body.getArgumentTypes(),
SmallVector<Location>(old_reduce_body.getNumArguments(),
reduce.getLoc()));

mlir::BlockAndValueMapping bvm;
for (auto item : llvm::zip(reduce.body().front().getArguments(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,18 @@ class ExpandHloTuplesPass
int argument_index = original_argument_index;
SmallVector<Value, 4> flattened_operands;
// insert the flattened tuples after the original tuple.
Location loc = func.body().getLoc();
for (auto flattened_type : tuple_type.getTypes()) {
expanded_input_types.push_back(flattened_type);
func.insertArgument(++argument_index, flattened_type, {});
func.insertArgument(++argument_index, flattened_type, {}, loc);
flattened_operands.push_back(func.getArgument(argument_index));
}

// Construct a new tuple and rewire it.
OpBuilder builder(func.body());
builder.setInsertionPointToStart(&func.body().front());
auto new_tuple = builder.create<mhlo::TupleOp>(
func.body().getLoc(), tuple_type, flattened_operands);
auto new_tuple =
builder.create<mhlo::TupleOp>(loc, tuple_type, flattened_operands);
func.getArgument(original_argument_index).replaceAllUsesWith(new_tuple);

// Now the original argument has been rewired, we should be able to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ void LowerIfOp(mlir::mhlo::IfOp if_op) {
ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper, &builder);
ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper, &builder);

tail_block->addArguments(if_op.getResultTypes());
tail_block->addArguments(if_op.getResultTypes(),
SmallVector<Location>(if_op.getNumResults(), loc));
for (auto it : llvm::zip(if_op.getResults(), tail_block->getArguments()))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));

Expand Down Expand Up @@ -193,7 +194,9 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
}

// Erase the original while loop.
tail_block->addArguments(while_op.getOperandTypes());
tail_block->addArguments(
while_op.getOperandTypes(),
SmallVector<Location>(while_op.getNumOperands(), loc));
for (auto it : llvm::zip(while_op.getResults(), tail_block->getArguments()))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));

Expand Down Expand Up @@ -233,7 +236,8 @@ void LowerCaseOp(mlir::mhlo::CaseOp case_op) {

// The tail block has block arguments for each result.
TypeRange result_types = case_op.getResultTypes();
tail_block->addArguments(result_types);
tail_block->addArguments(result_types,
SmallVector<Location>(result_types.size(), loc));
for (auto it : llvm::zip(case_op->getResults(), tail_block->getArguments())) {
Value orig_result = std::get<0>(it);
Value new_value = std::get<1>(it);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -1691,7 +1692,7 @@ struct PadOpConversion : public OpConversionPattern<mhlo::PadOp> {
loc, std::get<1>(it).getZExtValue()));
}
Type result_type = op.getResult().getType();
auto pad_tensor_op = linalg::PadTensorOp::createPadScalarOp(
auto pad_tensor_op = tensor::createPadScalarOp(
result_type, adaptor.operand(), padding_val, low, high,
/*nofold=*/false, loc, rewriter);
rewriter.replaceOp(op, pad_tensor_op.getResult());
Expand Down Expand Up @@ -1728,7 +1729,7 @@ static Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,

Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);

return linalg::PadTensorOp::createPadScalarOp(
return tensor::createPadScalarOp(
RankedTensorType::get(paddedShape, inputETy), input, padValue, lowIndices,
highIndices, /*nofold=*/false, loc, rewriter);
}
Expand Down Expand Up @@ -2066,7 +2067,7 @@ struct ReduceWindowOpOnTensorsGenericConversion
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(
input.getType().cast<ShapedType>().getElementType()));
auto pad_op = rewriter.create<linalg::PadTensorOp>(
auto pad_op = rewriter.create<tensor::PadOp>(
loc, input, static_lows, static_highs, ValueRange{}, ValueRange{});

SmallVector<Type, 4> block_arg_types;
Expand All @@ -2075,8 +2076,10 @@ struct ReduceWindowOpOnTensorsGenericConversion
auto& region = pad_op.region();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&region, region.end(), block_arg_types);
rewriter.create<linalg::YieldOp>(loc, zero);
rewriter.createBlock(
&region, region.end(), block_arg_types,
SmallVector<Location>(block_arg_types.size(), loc));
rewriter.create<tensor::YieldOp>(loc, zero);

input = pad_op.getResult();
}
Expand Down Expand Up @@ -2416,8 +2419,9 @@ struct TorchIndexSelectOpConversion
body_arg_types.push_back(
block_args.getType().cast<ShapedType>().getElementType());
}
block->addArguments(body_arg_types);
block->addArguments(result_type.getElementType());
block->addArguments(body_arg_types,
SmallVector<Location>(body_arg_types.size(), loc));
block->addArguments(result_type.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);

Expand Down Expand Up @@ -2528,7 +2532,7 @@ struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
// Now populate the linalg generic region
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArguments(resultType.getElementType());
block->addArguments(resultType.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Value CreateTupleValue(OpBuilder &builder, Location loc,

// Flattens the tuples in the region's arguments and returning values.
void FlattenTupleInRegion(Region &region, PatternRewriter &rewriter) {
Location loc = region.getLoc();
OpBuilder regionOpBuilder(region);

// Flatten tuples in arguments. The order of arguments must match the order
Expand All @@ -101,12 +102,12 @@ void FlattenTupleInRegion(Region &region, PatternRewriter &rewriter) {
llvm::SmallVector<Value, 4> newArguments;
FlattenTupleType(argument, newTypes);
for (auto type : newTypes) {
newArguments.push_back(region.addArgument(type));
newArguments.push_back(region.addArgument(type, loc));
}

// Replaces uses of the replacing argument.
auto tupleValue = CreateTupleValue(regionOpBuilder, region.getLoc(),
newArguments, argument.getType());
auto tupleValue = CreateTupleValue(regionOpBuilder, loc, newArguments,
argument.getType());
argument.replaceAllUsesWith(tupleValue);
}
// Removes old tuple arguments.
Expand All @@ -126,7 +127,7 @@ void FlattenTupleInRegion(Region &region, PatternRewriter &rewriter) {
for (auto operand : returnOp.getOperands()) {
FlattenTupleValue(builder, returnOp.getLoc(), operand, results);
}
builder.create<mhlo::ReturnOp>(region.getLoc(), results);
builder.create<mhlo::ReturnOp>(loc, results);
rewriter.eraseOp(returnOp);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ struct RankSpecializationClusterPattern : public RewritePattern {
// Create body block.
auto operand_types = llvm::to_vector<16>(
llvm::map_range(operand_set, [](Value v) { return v.getType(); }));
Block *block = rewriter.createBlock(&cluster_op.body(), {}, operand_types);
Block *block =
rewriter.createBlock(&cluster_op.body(), {}, operand_types,
SmallVector<Location>(operand_types.size(), loc));

// Copy operations into the body.
BlockAndValueMapping bvm;
Expand Down Expand Up @@ -215,7 +217,9 @@ struct MergeRankSpecializationClusterOpsPattern
loc, result_types, new_operands);
auto operand_types = llvm::to_vector<16>(
llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
Block *new_body = rewriter.createBlock(&new_op.body(), {}, operand_types);
Block *new_body =
rewriter.createBlock(&new_op.body(), {}, operand_types,
SmallVector<Location>(operand_types.size(), loc));
rewriter.setInsertionPointToStart(new_body);

// Map operands and copy operations of the preceding cluster into the new
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2474,8 +2474,8 @@ func @pad_cst(%arg0: tensor<12x4xf32>) -> tensor<18x12xf32> {
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
// CHECK: linalg.yield %[[CST]] : f32
// CHECK: tensor.pad %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
// CHECK: tensor.yield %[[CST]] : f32
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>

// -----
Expand All @@ -2496,8 +2496,8 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[PAD:.+]] = tensor.extract %[[ARG1]][] : tensor<f32>
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
// CHECK: linalg.yield %[[PAD]] : f32
// CHECK: tensor.pad %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
// CHECK: tensor.yield %[[PAD]] : f32
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>

// -----
Expand Down Expand Up @@ -2671,9 +2671,9 @@ func @linalg.conv_2D_padding_test1(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<40
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
// CHECK-NEXT: %[[FILL:.*]] = linalg.fill(%[[ZERO]], %[[INIT]]) : f16, tensor<400x1024x1024x1xf16> -> tensor<400x1024x1024x1xf16>
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
// CHECK-NEXT: %[[PAD:.*]] = linalg.pad_tensor %[[INPUT]] low[0, 0, 16, 0] high[0, 0, 16, 0] {
// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 0, 16, 0] high[0, 0, 16, 0] {
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK-NEXT: linalg.yield %[[ZERO]] : f16
// CHECK-NEXT: tensor.yield %[[ZERO]] : f16
// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1024x1056x1xf16>
// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[PAD]], %[[FILTER]] : tensor<400x1024x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%[[FILL]] : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16>
// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16>
Expand All @@ -2691,9 +2691,9 @@ func @linalg.conv_2D_padding_test2(%arg0: tensor<1x33x1x1xf16>, %arg1: tensor<40
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
// CHECK-NEXT: %[[FILL:.*]] = linalg.fill(%[[ZERO]], %[[INIT]]) : f16, tensor<400x1024x1024x1xf16> -> tensor<400x1024x1024x1xf16>
// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f16
// CHECK-NEXT: %[[PAD:.*]] = linalg.pad_tensor %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0] {
// CHECK-NEXT: %[[PAD:.*]] = tensor.pad %[[INPUT]] low[0, 8, 16, 0] high[0, 8, 16, 0] {
// CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index):
// CHECK-NEXT: linalg.yield %[[ZERO]] : f16
// CHECK-NEXT: tensor.yield %[[ZERO]] : f16
// CHECK-NEXT: } : tensor<400x1024x1024x1xf16> to tensor<400x1040x1056x1xf16>
// CHECK-NEXT: %[[RESULT:.*]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%2, %arg0 : tensor<400x1040x1056x1xf16>, tensor<1x33x1x1xf16>) outs(%1 : tensor<400x1024x1024x1xf16>) -> tensor<400x1024x1024x1xf16>
// CHECK-NEXT: return %[[RESULT]] : tensor<400x1024x1024x1xf16>
Expand Down Expand Up @@ -3051,9 +3051,9 @@ func @reduce_window_sum_ndhwc(%arg0: tensor<1x17x17x17x64xf32>,
// CHECK: linalg.yield %arg2 : f32

// CHECK: %cst = arith.constant 0.000000e+00 : f32
// CHECK: %[[PAD:.+]] = linalg.pad_tensor %arg0 low[0, 1] high[3, 2]
// CHECK: %[[PAD:.+]] = tensor.pad %arg0 low[0, 1] high[3, 2]
// CHECK: ^bb0(%arg2: index, %arg3: index):
// CHECK: linalg.yield %cst : f32
// CHECK: tensor.yield %cst : f32

// CHECK: %[[WINDOW:.+]] = linalg.init_tensor [2] : tensor<2xf32>
// CHECK: %[[REDUCE:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[PAD]], %[[WINDOW]] : tensor<7x9xf32>, tensor<2xf32>) outs(%[[FILL]] : tensor<4x7xf32>) {
Expand Down
Loading

0 comments on commit 91997c8

Please sign in to comment.