Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof-of-concept: speeding up gemm reference kernel #863

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

bartoldeman
Copy link
Contributor

Related issue: #259

This proof of concept is the result of me playing around a bit with reference kernels to better understand the underlying algorithms used for GEMM by BLIS and OpenBLAS for an upcoming talk ( https://easybuild.io/eum25/#linalg ): with blislab I got close to peak performance with just a kernel written in C.

This is a proof-of-concept since I'm not quite sure how to integrate parts of it, particularly the prefetch stuff, and I may have abused the C preprocessor a bit too much, although other things may be straight forward.

So here's the idea:
via a macro, generate 4 fast kernels:
row-major/column-major and beta==0/beta!=0

Then the for loop for k was replaced by a do-while loop, so it only works with k>0 (checked before).
Some 20 iterations before the end, much like various asm kernels, it'll prefetch relevant parts of C; I did not see any benefit prefetching A and B.
Next I also needed to fold the scaling into the c updater, replacing bli_tcopys with bli_tscal2s and bli_txpbys by bli_taxpbys.

I found that if I use a for loop instead of do-while or test for beta==0 inside the kernel the compiler spills the whole C-tile from registers onto the stack, but it'll keep it in registers with this approach.

Some tests on zen4 (single socket AMD EPYC 9534 64-Core Processor, Genoa) with GCC 13.3, CFLAGS="-march=native" ./configure generic on a 2400x2400 single-threaded dgemm:

  • original generic: 36.13 Gflops
  • generic with this PR: 44.96 Gflops
  • using column-major with KC,MC,NC copied from AOCL-BLAS' zen4 config: 57.67 Gflops (*)
  • AOCL-BLAS 5.0, pre-compiled GCC binary: 56.77 Gflops

(it was cool to beat AOCL-BLAS by a small amount, although of course there may be other cases where it won't!)

(*) this used
CFLAGS="-march=native -DBLIS_MR_d=32 -DBLIS_NR_d=6" ./configure generic and the following change:

--- a/ref_kernels/bli_cntx_ref.c
+++ b/ref_kernels/bli_cntx_ref.c
@@ -379,8 +379,8 @@ void GENBARNAME(cntx_init)
        bli_blksz_init     ( &blkszs[ BLIS_NR  ],     BLIS_NR_s,     BLIS_NR_d,     BLIS_NR_c,     BLIS_NR_z,
                                                  BLIS_PACKNR_s, BLIS_PACKNR_d, BLIS_PACKNR_c, BLIS_PACKNR_z );
        bli_blksz_init_easy( &blkszs[ BLIS_MC  ],           256,           128,           128,            64 );
-       bli_blksz_init_easy( &blkszs[ BLIS_KC  ],           256,           256,           256,           256 );
-       bli_blksz_init_easy( &blkszs[ BLIS_NC  ],          4096,          4096,          4096,          4096 );
+       bli_blksz_init_easy( &blkszs[ BLIS_KC  ],           256,           512,           256,           256 );
+       bli_blksz_init_easy( &blkszs[ BLIS_NC  ],          4096,          4002,          4096,          4096 );
        bli_blksz_init_easy( &blkszs[ BLIS_M2  ],          1000,          1000,          1000,          1000 );
        bli_blksz_init_easy( &blkszs[ BLIS_N2  ],          1000,          1000,          1000,          1000 );
        bli_blksz_init_easy( &blkszs[ BLIS_AF  ],             8,             8,             8,             8 );
@@ -447,7 +447,7 @@ void GENBARNAME(cntx_init)
        gen_func_init_ro( &funcs[ bli_ker_idx( BLIS_GEMMTRSM1M_U_UKR ) ], gemmtrsm1m_u_ukr_name );

        //                                                           s      d      c      z
-       bli_mbool_init( &mbools[ BLIS_GEMM_UKR_ROW_PREF ],        TRUE,  TRUE,  TRUE,  TRUE );
+       bli_mbool_init( &mbools[ BLIS_GEMM_UKR_ROW_PREF ],        TRUE, FALSE,  TRUE,  TRUE );
        bli_mbool_init( &mbools[ BLIS_GEMMTRSM_L_UKR_ROW_PREF ], FALSE, FALSE, FALSE, FALSE );
        bli_mbool_init( &mbools[ BLIS_GEMMTRSM_U_UKR_ROW_PREF ], FALSE, FALSE, FALSE, FALSE );
        bli_mbool_init( &mbools[ BLIS_TRSM_L_UKR_ROW_PREF ],     FALSE, FALSE, FALSE, FALSE );
@@ -552,4 +552,3 @@ void GENBARNAME(cntx_init)
        for ( dim_t i = 0; i < BLIS_NUM_LEVEL3_OPS; i++ )
                bli_cntx_set_l3_sup_handler( i, vfuncs[ i ], cntx );
 }

Use faster C kernels for the common special case m == mr, n == nr,
k > 0, where either cs_c == 1 or rs_c == 1.
@bartoldeman bartoldeman marked this pull request as draft March 9, 2025 17:33
const inc_t cs_a = PASTECH(BLIS_PACKMR_,ch); \
const inc_t rs_b = PASTECH(BLIS_PACKNR_,ch); \
\
char ab_[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))) = { 0 }; \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is copied from the reference kernel, this zero-init is redundant to the just-following loop L195-198. I suggest you check in the assembler whether the compiler did eliminate one of them (and did manage to vectorize the zeroing as well).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The zero-init is actually required for certain versions of clang since it improperly optimizes out some of the later zero assignments. See #854.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes exactly I saw #854 as well because this puzzled me too. But the compiler generated optimal code like this (for my case with MR=32, NR=6 for Zen4, from objdump -d:

     664:       c5 d1 57 ed             vxorpd %xmm5,%xmm5,%xmm5
     668:       48 89 c8                mov    %rcx,%rax
     66b:       48 89 e5                mov    %rsp,%rbp
     66e:       41 54                   push   %r12
     670:       53                      push   %rbx
     671:       41 bc 1a 00 00 00       mov    $0x1a,%r12d
     677:       48 8b 5d 10             mov    0x10(%rbp),%rbx
     67b:       49 29 fc                sub    %rdi,%r12
     67e:       62 f1 fd 48 28 f5       vmovapd %zmm5,%zmm6
     684:       62 f1 fd 48 28 fd       vmovapd %zmm5,%zmm7
     68a:       4c 89 e7                mov    %r12,%rdi
     68d:       62 71 fd 48 28 c5       vmovapd %zmm5,%zmm8
     693:       62 71 fd 48 28 cd       vmovapd %zmm5,%zmm9

