Skip to content

Commit

Permalink
Do not normalize broadcast operands on GPU.
Browse files Browse the repository at this point in the history
There is no need to normalize broadcast operands on GPU, and it makes it harder
to write matchers if they have to optionally match a bitcast.

PiperOrigin-RevId: 514373072
  • Loading branch information
akuegel authored and copybara-github committed Mar 6, 2023
1 parent 18534d1 commit 5673903
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 1 deletion.
3 changes: 2 additions & 1 deletion xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3665,7 +3665,8 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
if (options_.is_layout_sensitive()) {
return OkStatus();
}
if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
if (options_.enable_normalize_broadcast_operand() &&
ShapeUtil::HasDegenerateDimensions(operand->shape())) {
auto new_operand =
operand->parent()->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
Expand Down
10 changes: 10 additions & 0 deletions xla/service/algebraic_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ class AlgebraicSimplifierOptions {

bool enable_sink_broadcast() const { return enable_sink_broadcast_; }

void set_enable_normalize_broadcast_operand(
bool enable_normalize_broadcast_operand) {
enable_normalize_broadcast_operand_ = enable_normalize_broadcast_operand;
}

bool enable_normalize_broadcast_operand() const {
return enable_normalize_broadcast_operand_;
}

// If true, min(x, NaN) = NaN. If false, min(x, NaN) = x.
//
// TODO(b/209827141): Remove this and make minmax_propagate_nan uncondtionally
Expand Down Expand Up @@ -220,6 +229,7 @@ class AlgebraicSimplifierOptions {
bool enable_reduce_of_reshape_{true};
bool enable_negative_padding_replacement_{true};
bool enable_sink_broadcast_{true};
bool enable_normalize_broadcast_operand_{true};
int64_t very_small_gather_size_{4};
bool minmax_propagate_nan_{true};
Metadata metadata_;
Expand Down
17 changes: 17 additions & 0 deletions xla/service/algebraic_simplifier_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,23 @@ TEST_F(AlgebraicSimplifierTest, DegenerateDimsInOperandRemovedFromBroadcast) {
GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0)))));
}

// Test that Broadcast(x) where x has degenerate dimensions first does not
// remove the degenerate dimensions if the corresponding option is disabled.
TEST_F(AlgebraicSimplifierTest,
DegenerateDimsInOperandNotRemovedFromBroadcast) {
const char* kModuleStr = R"(
HloModule m
test {
c = f32[1,4] parameter(0)
ROOT b = f32[5,1,4,3] broadcast(c), dimensions={1,2}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
AlgebraicSimplifierOptions options;
options.set_enable_normalize_broadcast_operand(false);
ASSERT_FALSE(AlgebraicSimplifier(options).Run(m.get()).value());
}

// Test to catch a crash where we were overshooting the reshaped_dimensions
// vector.
TEST_F(AlgebraicSimplifierTest, ArrayOvershootTest) {
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
// AlgebraicSimplifier We run algsimp to a fixed point.
AlgebraicSimplifierOptions options;
options.set_enable_conv_operand_swap(false);
options.set_enable_normalize_broadcast_operand(false);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);

pipeline.AddPass<HloConstantFolding>();
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ Status GpuCompiler::OptimizeHloModule(
layout_insensitive_algsimp_opts.set_minmax_propagate_nan(
!debug_options.xla_gpu_enable_fast_min_max());

layout_insensitive_algsimp_opts.set_enable_normalize_broadcast_operand(false);

if (gpu_target_config.platform_name == "ROCM") {
layout_insensitive_algsimp_opts.set_enable_conv_operand_swap(false);
}
Expand Down Expand Up @@ -806,6 +808,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
options.set_supports_non_canonical_dots(false);
options.set_is_layout_sensitive(true);
options.set_enable_conv_operand_swap(false);
options.set_enable_normalize_broadcast_operand(false);
// "slow" minmax means we propagate nan.
options.set_minmax_propagate_nan(
!debug_options.xla_gpu_enable_fast_min_max());
Expand Down Expand Up @@ -907,6 +910,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment(
options.set_supports_non_canonical_dots(false);
options.set_is_layout_sensitive(true);
options.set_enable_conv_operand_swap(false);
options.set_enable_normalize_broadcast_operand(false);
// "slow" minmax means we propagate nan.
options.set_minmax_propagate_nan(
!hlo_module->config().debug_options().xla_gpu_enable_fast_min_max());
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/nvptx_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization(

AlgebraicSimplifierOptions algsimp_options;
algsimp_options.set_enable_conv_operand_swap(false);
algsimp_options.set_enable_normalize_broadcast_operand(false);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(algsimp_options);

// CudnnSimplifyPadding gets rid of some padding introduced by
Expand Down

0 comments on commit 5673903

Please sign in to comment.