Skip to content

Commit

Permalink
faster narrow mmt4d ukernels on x86 (iree-org#16655)
Browse files Browse the repository at this point in the history
A "narrow" matmul case is one where either the M dimension is small
(e.g. vecmat) or the N dimensions is small (e.g. matvec). In the context
of our ukernels, it always refers to the M dimension as reduction from
narrow-N to narrow-M is meant to have been performed already (see
iree-org#16399).

So far, narrow mmt4d ukernel tile functions had mostly just been added
as naive truncations of the general matmul kernel case. This was often
fine (mostly in floating-point cases, and outside of x86) but in some
cases was quite inefficient (integer-arithmetic cases on x86). In one
instance that we particularly cared about for Llama2 on x86, we had a
one-off doing something more clever (s16u4s32 case) but other than that,
our narrow matmul kernels on x86 were pretty bad. This makes them
better; this is also a net code shrink as part of the badness was from
trying to shoehorn a complicated general matmul kernel design onto all
M0 cases, and now it's much simpler with the narrow cases doing
something simpler, and the full-width case now relieved from having to
be generic.

Benchmark results on my AMD 7950X3D CPU (with Turbo disabled):

|name |Gop/s before|Gop/s after|Speedup (x)|

|-------------------------------------------|------------|-----------|-----------|
|BM_mmt4d_s8s8s32_tile_1x8x2_avx2_fma |56.0 |99.0 |1.77 |
|BM_mmt4d_s8s8s32_tile_2x8x2_avx2_fma |78.4 |132.7 |1.69 |
|BM_mmt4d_s8s8s32_tile_4x8x2_avx2_fma |92.2 |149.3 |1.62 |
|BM_mmt4d_s8s8s32_tile_8x8x2_avx2_fma |187.3 |188.5 |1.01 |
|BM_mmt4d_s8s8s32_tile_1x16x2_avx512_base |40.3 |119.4 |2.96 |
|BM_mmt4d_s8s8s32_tile_2x16x2_avx512_base |52.6 |156.1 |2.97 |
|BM_mmt4d_s8s8s32_tile_4x16x2_avx512_base |60.7 |169.8 |2.80 |
|BM_mmt4d_s8s8s32_tile_8x16x2_avx512_base |119.4 |178.9 |1.50 |
|BM_mmt4d_s8s8s32_tile_16x16x2_avx512_base |236.7 |235.0 |0.99 |
|BM_mmt4d_s8s8s32_tile_1x16x2_avx512_vnni |52.5 |87.9 |1.67 |
|BM_mmt4d_s8s8s32_tile_2x16x2_avx512_vnni |74.8 |162.2 |2.17 |
|BM_mmt4d_s8s8s32_tile_4x16x2_avx512_vnni |79.1 |196.5 |2.49 |
|BM_mmt4d_s8s8s32_tile_8x16x2_avx512_vnni |157.5 |214.2 |1.36 |
|BM_mmt4d_s8s8s32_tile_16x16x2_avx512_vnni |312.5 |325.8 |1.04 |
|BM_mmt4d_s16s16s32_tile_1x8x2_avx2_fma |75.5 |111.3 |1.47 |
|BM_mmt4d_s16s16s32_tile_2x8x2_avx2_fma |75.2 |170.5 |2.27 |
|BM_mmt4d_s16s16s32_tile_4x8x2_avx2_fma |124.0 |230.4 |1.86 |
|BM_mmt4d_s16s16s32_tile_8x8x2_avx2_fma |230.8 |253.5 |1.10 |
|BM_mmt4d_s16s16s32_tile_1x16x2_avx512_base |46.4 |187.9 |4.05 |
|BM_mmt4d_s16s16s32_tile_2x16x2_avx512_base |53.4 |229.4 |4.29 |
|BM_mmt4d_s16s16s32_tile_4x16x2_avx512_base |60.4 |228.8 |3.79 |
|BM_mmt4d_s16s16s32_tile_8x16x2_avx512_base |128.5 |246.7 |1.92 |
|BM_mmt4d_s16s16s32_tile_16x16x2_avx512_base|249.9 |251.6 |1.01 |
|BM_mmt4d_s16s16s32_tile_1x16x2_avx512_vnni |69.0 |102.3 |1.48 |
|BM_mmt4d_s16s16s32_tile_2x16x2_avx512_vnni |86.1 |173.1 |2.01 |
|BM_mmt4d_s16s16s32_tile_4x16x2_avx512_vnni |82.9 |320.6 |3.87 |
|BM_mmt4d_s16s16s32_tile_8x16x2_avx512_vnni |173.7 |341.0 |1.96 |
|BM_mmt4d_s16s16s32_tile_16x16x2_avx512_vnni|308.0 |343.9 |1.12 |
  • Loading branch information
bjacob authored Mar 4, 2024
1 parent 20ed89a commit 9d6d99f
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 450 deletions.
199 changes: 81 additions & 118 deletions runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx2_fma.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,58 @@ IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_f16f16f16_8x8x1_x86_64_avx2_fma, 8)

IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_4x8x2_x86_64_avx2_fma(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int M0) {
IREE_UK_ASSERT(M0 >= 1 && M0 <= 18 && iree_uk_is_po2_u32(M0));
IREE_UK_ASSERT(M0 >= 1 && M0 <= 8 && iree_uk_is_po2_u32(M0));
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
__m256i acc[8];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
acc[i] = _mm256_loadu_si256((__m256i*)(out_ptr + i * 8));
}
} else {
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
acc[i] = _mm256_setzero_si256();
}
}

