diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel index 788730dbb753..25997bcdfade 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel @@ -164,6 +164,8 @@ iree_compiler_cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:ValueBoundsOpInterface", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToArmSME", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt index 77b92f3f66a4..19043b59c898 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt @@ -138,6 +138,8 @@ iree_cc_library( MLIRTransformDialect MLIRTransformUtils MLIRTransforms + MLIRUBDialect + MLIRUBToLLVM MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorToArmSME diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp index 0f5fb32a7c1b..1c0fc6b86cd8 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp @@ -37,6 +37,7 @@ #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/TosaToArith/TosaToArith.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -1058,9 +1059,9 @@ void ConvertToLLVMPass::runOnOperation() { vector::populateVectorStepLoweringPatterns(patterns); populateVectorToLLVMConversionPatterns(typeConverter, patterns, reassociateFpReductions); + ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); - if (isAArch64(targetAttr) && (hasAnySVEFeature(targetAttr) || hasSMEFeature(targetAttr))) { populateArmSVELegalizeForLLVMExportPatterns(typeConverter, patterns); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel index a5c1bce4beda..4e8338511036 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/BUILD.bazel @@ -215,6 +215,8 @@ iree_compiler_cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "@llvm-project//mlir:UBToLLVM", "@llvm-project//mlir:ValueBoundsOpInterface", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToGPU", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt index 5c206210ab30..d789999267ce 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/CMakeLists.txt @@ -163,6 +163,8 @@ iree_cc_library( MLIRTransformDialect MLIRTransformUtils MLIRTransforms + MLIRUBDialect + MLIRUBToLLVM MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorToGPU diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp index 44172fe4758b..fcec6283617b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToNVVM.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -29,6 +30,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -53,7 +55,7 @@ struct ConvertToNVVMPass final void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); + NVVM::NVVMDialect, affine::AffineDialect, ub::UBDialect>(); } void runOnOperation() override { ModuleOp m = getOperation(); @@ -161,6 +163,7 @@ struct ConvertToNVVMPass final populateGpuToNVVMConversionPatterns(converter, llvmPatterns); populateNVGPUToNVVMConversionPatterns(converter, llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); + ub::populateUBToLLVMConversionPatterns(converter, llvmPatterns); /// Target specification. LLVMConversionTarget target(getContext()); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp index 850efcf62fab..60355b8f0db0 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -29,6 +30,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" @@ -99,9 +101,9 @@ struct ConvertToROCDLPass final ConvertToROCDLPass>::ConvertToROCDLPassBase; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override { ModuleOp m = getOperation(); @@ -238,6 +240,8 @@ struct ConvertToROCDLPass final LLVMConversionTarget target(getContext()); populateFuncToLLVMFuncOpConversionPattern(converter, llvmPatterns); configureGpuToROCDLConversionLegality(target); + ub::populateUBToLLVMConversionPatterns(converter, llvmPatterns); + if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir index 57f73c21ea88..9f902ff08b14 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-lowering))" --split-input-file %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-lowering,canonicalize,cse))" --split-input-file %s | FileCheck %s module { func.func @broadcast_read_lowering(%arg0: memref<4096x32xf16>) -> vector<1x8xf16> { @@ -11,9 +11,8 @@ module { } // CHECK-LABEL: func.func @broadcast_read_lowering // CHECK-SAME: (%[[ARG0:.+]]: memref<4096x32xf16>) -// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x8xf16> // CHECK: %[[LOAD:.+]] = vector.load %[[ARG0]]{{.*}} : memref<4096x32xf16> // CHECK: %[[ELEM:.+]] = vector.extract %[[LOAD]][0] : f16 from vector<1xf16> // CHECK: %[[SPLAT:.+]] = vector.splat %[[ELEM]] : vector<8xf16> -// CHECK: %[[INSERT:.+]] = vector.insert %[[SPLAT]], %[[INIT]] [0] : vector<8xf16> into vector<1x8xf16> +// CHECK: %[[INSERT:.+]] = vector.broadcast %[[SPLAT]] : vector<8xf16> to vector<1x8xf16> // CHECK: return %[[INSERT]] diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel index d264a26551f9..7ef31be6fc81 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel @@ -154,6 +154,7 @@ iree_compiler_cc_library( "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBToSPIRV", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorInterfaces", "@llvm-project//mlir:VectorToGPU", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt index 08ec5885dc97..867786177b35 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt @@ -129,6 +129,7 @@ iree_cc_library( MLIRTransformDialect MLIRTransformUtils MLIRTransforms + MLIRUBToSPIRV MLIRVectorDialect MLIRVectorInterfaces MLIRVectorToGPU diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index ce6984a1cd11..5f2992e667bd 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -34,6 +34,7 @@ #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" +#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -656,6 +657,8 @@ void ConvertToSPIRVPass::runOnOperation() { // Pull in builtin func to spirv.func conversion. populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); + // Add IREE HAL interface op conversions. patterns.add< HALInterfaceLoadConstantConverter, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir index 2c0d654521a2..299633abdb8f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_elementwise_ops.mlir @@ -1,5 +1,5 @@ // RUN: iree-opt --split-input-file \ -// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering))' \ +// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse))' \ // RUN: %s | FileCheck %s func.func @add(%lhs: tensor<2x8xf32>, %rhs: tensor<2x8xf32>) -> tensor<2x8xf32> { @@ -48,7 +48,7 @@ func.func @transpose_leading_one_dim(%input: tensor<4x1x1xf32>) -> tensor<1x1x4x // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[ZERO:.+]] = ub.poison : vector<4xf32> // CHECK: %[[R0:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]{{.+}} : tensor<4x1x1xf32>, vector<1xf32> // CHECK: %[[R1:.+]] = vector.transfer_read %[[INPUT]][%[[C1]], %[[C0]], %[[C0]]]{{.+}} : tensor<4x1x1xf32>, vector<1xf32> @@ -93,7 +93,7 @@ func.func @transpose_add(%lhs: tensor<4x2xf32>, %rhs: tensor<2xf32>) -> tensor<2 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK-DAG: %[[OINIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: %[[OINIT:.+]] = ub.poison : vector<4xf32> // CHECK: %[[LHS0:.+]] = vector.transfer_read %[[LHS]][%[[C0]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32> // CHECK: %[[LHS1:.+]] = vector.transfer_read %[[LHS]][%[[C1]], %[[C0]]]{{.+}} : tensor<4x2xf32>, vector<2xf32> diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir index 03644e7c9bec..f0c42b154e1e 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/vectorize_matmul.mlir @@ -1,5 +1,5 @@ // RUN: iree-opt --split-input-file \ -// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering))' \ +// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-generic-vectorization,iree-spirv-initial-vector-lowering,iree-codegen-optimize-tensor-insert-extract-slices,iree-spirv-final-vector-lowering,canonicalize,cse))' \ // RUN: %s | FileCheck %s func.func @matmul_1x4x4(%lhs: tensor<1x4xf32>, %rhs: tensor<4x4xf32>, %init: tensor<1x4xf32>) -> tensor<1x4xf32> { @@ -139,9 +139,7 @@ func.func @matmul_broadcast_add(%init: tensor<1x8xf32>, %a: tensor<1x8xf32>, %b: // CHECK: %[[EXT0:.+]] = vector.extract %[[READ]][0] : f32 from vector<1xf32> // CHECK: %[[BCST0:.+]] = vector.splat %[[EXT0]] : vector<4xf32> // CHECK: %[[ADD0:.+]] = arith.addf %{{.+}}, %[[BCST0]] : vector<4xf32> -// CHECK: %[[EXT1:.+]] = vector.extract %[[READ]][0] : f32 from vector<1xf32> -// CHECK: %[[BCST1:.+]] = vector.splat %[[EXT1]] : vector<4xf32> -// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[BCST1]] : vector<4xf32> +// CHECK: %[[ADD1:.+]] = arith.addf %{{.+}}, %[[BCST0]] : vector<4xf32> // CHECK: %[[WRITE0:.+]] = vector.transfer_write %[[ADD0]], %[[INIT]][%[[C0]], %[[C0]]] // CHECK: %[[WRITE1:.+]] = vector.transfer_write %[[ADD1]], %[[WRITE0]][%[[C0]], %[[C4]]] // CHECK: return %[[WRITE1]] @@ -287,7 +285,7 @@ func.func @matmul_4x4x4_i8_to_i32_dot_prod(%lhs: tensor<4x4xi8>, %rhs : tensor<4 // CHECK-SAME: (%[[LHS:.+]]: tensor<4x4xi8>, %[[RHS:.+]]: tensor<4x4xi8>) // CHECK-DAG: %[[C0I8:.+]] = arith.constant 0 : i8 // CHECK-DAG: %[[C0I32:.+]] = arith.constant 0 : i32 -// CHECK-DAG: %[[V4I8:.+]] = arith.constant dense<0> : vector<4xi8> +// CHECK-DAG: %[[V4I8:.+]] = ub.poison : vector<4xi8> // CHECK-DAG: %[[V4I32:.+]] = arith.constant dense<0> : vector<4xi32> // CHECK-DAG: %[[V1I32:.+]] = arith.constant dense<0> : vector<1xi32> // CHECK-DAG: %[[IDX0:.+]] = arith.constant 0 : index diff --git a/third_party/llvm-project b/third_party/llvm-project index 2abf270e3e2d..ea7924e1412b 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 2abf270e3e2dd2d70a7b6eaf11859be07471ed3a +Subproject commit ea7924e1412b545e2065e30446b721a89a5e07d3