and then continuing all the way up to zmm28, so 24 zero'd vectors of 8 doubles each, which is exactly 6*32.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I did not make the link to #854, hmm it looks a weird issue.

if ( i >= 0 && i < mr ) \
for ( dim_t j = 0; j < nr; j += CACHELINE_SIZE/sizeof(double) ) \
bli_prefetch( &c[ i*s_c + j ], 0, 3 ); \
for ( dim_t i = 0; i < mr; ++i ) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since PRAGMA_SIMD was used, why not make use of #pragma unroll(n) on MR-loop as well in order to fill in the core pipeline ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about unroll pragmas but its syntax is different between compilers, e.g. for GCC its #pragma GCC unroll n
OpenMP 5.1 has common syntax for it but not yet supported (will be e.g. in GCC 15 but not yet in 14). In the end GCC will unroll small loops with a constant number of iterations all by itself with -O3, which is used here by default so I had pretty optimal code generated here:

     730:       62 f1 fd 48 10 1a       vmovupd (%rdx),%zmm3
     736:       62 f1 fd 48 10 52 01    vmovupd 0x40(%rdx),%zmm2
     73d:       62 f1 fd 48 10 4a 02    vmovupd 0x80(%rdx),%zmm1
     744:       48 ff c1                inc    %rcx
     747:       62 f1 fd 48 10 42 03    vmovupd 0xc0(%rdx),%zmm0
     74e:       62 f2 fd 48 19 20       vbroadcastsd (%rax),%zmm4
     754:       48 81 c2 00 01 00 00    add    $0x100,%rdx
     75b:       48 83 c0 30             add    $0x30,%rax
     75f:       4c 01 df                add    %r11,%rdi
     762:       62 62 dd 48 b8 e3       vfmadd231pd %zmm3,%zmm4,%zmm28
     768:       62 62 dd 48 b8 da       vfmadd231pd %zmm2,%zmm4,%zmm27
     76e:       62 62 dd 48 b8 d1       vfmadd231pd %zmm1,%zmm4,%zmm26
     774:       62 62 dd 48 b8 c8       vfmadd231pd %zmm0,%zmm4,%zmm25
     77a:       62 f2 fd 48 19 60 fb    vbroadcastsd -0x28(%rax),%zmm4
     781:       62 62 e5 48 b8 c4       vfmadd231pd %zmm4,%zmm3,%zmm24
     787:       62 e2 ed 48 b8 fc       vfmadd231pd %zmm4,%zmm2,%zmm23
     78d:       62 e2 f5 48 b8 f4       vfmadd231pd %zmm4,%zmm1,%zmm22
     793:       62 e2 fd 48 b8 ec       vfmadd231pd %zmm4,%zmm0,%zmm21
     799:       62 f2 fd 48 19 60 fc    vbroadcastsd -0x20(%rax),%zmm4
     7a0:       62 e2 e5 48 b8 e4       vfmadd231pd %zmm4,%zmm3,%zmm20
     7a6:       62 e2 ed 48 b8 dc       vfmadd231pd %zmm4,%zmm2,%zmm19
     7ac:       62 e2 f5 48 b8 d4       vfmadd231pd %zmm4,%zmm1,%zmm18
     7b2:       62 e2 fd 48 b8 cc       vfmadd231pd %zmm4,%zmm0,%zmm17
     7b8:       62 f2 fd 48 19 60 fd    vbroadcastsd -0x18(%rax),%zmm4
     7bf:       62 e2 e5 48 b8 c4       vfmadd231pd %zmm4,%zmm3,%zmm16
     7c5:       62 72 ed 48 b8 fc       vfmadd231pd %zmm4,%zmm2,%zmm15
     7cb:       62 72 f5 48 b8 f4       vfmadd231pd %zmm4,%zmm1,%zmm14
     7d1:       62 72 fd 48 b8 ec       vfmadd231pd %zmm4,%zmm0,%zmm13
     7d7:       62 f2 fd 48 19 60 fe    vbroadcastsd -0x10(%rax),%zmm4
     7de:       62 72 e5 48 b8 e4       vfmadd231pd %zmm4,%zmm3,%zmm12
     7e4:       62 72 ed 48 b8 dc       vfmadd231pd %zmm4,%zmm2,%zmm11
     7ea:       62 72 f5 48 b8 d4       vfmadd231pd %zmm4,%zmm1,%zmm10
     7f0:       62 72 fd 48 b8 cc       vfmadd231pd %zmm4,%zmm0,%zmm9
     7f6:       62 f2 fd 48 19 60 ff    vbroadcastsd -0x8(%rax),%zmm4
     7fd:       62 72 dd 48 b8 c3       vfmadd231pd %zmm3,%zmm4,%zmm8
     803:       62 f2 dd 48 b8 fa       vfmadd231pd %zmm2,%zmm4,%zmm7
     809:       62 f2 dd 48 b8 f1       vfmadd231pd %zmm1,%zmm4,%zmm6
     80f:       62 f2 dd 48 b8 e8       vfmadd231pd %zmm0,%zmm4,%zmm5
     815:       49 39 ca                cmp    %rcx,%r10
     818:       7e 28                   jle    842 <bli_ddgemm_vect_c_generic_ref+0x1e2>
     81a:       49 8d 1c 0c             lea    (%r12,%rcx,1),%rbx
     81e:       48 83 fb 05             cmp    $0x5,%rbx
     822:       0f 87 08 ff ff ff       ja     730 <bli_ddgemm_vect_c_generic_ref+0xd0>
     828:       0f 18 0f                prefetcht0 (%rdi)
     82b:       0f 18 4f 40             prefetcht0 0x40(%rdi)
     82f:       0f 18 8f 80 00 00 00    prefetcht0 0x80(%rdi)
     836:       0f 18 8f c0 00 00 00    prefetcht0 0xc0(%rdi)
     83d:       e9 ee fe ff ff          jmp    730 <bli_ddgemm_vect_c_generic_ref+0xd0>

