Skip to content

Commit

Permalink
add support for int4 in MMA_Traits
Browse files Browse the repository at this point in the history
  • Loading branch information
luliyucoordinate committed Nov 7, 2024
1 parent f4326bb commit 7703b6f
Showing 1 changed file with 197 additions and 0 deletions.
197 changes: 197 additions & 0 deletions include/cute/atom/mma_traits_sm80.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,203 @@ template <>
struct MMA_Traits<SM80_16x8x32_S32U8U8S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U8U8S32_TN> {};

///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = s4 * s4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////

template <>
struct MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;

using Shape_MNK = Shape<_8, _8, _32>;
using ThrID = Layout<_32>;
// (T32,V8) -> (M8,N32)
using ALayout = Layout<Shape <Shape < _4, _8>, Shape <_8>>,
Stride<Stride<_64, _1>, Stride<_8>>>;
using BLayout = Layout<Shape <Shape < _4, _8>, Shape <_8>>,
Stride<Stride<_64, _1>, Stride<_8>>>;
using CLayout = SM80_8x8_Row;
};

template <>
struct MMA_Traits<SM80_8x8x32_S32S4S4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;

using Shape_MNK = Shape<_16, _8, _32>;
using ThrID = Layout<_32>;
// (T32,V16) -> (M16,N32)
using ALayout = Layout<Shape <Shape < _4, _8>, Shape < _8, _2>>,
Stride<Stride<_128, _1>, Stride<_16, _8>>>;
// (T32,V8) -> (M8,N32)
using BLayout = Layout<Shape <Shape < _4, _8>, Shape <_8>>,
Stride<Stride<_32, _1>, Stride<_8>>>;
using CLayout = SM80_16x8_Row;
};

template <>
struct MMA_Traits<SM80_16x8x32_S32S4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;

using Shape_MNK = Shape<_16, _8, _64>;
using ThrID = Layout<_32>;
// (T32,V32) -> (M16,N64)
using ALayout = Layout<Shape <Shape < _4, _8>, Shape < _8, _2, _2>>,
Stride<Stride<_128, _1>, Stride<_16, _8, _512>>>;
// (T32,V16) -> (M8,N64)
using BLayout = Layout<Shape <Shape < _4, _8>, Shape <_8, _2>>,
Stride<Stride<_64, _1>, Stride<_8, _256>>>;
using CLayout = SM80_16x8_Row;
};

template <>
struct MMA_Traits<SM80_16x8x64_S32S4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {};

///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = s4 * u4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////

template <>
struct MMA_Traits<SM80_8x8x32_S32S4U4S32_TN>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_8x8x32_S32S4U4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32S4U4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x32_S32S4U4S32_TN>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_16x8x32_S32S4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32S4U4S32_TN> {};
template <>
struct MMA_Traits<SM80_16x8x64_S32S4U4S32_TN>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = int4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_16x8x64_S32S4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32S4U4S32_TN> {};

///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = u4 * s4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////

template <>
struct MMA_Traits<SM80_8x8x32_S32U4S4S32_TN>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_8x8x32_S32U4S4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32U4S4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x32_S32U4S4S32_TN>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_16x8x32_S32U4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U4S4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x64_S32U4S4S32_TN>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = int4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_16x8x64_S32U4S4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32U4S4S32_TN> {};

///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = u4 * u4 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////

template <>
struct MMA_Traits<SM80_8x8x32_S32U4U4S32_TN>
: MMA_Traits<SM80_8x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_8x8x32_S32U4U4S32_TN_SATURATE>
: MMA_Traits<SM80_8x8x32_S32U4U4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x32_S32U4U4S32_TN>
: MMA_Traits<SM80_16x8x32_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_16x8x32_S32U4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x32_S32U4U4S32_TN> {};

template <>
struct MMA_Traits<SM80_16x8x64_S32U4U4S32_TN>
: MMA_Traits<SM80_16x8x64_S32S4S4S32_TN> {
using ValTypeD = int32_t;
using ValTypeA = uint4b_t;
using ValTypeB = uint4b_t;
using ValTypeC = int32_t;
};

template <>
struct MMA_Traits<SM80_16x8x64_S32U4U4S32_TN_SATURATE>
: MMA_Traits<SM80_16x8x64_S32U4U4S32_TN> {};

///////////////////////////////////////////////////////////////////////////////
/////////////////////////// s32 = b1 ^ b1 + s32 ///////////////////////////////
///////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 7703b6f

Please sign in to comment.