for (int k = 0; k < params->K; ++k) {
// rhs_i16 is the rhs tile (2x8), sign-extended to i16.
__m256i rhs_i16 =
_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)rhs_ptr));
rhs_ptr += 16;
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
acc[i] = _mm256_add_epi32(
acc[i], _mm256_madd_epi16(_mm256_cvtepi8_epi16(_mm_set1_epi16(
*(const iree_uk_int16_t*)lhs_ptr)),
rhs_i16));
lhs_ptr += 2;
}
}

IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
_mm256_storeu_si256((__m256i*)(out_ptr + i * 8), acc[i]);
}
}

IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_4x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma, 1)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_4x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma, 2)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_4x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma, 4)

void iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params) {
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
Expand All @@ -165,22 +212,16 @@ iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma(
// This unusual layout is chosen so that the inner arithmetic loop only needs
// to perform cheap shuffles within 128bit groups of lanes.
__m256i acc[4][2];
const int imax = M0 <= 4 ? M0 : 4;
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int i = 0; i < 4; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
if (M0 <= 4) {
acc[i][j] = _mm256_castsi128_si256(
_mm_loadu_si128((__m128i*)(out_ptr + i * 8 + j * 4)));
} else {
acc[i][j] = iree_uk_avx_loadu_2x128(
(__m128i*)(out_ptr + i * 8 + j * 4),
(__m128i*)(out_ptr + (i + 4) * 8 + (1 - j) * 4));
}
acc[i][j] = iree_uk_avx_loadu_2x128(
(__m128i*)(out_ptr + i * 8 + j * 4),
(__m128i*)(out_ptr + (i + 4) * 8 + (1 - j) * 4));
}
}
} else {
IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int i = 0; i < 4; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
acc[i][j] = _mm256_setzero_si256();
}
Expand All @@ -197,146 +238,68 @@ iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma(
rhs_i16_perm[1] =
_mm256_permute2x128_si256(rhs_i16_perm[0], rhs_i16_perm[0], 0x01);
// lhs_i16 is the lhs tile (M0x2), sign-extended to i16.
__m256i lhs_i16;
if (M0 == 1) {
lhs_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si16(lhs_ptr));
lhs_ptr += 2;
} else if (M0 == 2) {
lhs_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si32(lhs_ptr));
lhs_ptr += 4;
} else if (M0 == 4) {
lhs_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si64(lhs_ptr));
lhs_ptr += 8;
} else {
lhs_i16 = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)lhs_ptr));
lhs_ptr += 16;
}
__m256i lhs_i16 =
_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)lhs_ptr));
lhs_ptr += 16;
// lhs_i16_dup4[i] is lanes of lhs_i16 shuffled as:
// (i, i, i, i, i+4, i+4, i+4, i+4).
__m256i lhs_i16_dup4[4];
if (M0 >= 1) lhs_i16_dup4[0] = _mm256_shuffle_epi32(lhs_i16, 0 * 0x55);
if (M0 >= 2) lhs_i16_dup4[1] = _mm256_shuffle_epi32(lhs_i16, 1 * 0x55);
if (M0 >= 4) lhs_i16_dup4[2] = _mm256_shuffle_epi32(lhs_i16, 2 * 0x55);
if (M0 >= 4) lhs_i16_dup4[3] = _mm256_shuffle_epi32(lhs_i16, 3 * 0x55);
IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
lhs_i16_dup4[0] = _mm256_shuffle_epi32(lhs_i16, 0 * 0x55);
lhs_i16_dup4[1] = _mm256_shuffle_epi32(lhs_i16, 1 * 0x55);
lhs_i16_dup4[2] = _mm256_shuffle_epi32(lhs_i16, 2 * 0x55);
lhs_i16_dup4[3] = _mm256_shuffle_epi32(lhs_i16, 3 * 0x55);
IREE_UK_UNROLL for (int i = 0; i < 4; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
acc[i][j] = _mm256_add_epi32(
acc[i][j], _mm256_madd_epi16(lhs_i16_dup4[i], rhs_i16_perm[j]));
}
}
}

IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int i = 0; i < 4; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
if (M0 <= 4) {
_mm_storeu_si128((__m128i*)(out_ptr + i * 8 + j * 4),
_mm256_extracti128_si256(acc[i][j], 0));
} else {
iree_uk_avx_storeu_2x128(
(__m128i*)(out_ptr + i * 8 + j * 4),
(__m128i*)(out_ptr + (i + 4) * 8 + (1 - j) * 4), acc[i][j]);
}
iree_uk_avx_storeu_2x128((__m128i*)(out_ptr + i * 8 + j * 4),
(__m128i*)(out_ptr + (i + 4) * 8 + (1 - j) * 4),
acc[i][j]);
}
}
}

IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_1x8x2_x86_64_avx2_fma, 1)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_2x8x2_x86_64_avx2_fma, 2)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_4x8x2_x86_64_avx2_fma, 4)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s8s32_1x8x2_to_8x8x2_x86_64_avx2_fma,
iree_uk_mmt4d_tile_s8s8s32_8x8x2_x86_64_avx2_fma, 8)

IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s16s16s32_1x8x2_to_8x8x2_x86_64_avx2_fma(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int M0) {
IREE_UK_ASSERT(M0 >= 1 && M0 <= 18 && iree_uk_is_po2_u32(M0));
IREE_UK_ASSERT(M0 >= 1 && M0 <= 8 && iree_uk_is_po2_u32(M0));
iree_uk_int32_t* IREE_UK_RESTRICT out_ptr = out_tile;
const iree_uk_int16_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int16_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
// acc[i][0] contains the 1st half of row i and the 2nd half of row (i+4).
// acc[i][1] contains the 2nd half of row i and the 1st half of row (i+4).
// This unusual layout is chosen so that the inner arithmetic loop only needs
// to perform cheap shuffles within 128bit groups of lanes.
__m256i acc[4][2];
const int imax = M0 <= 4 ? M0 : 4;
__m256i acc[8];
if (params->flags & IREE_UK_FLAG_MMT4D_ACCUMULATE) {
IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
if (M0 <= 4) {
acc[i][j] = _mm256_castsi128_si256(
_mm_loadu_si128((__m128i*)(out_ptr + i * 8 + j * 4)));
} else {
acc[i][j] = iree_uk_avx_loadu_2x128(
(__m128i*)(out_ptr + i * 8 + j * 4),
(__m128i*)(out_ptr + (i + 4) * 8 + (1 - j) * 4));
}
}
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
acc[i] = _mm256_loadu_si256((__m256i*)(out_ptr + i * 8));
}
} else {
IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
acc[i][j] = _mm256_setzero_si256();
}
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
acc[i] = _mm256_setzero_si256();
}
}

for (int k = 0; k < params->K; ++k) {
__m256i rhs_perm[2];
// rhs_perm[0] is the rhs tile (2x8).
rhs_perm[0] = _mm256_loadu_si256((const __m256i*)rhs_ptr);
// rhs is the rhs tile (2x8), sign-extended to i16.
__m256i rhs = _mm256_loadu_si256((const __m256i*)rhs_ptr);
rhs_ptr += 16;
// rhs_perm[1] is that with the halves swapped.
rhs_perm[1] = _mm256_permute2x128_si256(rhs_perm[0], rhs_perm[0], 0x01);
// lhs is the lhs tile (M0x2).
__m256i lhs;
if (M0 == 1) {
lhs = _mm256_castsi128_si256(_mm_loadu_si32(lhs_ptr));
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
acc[i] = _mm256_add_epi32(
acc[i],
_mm256_madd_epi16(_mm256_set1_epi32(*(const iree_uk_int32_t*)lhs_ptr),
rhs));
lhs_ptr += 2;
} else if (M0 == 2) {
lhs = _mm256_castsi128_si256(_mm_loadu_si64(lhs_ptr));
lhs_ptr += 4;
} else if (M0 == 4) {
lhs = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)lhs_ptr));
lhs_ptr += 8;
} else {
lhs = _mm256_loadu_si256((const __m256i*)lhs_ptr);
lhs_ptr += 16;
}
// lhs_dup4[i] is lanes of lhs shuffled as:
// (i, i, i, i, i+4, i+4, i+4, i+4).
__m256i lhs_dup4[4];
if (M0 >= 1) lhs_dup4[0] = _mm256_shuffle_epi32(lhs, 0 * 0x55);
if (M0 >= 2) lhs_dup4[1] = _mm256_shuffle_epi32(lhs, 1 * 0x55);
if (M0 >= 4) lhs_dup4[2] = _mm256_shuffle_epi32(lhs, 2 * 0x55);
if (M0 >= 4) lhs_dup4[3] = _mm256_shuffle_epi32(lhs, 3 * 0x55);
IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
acc[i][j] = _mm256_add_epi32(
acc[i][j], _mm256_madd_epi16(lhs_dup4[i], rhs_perm[j]));
}
}
}

IREE_UK_UNROLL for (int i = 0; i < imax; ++i) {
IREE_UK_UNROLL for (int j = 0; j < 2; ++j) {
if (M0 <= 4) {
_mm_storeu_si128((__m128i*)(out_ptr + i * 8 + j * 4),
_mm256_extracti128_si256(acc[i][j], 0));
} else {
iree_uk_avx_storeu_2x128(
(__m128i*)(out_ptr + i * 8 + j * 4),
(__m128i*)(out_ptr + (i + 4) * 8 + (1 - j) * 4), acc[i][j]);
}
}
IREE_UK_UNROLL for (int i = 0; i < M0; ++i) {
_mm256_storeu_si256((__m256i*)(out_ptr + i * 8), acc[i]);
}
}

Expand Down
Loading

0 comments on commit 9d6d99f

Please sign in to comment.