Skip to content

Commit

Permalink
[NVPTX] Add idp2a, idp4a intrinsics (llvm#102763)
Browse files Browse the repository at this point in the history
Add support for `llvm.nvvm.idp2a` and `llvm.nvvm.idp4a` which correspond
directly to `dp2a` and `dp4a` PTX instructions.
  • Loading branch information
AlexMaclean authored Aug 14, 2024
1 parent 2c12c1e commit 099bf20
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 0 deletions.
71 changes: 71 additions & 0 deletions llvm/docs/NVPTXUsage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,77 @@ The ``@llvm.nvvm.fence.proxy.tensormap_generic.*`` is a uni-directional fence us

The address operand ``addr`` and the operand ``size`` together specify the memory range ``[addr, addr+size)`` on which the ordering guarantees on the memory accesses across the proxies is to be provided. The only supported value for the ``size`` operand is ``128`` and must be an immediate. Generic Addressing is used unconditionally, and the address specified by the operand addr must fall within the ``.global`` state space. Otherwise, the behavior is undefined. For more information, see `PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-membar>`_.

Arithmetic Intrinsics
---------------------

'``llvm.nvvm.idp2a.[us].[us]``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm
declare i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
declare i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
declare i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
declare i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 immarg %is.hi, i32 %c)
Overview:
"""""""""

The '``llvm.nvvm.idp2a.[us].[us]``' intrinsics performs a 2-element vector dot
product followed by addition. They corresponds directly to the ``dp2a`` PTX
instruction.

Semantics:
""""""""""

The 32-bit value in ``%a`` is broken into 2 16-bit values which are extended to
32 bits. For the '``llvm.nvvm.idp2a.u.[us]``' variants zero-extension is used,
while for the '``llvm.nvvm.idp2a.s.[us]``' sign-extension is used. Two bytes are
selected from ``%b``, if ``%is.hi`` is true, the most significant bytes are
selected, otherwise the least significant bytes are selected. These bytes are
then extended to 32-bits. For the '``llvm.nvvm.idp2a.[us].u``' variants
zero-extension is used, while for the '``llvm.nvvm.idp2a.[us].s``'
sign-extension is used. The dot product of these 2-element vectors is added to
``%c`` to produce the return.


'``llvm.nvvm.idp4a.[us].[us]``' Intrinsics
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""

.. code-block:: llvm
declare i32 @llvm.nvvm.idp4a.s.s(i32 %a, i32 %b, i32 %c)
declare i32 @llvm.nvvm.idp4a.s.u(i32 %a, i32 %b, i32 %c)
declare i32 @llvm.nvvm.idp4a.u.s(i32 %a, i32 %b, i32 %c)
declare i32 @llvm.nvvm.idp4a.u.u(i32 %a, i32 %b, i32 %c)
Overview:
"""""""""

The '``llvm.nvvm.idp4a.[us].[us]``' intrinsics perform a 4-element vector dot
product followed by addition. They corresponds directly to the ``dp4a`` PTX
instruction.

Semantics:
""""""""""

Each of the 4 bytes in both ``%a`` and ``%b`` are extended to 32-bit integers
forming 2 ``<4 x i32>``. For ``%a``, zero-extension is used in the
'``llvm.nvvm.idp4a.u.[us]``' variants, while sign-extension is used with
'``llvm.nvvm.idp4a.s.[us]``' variants. Similarly, for ``%b``, zero-extension is
used in the '``llvm.nvvm.idp4a.[us].u``' variants, while sign-extension is used
with '``llvm.nvvm.idp4a.[us].s``' variants. The dot product of these 4-element
vectors is added to ``%c`` to produce the return.



Other Intrinsics
----------------

Expand Down
16 changes: 16 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,22 @@ let TargetPrefix = "nvvm" in {
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;

//
// Dot Product
//
foreach a_type = ["s", "u"] in {
foreach b_type = ["s", "u"] in {
def int_nvvm_idp4a_ # a_type # _ # b_type :
DefaultAttrsIntrinsic<[llvm_i32_ty],
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable]>;
def int_nvvm_idp2a_ # a_type # _ # b_type :
DefaultAttrsIntrinsic<[llvm_i32_ty],
[llvm_i32_ty, llvm_i32_ty, llvm_i1_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<2>>]>;
}
}

//
// Convert
//
Expand Down
28 changes: 28 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def do_SQRTF32_RN : Predicate<"usePrecSqrtF32()">;

def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;

def True : Predicate<"true">;
def False : Predicate<"false">;
Expand Down Expand Up @@ -3920,6 +3921,33 @@ let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in
}


