From d4c87edd2d141a4dd856f1ceecbfbee9fcc881fd Mon Sep 17 00:00:00 2001 From: luliyucoordinate Date: Thu, 7 Nov 2024 12:32:42 +0000 Subject: [PATCH] add support for int4 in MMA_Traits --- include/cute/atom/mma_traits_sm80.hpp | 197 ++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp index 706b10d88..9b2a635b0 100644 --- a/include/cute/atom/mma_traits_sm80.hpp +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -419,6 +419,203 @@ template <> struct MMA_Traits : MMA_Traits {}; +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V8) -> (M8,N32) + using ALayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V16) -> (M16,N32) + using ALayout = Layout, Shape < _8, _2>>, + Stride, Stride<_16, _8>>>; + // (T32,V8) -> (M8,N32) + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,N64) + using ALayout = Layout, Shape < _8, _2, _2>>, + Stride, Stride<_16, _8, _512>>>; + // (T32,V16) -> (M8,N64) + using BLayout = Layout, Shape <_8, _2>>, + Stride, Stride<_8, _256>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + /////////////////////////////////////////////////////////////////////////////// /////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// ///////////////////////////////////////////////////////////////////////////////