for ( dim_t i = 0; i < mr; ++i ) \
PRAGMA_SIMD \
for ( dim_t j = 0; j < nr; ++j ) \
taxpbys \
Copy link
Contributor

@hominhquan hominhquan Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably what brought you some gain compared to the reference kernel. The scaling-by-alpha is done at the same time of accumulation-to-c by using AXBY (FMA).

Now the point is: This reference kernel was written to be simple-and-stupid, easily comprehensible and not aimed to be fast. Do we really want to make it a little harder to new people to understand, in exchange of some percentage of performance, given it was not the original purpose of this kernel.

BTW, maybe I was a bit paranoid, and perhaps a simple comment saying // Scaling ab by alpha and accumulate to c with AXBY() suffices to help the reader.

I would prefer some direct modification in the original reference kernel (PRAGMA_UNROLL, remove redundant ab-zero-init, AXBY-alpha-scaling-accumulation, or even __builtin_prefetch()) with well-explained comments, rather than four new reference kernels which will take more time to further people to understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well there are already two original reference kernels; the slow version is the first one in the file called gemm_genxx already and gemm itself was already the fast path, as it says in the comment:

// An implementation that attempts to facilitate emission of vectorized
// instructions via constant loop bounds + #pragma omp simd directives.
// If compile-time MR/NR are not available (indicated by BLIS_[MN]R_x = -1),
// then the non-unrolled version (above) is used.

