Skip to content

Commit

Permalink
[XLA] Fix dot strength reduction to simply remove all degenerate dime…
Browse files Browse the repository at this point in the history
…nsions and

let the next pass of algebriac simplification handle the resulting Dot that
may have no contracting or non-contracting dimensions. Also allow scalar dots
in shape inference as it works.

PiperOrigin-RevId: 244956958
  • Loading branch information
blakehechtman authored and tensorflower-gardener committed Apr 24, 2019
1 parent 88e0e42 commit 86fbd13
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 284 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/client/lib/svd_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ XLA_TEST_F(SVDTest, Various_Size_Random_Matrix_512x512) {
Array2D<float> a_val = GenerateRandomMatrix(512, 512);
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
auto result = SVD(a, 100, 1e-6);
auto result = SVD(a, 100, 1e-4);
GetAverageAbsoluteError(ComputeMatmulUDVT(result, &builder), a, &builder);

ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
Expand Down
12 changes: 0 additions & 12 deletions tensorflow/compiler/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1051,18 +1051,6 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
// If one operand is a scalar, just multiply the two operands.
if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
if (dimension_numbers.rhs_batch_dimensions_size() != 0 ||
dimension_numbers.lhs_batch_dimensions_size() != 0 ||
dimension_numbers.rhs_contracting_dimensions_size() != 0 ||
dimension_numbers.lhs_contracting_dimensions_size() != 0) {
return InvalidArgument(
"Dots with scalar operands must have no contracting or batch "
"dimensions");
}
return xla::Mul(lhs, rhs);
}
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
Expand Down
319 changes: 59 additions & 260 deletions tensorflow/compiler/xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,31 +259,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
AlgebraicSimplifier* simplifier)
: computation_(computation), options_(options), simplifier_(simplifier) {}

// Transforms Dots where at least one input is a vector or has a degenerate
// dimension and converts it into a multiply and reduce. This should enable
// more fusion than leaving the nodes as Dot operations.
StatusOr<bool> HandleDotStrengthReduction(HloInstruction* dot);

// Removes dimension dim from hlo.
HloInstruction* StripDim(HloInstruction* hlo, int64 dim) {
CHECK_EQ(hlo->shape().dimensions(dim), 1);
return computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::DeleteDimension(dim, hlo->shape()), hlo));
}

// Reshapes an instruction to rank 1 if it is not already rank 1.
HloInstruction* Flatten(HloInstruction* hlo) {
if (hlo->shape().rank() == 1) {
return hlo;
}
auto hlo_instruction =
computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(hlo->shape().element_type(),
{ShapeUtil::ElementsIn(hlo->shape())}),
hlo));
simplifier_->UpdateLayout(hlo_instruction->mutable_shape());
return hlo_instruction;
}
// Removes degenerate dimension from dot.
StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);

// Converts to primitive type if the input hlo is not that type, otherwise
// returns the original hlo.
Expand Down Expand Up @@ -1173,240 +1150,73 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
return Status::OK();
}

StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
HloInstruction* dot) {
HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));

const auto kept_dim = [](int64 rank, int64 contracting_dimension,
absl::Span<const int64> batch_dimensions) -> int64 {
for (int64 i = 0; i < rank; ++i) {
if (i != contracting_dimension &&
!absl::c_linear_search(batch_dimensions, i)) {
return i;
}
}
return -1;
};

const int64 dot_rank = dot->shape().rank();
const int64 rhs_rank = rhs->shape().rank();
const int64 lhs_rank = lhs->shape().rank();
const auto& dnums = dot->dot_dimension_numbers();
if (dnums.rhs_contracting_dimensions_size() != 1) {
return false;
}
if (dot_rank > 2 && (lhs_rank != rhs_rank || lhs_rank != dot_rank)) {
return false;
}
int64 lhs_collapsing_dim = dnums.lhs_contracting_dimensions(0);
int64 lhs_kept_dim = kept_dim(lhs_rank, lhs_collapsing_dim,
AsInt64Slice(dnums.lhs_batch_dimensions()));
// If there is no non-contracting dimension in rank 2, do not strength reduce.
if (lhs_kept_dim == -1 && lhs_rank > 1) {
return false;
}
if (lhs->IsRank2Transpose()) {
lhs = lhs->mutable_operand(0);
std::swap(lhs_collapsing_dim, lhs_kept_dim);
}

int64 rhs_collapsing_dim = dnums.rhs_contracting_dimensions(0);
int64 rhs_kept_dim = kept_dim(rhs_rank, rhs_collapsing_dim,
AsInt64Slice(dnums.rhs_batch_dimensions()));
// If there is no non-contracting dimension in rank 2, do not strength reduce.
if (rhs_kept_dim == -1 && rhs_rank > 1) {
return false;
}
if (rhs->IsRank2Transpose()) {
rhs = rhs->mutable_operand(0);
std::swap(rhs_collapsing_dim, rhs_kept_dim);
}