foreach a_type = ["s", "u"] in {
foreach b_type = ["s", "u"] in {

def DOT4_ # a_type # b_type :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
"dp4a." # a_type # "32." # b_type # "32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst,
(!cast<Intrinsic>("int_nvvm_idp4a_" # a_type # "_" # b_type)
(i32 Int32Regs:$a), (i32 Int32Regs:$b), (i32 Int32Regs:$c)))]>,
Requires<[hasDotInstructions]>;

foreach is_hi = [0, -1] in {
defvar lohi_suffix = !if(is_hi, "hi", "lo");

def DOT2_ # lohi_suffix # _ # a_type # b_type :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
"dp2a." # lohi_suffix # "." # a_type # "32." # b_type # "32 \t$dst, $a, $b, $c;",
[(set Int32Regs:$dst,
(!cast<Intrinsic>("int_nvvm_idp2a_" # a_type # "_" # b_type)
(i32 Int32Regs:$a), (i32 Int32Regs:$b), is_hi, (i32 Int32Regs:$c)))]>,
Requires<[hasDotInstructions]>;
}
}
}

include "NVPTXIntrinsics.td"

//-----------------------------------
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasMemoryOrdering() const { return SmVersion >= 70 && PTXVersion >= 60; }
// Does SM & PTX support atomic relaxed MMIO operations ?
bool hasRelaxedMMIO() const { return SmVersion >= 70 && PTXVersion >= 82; }
bool hasDotInstructions() const {
return SmVersion >= 61 && PTXVersion >= 50;
}
unsigned int getFullSmVersion() const { return FullSmVersion; }
unsigned int getSmVersion() const { return getFullSmVersion() / 10; }
// GPUs with "a" suffix have include architecture-accelerated features that
Expand Down
222 changes: 222 additions & 0 deletions llvm/test/CodeGen/NVPTX/dot-product.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -march=nvptx -mcpu=sm_61 | FileCheck %s
; RUN: llc < %s -march=nvptx64 -mcpu=sm_61 | FileCheck %s

target triple = "nvptx-nvidia-cuda"

declare i32 @llvm.nvvm.idp4a.s.s(i32, i32, i32)
declare i32 @llvm.nvvm.idp4a.s.u(i32, i32, i32)
declare i32 @llvm.nvvm.idp4a.u.s(i32, i32, i32)
declare i32 @llvm.nvvm.idp4a.u.u(i32, i32, i32)

define i32 @test_dp4a_u32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_u32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_u32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_u32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_u32_u32_param_2];
; CHECK-NEXT: dp4a.u32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.u.u(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_u32imm_u32imm(i32 %c) {
; CHECK-LABEL: test_dp4a_u32imm_u32imm(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_u32imm_u32imm_param_0];
; CHECK-NEXT: mov.b32 %r2, 0;
; CHECK-NEXT: dp4a.u32.u32 %r3, %r2, %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.u.u(i32 0, i32 0, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_u32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_u32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_u32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_u32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_u32_s32_param_2];
; CHECK-NEXT: dp4a.u32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.u.s(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_s32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_s32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_s32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_s32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_s32_u32_param_2];
; CHECK-NEXT: dp4a.s32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.s.u(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

define i32 @test_dp4a_s32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp4a_s32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp4a_s32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp4a_s32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp4a_s32_s32_param_2];
; CHECK-NEXT: dp4a.s32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp4a.s.s(i32 %a, i32 %b, i32 %c)
ret i32 %call
}

declare i32 @llvm.nvvm.idp2a.s.s(i32, i32, i1 immarg, i32)
declare i32 @llvm.nvvm.idp2a.s.u(i32, i32, i1 immarg, i32)
declare i32 @llvm.nvvm.idp2a.u.s(i32, i32, i1 immarg, i32)
declare i32 @llvm.nvvm.idp2a.u.u(i32, i32, i1 immarg, i32)

define i32 @test_dp2a_lo_u32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_u32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_u32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_u32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_u32_u32_param_2];
; CHECK-NEXT: dp2a.lo.u32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_lo_u32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_u32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_u32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_u32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_u32_s32_param_2];
; CHECK-NEXT: dp2a.lo.u32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_lo_s32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_s32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_s32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_s32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_s32_u32_param_2];
; CHECK-NEXT: dp2a.lo.s32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_lo_s32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_lo_s32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_lo_s32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_lo_s32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_lo_s32_s32_param_2];
; CHECK-NEXT: dp2a.lo.s32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 0, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_u32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_u32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_u32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_u32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_u32_u32_param_2];
; CHECK-NEXT: dp2a.hi.u32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.u(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_u32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_u32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_u32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_u32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_u32_s32_param_2];
; CHECK-NEXT: dp2a.hi.u32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.u.s(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_s32_u32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_s32_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_s32_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_s32_u32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_s32_u32_param_2];
; CHECK-NEXT: dp2a.hi.s32.u32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.u(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

define i32 @test_dp2a_hi_s32_s32(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: test_dp2a_hi_s32_s32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_dp2a_hi_s32_s32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_dp2a_hi_s32_s32_param_1];
; CHECK-NEXT: ld.param.u32 %r3, [test_dp2a_hi_s32_s32_param_2];
; CHECK-NEXT: dp2a.hi.s32.s32 %r4, %r1, %r2, %r3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4;
; CHECK-NEXT: ret;
%call = call i32 @llvm.nvvm.idp2a.s.s(i32 %a, i32 %b, i1 1, i32 %c)
ret i32 %call
}

0 comments on commit 099bf20

Please sign in to comment.