Skip to content

Commit

Permalink
[XLA:GPU] Disable Triton GEMM for non-default dot() precision.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 515348102
  • Loading branch information
Ilia Sergachev authored and copybara-github committed Mar 9, 2023
1 parent c1cfc07 commit 9e4d083
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
10 changes: 1 addition & 9 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -996,23 +996,15 @@ cc_library(
hdrs = ["gemm_rewriter_triton.h"],
deps = [
":ir_emission_utils",
":launch_dimensions",
":stream_executor_util",
":target_util",
"//xla:autotune_results_proto_cc",
"//xla:shape_util",
"//xla:status",
"//xla:status_macros",
"//xla:statusor",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_pass",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@llvm-project//llvm:Linker",
"@llvm-project//llvm:ir_headers",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/gemm_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "xla/shape_util.h"
#include "xla/status.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

Expand Down Expand Up @@ -546,7 +547,9 @@ DotFusionAnalysis::DotFusionAnalysis(const HloInstruction* root) {
bool IsTritonHandledGEMM(
const HloInstruction& dot,
const se::CudaComputeCapability cuda_compute_capability) {
if (dot.opcode() != HloOpcode::kDot) {
if (dot.opcode() != HloOpcode::kDot ||
absl::c_any_of(dot.precision_config().operand_precision(),
[](int x) { return x != PrecisionConfig::DEFAULT; })) {
return false;
}
const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers();
Expand Down

0 comments on commit 9e4d083

Please sign in to comment.