auto reshape_if_necessary = [&](HloInstruction* hlo) {
hlo = AsType(hlo, dot->shape().element_type());
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
hlo = computation_->AddInstruction(
HloInstruction::CreateReshape(dot->shape(), hlo));
const Shape& lhs_shape = dot->operand(0)->shape();
int64 num_degenerate_lhs_dims = 0;
std::vector<int64> lhs_dimension_map(lhs_shape.rank(), -1);
for (int64 i = 0; i < lhs_shape.rank(); ++i) {
if (lhs_shape.dimensions(i) == 1) {
++num_degenerate_lhs_dims;
} else {
lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
}
return hlo;
};

auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
return AddReduce(AsType(hlo, F32), dim);
};

auto broadcast = [&](HloInstruction* hlo, const Shape& shape,
absl::Span<const int64> dims) {
return computation_->AddInstruction(
HloInstruction::CreateBroadcast(shape, hlo, dims));
};

auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
int64 dim) {
return broadcast(hlo, shape, {dim});
};

auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) {
return computation_->AddInstruction(HloInstruction::CreateBinary(
local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs));
};

// Strength reduce dot(a[K] , b[K]) =
// reshape(result.shape,
// reduce_sum(multiply(a, b), {0}))
if (rhs_rank == 1 && lhs_rank == 1) {
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, rhs), 0))));
return true;
}

if (ShapeUtil::IsEffectiveScalar(rhs->shape()) &&
ShapeUtil::IsEffectiveScalar(lhs->shape())) {
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs)))));
return true;
}

// Simplify outer product into multiply with broadcasting.
//
// A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
if (rhs_rank == 2 && rhs->shape().dimensions(rhs_collapsing_dim) == 1) {
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0),
broadcast_to_dim(Flatten(rhs), dot->shape(), 1))));
return true;
}

// Strength reduce dot(a[1, K], b) =
// reshape(result.shape,
// reduce_sum(
// multiply(broadcast(reshape(a, [K]), {0}), b),
// {0})
// )
// )
if (lhs_rank == 1 ||
(lhs_rank == 2 && lhs->shape().dimensions(lhs_kept_dim) == 1)) {
if (rhs->shape().rank() == 1) {
TF_RETURN_IF_ERROR(
ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
multiply(Flatten(lhs), rhs), 0))));
return true;
const Shape& rhs_shape = dot->operand(1)->shape();
int64 num_degenerate_rhs_dims = 0;
std::vector<int64> rhs_dimension_map(rhs_shape.rank(), -1);
for (int64 i = 0; i < rhs_shape.rank(); ++i) {
if (rhs_shape.dimensions(i) == 1) {
++num_degenerate_rhs_dims;
} else {
rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
}
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(add_reduce_in_f32(
multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
rhs_collapsing_dim),
rhs),
rhs_collapsing_dim))));
return true;
}

// Strength reduce dot(a, b[K, 1]) =
// reshape(result.shape,
// reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
// )
if (rhs_rank == 1 ||
(rhs_rank == 2 && rhs->shape().dimensions(rhs_kept_dim) == 1)) {
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(add_reduce_in_f32(
multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
lhs_collapsing_dim)),
lhs_collapsing_dim))));
return true;
}

// Only consider kDot with batch dimension.
if (dot_rank <= 2) {
if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
return false;
}

CHECK_EQ(rhs_rank, lhs_rank);
CHECK_EQ(dot_rank, lhs_rank);
// If there is more than one non-contracting dimension or the batch dimensions
// are not equal, bail out since transposes may be required to do a strength
// reduction.
if (dnums.rhs_batch_dimensions_size() + 2 != dot_rank ||
!absl::c_equal(dnums.lhs_batch_dimensions(),
dnums.rhs_batch_dimensions())) {
return false;
}

auto broadcast_dims = [](int64 rank, int64 non_broadcast_dim) {
absl::InlinedVector<int64, 8> dims;
for (int64 i = 0; i < rank; ++i) {
if (i != non_broadcast_dim) {
dims.push_back(i);
}
const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
DotDimensionNumbers new_dnums;
for (int64 dim : dnums.lhs_batch_dimensions()) {
int64 new_dim = lhs_dimension_map[dim];
if (new_dim != -1) {
new_dnums.add_lhs_batch_dimensions(new_dim);
}
return dims;
};

// If the contracting dimension is 1, remove the degnerate dimnensions from
// the lhs and rhs, broadcast each to the result shape and multiply.
if (lhs->shape().dimensions(lhs_collapsing_dim) == 1 &&
(rhs_kept_dim == rhs_rank - 1 ||
(rhs_collapsing_dim == rhs_rank - 1 && rhs_kept_dim == rhs_rank - 2))) {
CHECK_EQ(rhs->shape().dimensions(rhs_collapsing_dim), 1);
const int64 lhs_kept_dim_in_output =
lhs_kept_dim > lhs_collapsing_dim ? (lhs_kept_dim - 1) : lhs_kept_dim;
absl::InlinedVector<int64, 8> lhs_broadcast_dims;
for (const int64 dim : dnums.lhs_batch_dimensions()) {
lhs_broadcast_dims.push_back(dim > lhs_collapsing_dim ? (dim - 1) : dim);
}
absl::InlinedVector<int64, 8> rhs_broadcast_dims = lhs_broadcast_dims;
lhs_broadcast_dims.push_back(lhs_kept_dim_in_output);
absl::c_sort(lhs_broadcast_dims);
rhs_broadcast_dims.push_back(dot_rank - 1);
absl::c_sort(rhs_broadcast_dims);
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(
multiply(broadcast(StripDim(lhs, lhs_collapsing_dim),
dot->shape(), lhs_broadcast_dims),
broadcast(StripDim(rhs, rhs_collapsing_dim),
dot->shape(), rhs_broadcast_dims)))));
return true;
}

