Skip to content

Commit

Permalink
Let UnfuseBatchNorm handle dynamic shapes.
Browse files Browse the repository at this point in the history
Also add an empty dimensions vector when broadcasting
a scalar value. This is needed to legalize the broadcast
further down. Also, this follows pre-existing conventions
of how broadcasts of scalars are represented.

PiperOrigin-RevId: 299297553
Change-Id: I51a64f1d7b4ce3349d3e02e743faf66e29aaead1
  • Loading branch information
akuegel authored and tensorflower-gardener committed Mar 6, 2020
1 parent 0dbec67 commit 53ed9a3
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 22 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ cc_library(
],
deps = [
":hlo",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
Expand Down
45 changes: 44 additions & 1 deletion tensorflow/compiler/mlir/xla/tests/unfuse_batch_norm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func @batchNormInference_2D_inner_features(
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
Expand Down Expand Up @@ -92,3 +92,46 @@ func @batchNormInference_f16_overflow(
tensor<256xf16>) -> tensor<4x256xf16>
return %0 : tensor<4x256xf16>
}

// -----
// CHECK-LABEL: @batchNormInference_dynamic_shape
// Validate that dynamic shapes are handled properly.
// CHECK-SAME: %[[X:[^:[:space:]]+]]
// CHECK-SAME: %[[SCALE:[^:[:space:]]+]]
// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]]
// CHECK-SAME: %[[MEAN:[^:[:space:]]+]]
// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]]
func @batchNormInference_dynamic_shape(
%x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>,
%mean: tensor<?xf32>, %variance: tensor<?xf32>)
-> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor<?xf32>
// CHECK-DAG: %[[INDEX_CAST:.+]] = index_cast %[[DIM]] : index to i32
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INDEX_CAST]]) : (i32) -> tensor<1xi32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_INDEX_CAST_0:.+]] = index_cast %[[INPUT_DIM_0]] : index to i32
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_INDEX_CAST_1:.+]] = index_cast %[[INPUT_DIM_1]] : index to i32
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_INDEX_CAST_2:.+]] = index_cast %[[INPUT_DIM_2]] : index to i32
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_INDEX_CAST_3:.+]] = index_cast %[[INPUT_DIM_3]] : index to i32
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = "xla_hlo.scalars_to_dimension_tensor"(%[[INPUT_INDEX_CAST_0]], %[[INPUT_INDEX_CAST_1]], %[[INPUT_INDEX_CAST_2]], %[[INPUT_INDEX_CAST_3]]) : (i32, i32, i32, i32) -> tensor<4xi32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.sub %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.mul %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.div %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
88 changes: 67 additions & 21 deletions tensorflow/compiler/mlir/xla/transforms/unfuse_batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
Expand All @@ -28,20 +30,47 @@ namespace xla_hlo {

namespace {

// Broadcasts the 1D value tensor to rank.
Value broadcastToFeatureDim(Location loc, Type result_type, Value value_1d,
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
// a static broadcast.
Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
Value value_1d, Value shape_value,
int64_t feature_dim,
ConversionPatternRewriter& rewriter) {
ConversionPatternRewriter& rewriter) { // NOLINT
Builder b(rewriter.getContext());
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
if (shape_value) {
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
loc, result_type, value_1d, shape_value, dims);
}
assert(result_type.hasStaticShape());
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
dims);
}

// Calculate the shape value of operand, assuming it is a dynamic shape with
// static rank.
Value CalculateShapeValue(Location loc, Value operand,
ConversionPatternRewriter& rewriter) { // NOLINT
RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
llvm::SmallVector<Value, 4> shape_values;
int64_t rank = result_type.getRank();
shape_values.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
auto index_value = rewriter.create<mlir::DimOp>(loc, operand, i);
shape_values.push_back(rewriter.create<mlir::IndexCastOp>(
loc, index_value, rewriter.getIntegerType(32)));
}
Type shape_element_type = shape_values.front().getType();
return rewriter.create<ScalarsToDimensionTensorOp>(
loc, RankedTensorType::get({rank}, shape_element_type), shape_values);
}

Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
FloatType fp_type, Type broadcast_to_type,
ConversionPatternRewriter& rewriter) {
FloatType fp_type, Value variance,
RankedTensorType broadcast_to_type,
ConversionPatternRewriter& rewriter) { // NOLINT
Builder b(rewriter.getContext());
if (epsilon_attr.getType() != fp_type) {
// Need to convert.
Expand All @@ -66,9 +95,16 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
Value epsilon =
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
epsilon = rewriter.create<xla_hlo::BroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/nullptr);
return epsilon;
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
if (broadcast_to_type.hasStaticShape()) {
return rewriter.create<xla_hlo::BroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
}
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, shape_value,
/*broadcast_dims=*/dims);
}

class UnfuseBatchNormInferencePattern
Expand All @@ -84,9 +120,10 @@ class UnfuseBatchNormInferencePattern
// Enforce type invariants.
// Note that we deduce the actual element type from the variance,
// which should not be subject to quantization at a higher level.
auto input_type = operands.operand().getType();
auto variance_type = operands.variance().getType().dyn_cast<ShapedType>();
if (!variance_type) {
auto input_type = operands.operand().getType().dyn_cast<RankedTensorType>();
auto variance_type =
operands.variance().getType().dyn_cast<RankedTensorType>();
if (!input_type || !variance_type) {
return matchFailure();
}
auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
Expand All @@ -97,8 +134,9 @@ class UnfuseBatchNormInferencePattern

// Add epsilon to the variance and sqrt to get stddev:
// stddev = sqrt(variance + epsilon)
auto epsilon = MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(),
fp_type, variance_type, rewriter);
auto epsilon =
MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
operands.variance(), variance_type, rewriter);
if (!epsilon) {
return matchFailure();
}
Expand All @@ -108,14 +146,22 @@ class UnfuseBatchNormInferencePattern
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);

// Broadcast all terms.
auto broadcast_scale = broadcastToFeatureDim(
bn_op.getLoc(), input_type, operands.scale(), feature_dim, rewriter);
auto broadcast_offset = broadcastToFeatureDim(
bn_op.getLoc(), input_type, operands.offset(), feature_dim, rewriter);
auto broadcast_mean = broadcastToFeatureDim(
bn_op.getLoc(), input_type, operands.mean(), feature_dim, rewriter);
auto broadcast_stddev = broadcastToFeatureDim(
bn_op.getLoc(), input_type, stddev, feature_dim, rewriter);
Value shape_value;
if (!input_type.hasStaticShape()) {
shape_value =
CalculateShapeValue(bn_op.getLoc(), operands.operand(), rewriter);
}
auto broadcast_scale =
BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.scale(),
shape_value, feature_dim, rewriter);
auto broadcast_offset =
BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.offset(),
shape_value, feature_dim, rewriter);
auto broadcast_mean =
BroadcastToFeatureDim(bn_op.getLoc(), input_type, operands.mean(),
shape_value, feature_dim, rewriter);
auto broadcast_stddev = BroadcastToFeatureDim(
bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);

// Compute:
// scale * (input - mean) / stddev + offset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
Expand All @@ -33,6 +34,7 @@ struct TestUnfuseBatchNormPass : public FunctionPass<TestUnfuseBatchNormPass> {

// Consider the xla_hlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>();
conversionTarget.addLegalDialect<StandardOpsDialect>();
conversionTarget.addIllegalOp<xla_hlo::BatchNormInferenceOp>();

PopulateUnfuseBatchNormPatterns(&getContext(), &conversionPatterns);
Expand Down

0 comments on commit 53ed9a3

Please sign in to comment.