forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
faster narrow mmt4d ukernels on x86 (iree-org#16655)
A "narrow" matmul case is one where either the M dimension is small (e.g. vecmat) or the N dimensions is small (e.g. matvec). In the context of our ukernels, it always refers to the M dimension as reduction from narrow-N to narrow-M is meant to have been performed already (see iree-org#16399). So far, narrow mmt4d ukernel tile functions had mostly just been added as naive truncations of the general matmul kernel case. This was often fine (mostly in floating-point cases, and outside of x86) but in some cases was quite inefficient (integer-arithmetic cases on x86). In one instance that we particularly cared about for Llama2 on x86, we had a one-off doing something more clever (s16u4s32 case) but other than that, our narrow matmul kernels on x86 were pretty bad. This makes them better; this is also a net code shrink as part of the badness was from trying to shoehorn a complicated general matmul kernel design onto all M0 cases, and now it's much simpler with the narrow cases doing something simpler, and the full-width case now relieved from having to be generic. Benchmark results on my AMD 7950X3D CPU (with Turbo disabled): |name |Gop/s before|Gop/s after|Speedup (x)| |-------------------------------------------|------------|-----------|-----------| |BM_mmt4d_s8s8s32_tile_1x8x2_avx2_fma |56.0 |99.0 |1.77 | |BM_mmt4d_s8s8s32_tile_2x8x2_avx2_fma |78.4 |132.7 |1.69 | |BM_mmt4d_s8s8s32_tile_4x8x2_avx2_fma |92.2 |149.3 |1.62 | |BM_mmt4d_s8s8s32_tile_8x8x2_avx2_fma |187.3 |188.5 |1.01 | |BM_mmt4d_s8s8s32_tile_1x16x2_avx512_base |40.3 |119.4 |2.96 | |BM_mmt4d_s8s8s32_tile_2x16x2_avx512_base |52.6 |156.1 |2.97 | |BM_mmt4d_s8s8s32_tile_4x16x2_avx512_base |60.7 |169.8 |2.80 | |BM_mmt4d_s8s8s32_tile_8x16x2_avx512_base |119.4 |178.9 |1.50 | |BM_mmt4d_s8s8s32_tile_16x16x2_avx512_base |236.7 |235.0 |0.99 | |BM_mmt4d_s8s8s32_tile_1x16x2_avx512_vnni |52.5 |87.9 |1.67 | |BM_mmt4d_s8s8s32_tile_2x16x2_avx512_vnni |74.8 |162.2 |2.17 | |BM_mmt4d_s8s8s32_tile_4x16x2_avx512_vnni |79.1 |196.5 |2.49 | |BM_mmt4d_s8s8s32_tile_8x16x2_avx512_vnni |157.5 |214.2 |1.36 | |BM_mmt4d_s8s8s32_tile_16x16x2_avx512_vnni |312.5 |325.8 |1.04 | |BM_mmt4d_s16s16s32_tile_1x8x2_avx2_fma |75.5 |111.3 |1.47 | |BM_mmt4d_s16s16s32_tile_2x8x2_avx2_fma |75.2 |170.5 |2.27 | |BM_mmt4d_s16s16s32_tile_4x8x2_avx2_fma |124.0 |230.4 |1.86 | |BM_mmt4d_s16s16s32_tile_8x8x2_avx2_fma |230.8 |253.5 |1.10 | |BM_mmt4d_s16s16s32_tile_1x16x2_avx512_base |46.4 |187.9 |4.05 | |BM_mmt4d_s16s16s32_tile_2x16x2_avx512_base |53.4 |229.4 |4.29 | |BM_mmt4d_s16s16s32_tile_4x16x2_avx512_base |60.4 |228.8 |3.79 | |BM_mmt4d_s16s16s32_tile_8x16x2_avx512_base |128.5 |246.7 |1.92 | |BM_mmt4d_s16s16s32_tile_16x16x2_avx512_base|249.9 |251.6 |1.01 | |BM_mmt4d_s16s16s32_tile_1x16x2_avx512_vnni |69.0 |102.3 |1.48 | |BM_mmt4d_s16s16s32_tile_2x16x2_avx512_vnni |86.1 |173.1 |2.01 | |BM_mmt4d_s16s16s32_tile_4x16x2_avx512_vnni |82.9 |320.6 |3.87 | |BM_mmt4d_s16s16s32_tile_8x16x2_avx512_vnni |173.7 |341.0 |1.96 | |BM_mmt4d_s16s16s32_tile_16x16x2_avx512_vnni|308.0 |343.9 |1.12 |
- Loading branch information
Showing
3 changed files
with
240 additions
and
450 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.