// If the lhs and rhs non-contracting dimensions are both one, strip each one,
// multiply and then reduce the collapsing dimension
if (lhs->shape().dimensions(lhs_kept_dim) == 1 &&
rhs->shape().dimensions(rhs_kept_dim) == 1 &&
lhs_kept_dim == rhs_kept_dim) {
auto new_lhs = StripDim(lhs, lhs_kept_dim);
auto new_rhs = StripDim(rhs, rhs_kept_dim);
const int64 reduce_dim = rhs_kept_dim < rhs_collapsing_dim
? (rhs_collapsing_dim - 1)
: rhs_collapsing_dim;
TF_RETURN_IF_ERROR(
ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
multiply(new_lhs, new_rhs), reduce_dim))));
return true;
for (int64 dim : dnums.lhs_contracting_dimensions()) {
int64 new_dim = lhs_dimension_map[dim];
if (new_dim != -1) {
new_dnums.add_lhs_contracting_dimensions(new_dim);
}
}

// If the lhs non-contracting dimensions is one, strip the one, brodcast to
// the rhs shape, multiply and then reduce the collapsing dimension
if (lhs->shape().dimensions(lhs_kept_dim) == 1) {
auto new_lhs = broadcast(StripDim(lhs, lhs_kept_dim), rhs->shape(),
broadcast_dims(rhs_rank, rhs_kept_dim));
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(add_reduce_in_f32(multiply(new_lhs, rhs),
rhs_collapsing_dim))));
return true;
for (int64 dim : dnums.rhs_batch_dimensions()) {
int64 new_dim = rhs_dimension_map[dim];
if (new_dim != -1) {
new_dnums.add_rhs_batch_dimensions(new_dim);
}
}

// If the rhs non-contracting dimensions is one, strip the one, brodcast to
// the lhs shape, multiply and then reduce the collapsing dimension
if (rhs->shape().dimensions(rhs_kept_dim) == 1) {
auto new_rhs = broadcast(StripDim(rhs, rhs_kept_dim), lhs->shape(),
broadcast_dims(lhs_rank, lhs_kept_dim));
TF_RETURN_IF_ERROR(ReplaceInstruction(
dot, reshape_if_necessary(add_reduce_in_f32(multiply(lhs, new_rhs),
lhs_collapsing_dim))));
return true;
for (int64 dim : dnums.rhs_contracting_dimensions()) {
int64 new_dim = rhs_dimension_map[dim];
if (new_dim != -1) {
new_dnums.add_rhs_contracting_dimensions(new_dim);
}
}

return false;
HloInstruction* new_lhs =
dot->parent()->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::DropDegenerateDimensions(lhs_shape),
dot->mutable_operand(0)));
HloInstruction* new_rhs =
dot->parent()->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::DropDegenerateDimensions(rhs_shape),
dot->mutable_operand(1)));
TF_ASSIGN_OR_RETURN(auto new_dot, MakeDotHlo(new_lhs, new_rhs, new_dnums,
dot->precision_config()));
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
return true;
}

StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
Expand Down Expand Up @@ -1698,8 +1508,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}

// Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are
// rank 2 or below.
// Only optimize F32 or BF16 dot operations where the dot, rhs and lhs.
if (dot->shape().element_type() != F32 &&
dot->shape().element_type() != BF16) {
return Status::OK();
Expand Down Expand Up @@ -1815,15 +1624,6 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
return ReplaceInstruction(dot, new_dot);
}

if (lhs->shape().rank() > 2 || rhs->shape().rank() > 2 ||
dot->shape().rank() > 2) {
if (options_.enable_dot_strength_reduction() &&
!options_.is_layout_sensitive()) {
TF_RETURN_IF_ERROR(HandleDotStrengthReduction(dot).status());
}
return Status::OK();
}

TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
OptimizeDotOfConcat(dot));
if (dot_of_concat_optimized) {
Expand All @@ -1843,11 +1643,10 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
return ReplaceInstruction(dot, dot_of_gather_optimized);
}

if (options_.enable_dot_strength_reduction() &&
!options_.is_layout_sensitive()) {
TF_ASSIGN_OR_RETURN(bool did_strength_reduction,
HandleDotStrengthReduction(dot));
if (did_strength_reduction) {
if (options_.enable_dot_strength_reduction()) {
TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
RemoveDegenerateDimensionFromDot(dot));
if (removed_degenerate_dimensions) {
return Status::OK();
}
}
Expand Down
Loading

0 comments on commit 86fbd13

Please sign in to comment.