If I try to make the original fast path faster I simply don't get the same speed ups because the whole C tile is spilled to memory and I might as well not change anything.

An alternative also would be to have the 4 new kernels doing the only fast path, and let all other oddball cases use gemm_genxx, ie. changing the if in

        /* If either BLIS_MR_? or BLIS_NR_? was left undefined by the subconfig,
           the compiler can't fully unroll the MR and NR loop iterations below,
           which means there's no benefit to using this kernel over a general-
           purpose implementation instead. */ \
        if ( mr == -1 || nr == -1 || rs_a != 1 || cs_b != 1 ) \
        { \
                PASTEMAC(ch,ch,gemm_gen,arch,suf) \

to

      if ( mr != m || nr != n || rs_a != 1 || cs_b != 1 || (cs_c != 1 && rs_c != 1) || k == 0) \

or even flipping it around: the slow version is the main reference gemm and the fast version is called via that if statement flipped?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have to think if the k==0 restriction can be potentially lifted for all gemm kernels, in which case modifying the check as suggested would be fine.

@hominhquan
Copy link
Contributor

Thank you @bartoldeman for your contribution. I added some comments in review. Feel free to reply back.

@devinamatthews
Copy link
Member

@bartoldeman I haven't done as thorough a review as @hominhquan but I especially like that you were able to find a way to convince the compiler to keep the AB microtile in registers. This is something I had struggled with a lot and was the biggest deficiency compared to the hand-written kernels. (Except for icc which did some very strange things in the loop body, but it's deprecated now...)

Thinking of integration into BLIS, what would be really neat is if as much as possible about the kernel was configurable via macros, e.g. number of iterations before to prefetch C, row-major vs. column-major, etc. Then, if compilers tend to play nice in a portable way (we'd need to look at this), then it would be an excellent starting point for new architectures.

@bartoldeman
Copy link
Contributor Author

@devinamatthews I created a zen4 config using this reference kernel but it does mean adding it in various places through the general source code as well, not simply adding a few files under config/zen4. That said, bli_kernel_defs_<arch>.h could be a place to define CACHELINE_SIZE and TAIL_NITER perhaps. Right now I can already pass MR and NR to configure but would it be a good idea to be able to do that for some other constants too, to easily create a tuned generic kernel, without going more "heavy duty"? This is why it's POC.. I really am not familiar enough how to do that yet.

@devinamatthews
Copy link
Member

Yes just those two settings would be a great place to start. We'd just want to give them "BLISier" names and write some documentation. We'd want to find somebody to test on other architectures as well (I can only do Zen3, SKX, and Apple M1).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants