-
Notifications
You must be signed in to change notification settings - Fork 376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof-of-concept: speeding up gemm reference kernel #863
base: master
Are you sure you want to change the base?
Conversation
Use faster C kernels for the common special case m == mr, n == nr, k > 0, where either cs_c == 1 or rs_c == 1.
const inc_t cs_a = PASTECH(BLIS_PACKMR_,ch); \ | ||
const inc_t rs_b = PASTECH(BLIS_PACKNR_,ch); \ | ||
\ | ||
char ab_[ BLIS_STACK_BUF_MAX_SIZE ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))) = { 0 }; \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is copied from the reference kernel, this zero-init is redundant to the just-following loop L195-198. I suggest you check in the assembler whether the compiler did eliminate one of them (and did manage to vectorize the zeroing as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The zero-init is actually required for certain versions of clang since it improperly optimizes out some of the later zero assignments. See #854.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes exactly I saw #854 as well because this puzzled me too. But the compiler generated optimal code like this (for my case with MR=32, NR=6 for Zen4, from objdump -d
:
664: c5 d1 57 ed vxorpd %xmm5,%xmm5,%xmm5
668: 48 89 c8 mov %rcx,%rax
66b: 48 89 e5 mov %rsp,%rbp
66e: 41 54 push %r12
670: 53 push %rbx
671: 41 bc 1a 00 00 00 mov $0x1a,%r12d
677: 48 8b 5d 10 mov 0x10(%rbp),%rbx
67b: 49 29 fc sub %rdi,%r12
67e: 62 f1 fd 48 28 f5 vmovapd %zmm5,%zmm6
684: 62 f1 fd 48 28 fd vmovapd %zmm5,%zmm7
68a: 4c 89 e7 mov %r12,%rdi
68d: 62 71 fd 48 28 c5 vmovapd %zmm5,%zmm8
693: 62 71 fd 48 28 cd vmovapd %zmm5,%zmm9
and then continuing all the way up to zmm28
, so 24 zero'd vectors of 8 doubles each, which is exactly 6*32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, I did not make the link to #854, hmm it looks a weird issue.
if ( i >= 0 && i < mr ) \ | ||
for ( dim_t j = 0; j < nr; j += CACHELINE_SIZE/sizeof(double) ) \ | ||
bli_prefetch( &c[ i*s_c + j ], 0, 3 ); \ | ||
for ( dim_t i = 0; i < mr; ++i ) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since PRAGMA_SIMD
was used, why not make use of #pragma unroll(n)
on MR-loop as well in order to fill in the core pipeline ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about unroll pragmas but its syntax is different between compilers, e.g. for GCC its #pragma GCC unroll n
OpenMP 5.1 has common syntax for it but not yet supported (will be e.g. in GCC 15 but not yet in 14). In the end GCC will unroll small loops with a constant number of iterations all by itself with -O3
, which is used here by default so I had pretty optimal code generated here:
730: 62 f1 fd 48 10 1a vmovupd (%rdx),%zmm3
736: 62 f1 fd 48 10 52 01 vmovupd 0x40(%rdx),%zmm2
73d: 62 f1 fd 48 10 4a 02 vmovupd 0x80(%rdx),%zmm1
744: 48 ff c1 inc %rcx
747: 62 f1 fd 48 10 42 03 vmovupd 0xc0(%rdx),%zmm0
74e: 62 f2 fd 48 19 20 vbroadcastsd (%rax),%zmm4
754: 48 81 c2 00 01 00 00 add $0x100,%rdx
75b: 48 83 c0 30 add $0x30,%rax
75f: 4c 01 df add %r11,%rdi
762: 62 62 dd 48 b8 e3 vfmadd231pd %zmm3,%zmm4,%zmm28
768: 62 62 dd 48 b8 da vfmadd231pd %zmm2,%zmm4,%zmm27
76e: 62 62 dd 48 b8 d1 vfmadd231pd %zmm1,%zmm4,%zmm26
774: 62 62 dd 48 b8 c8 vfmadd231pd %zmm0,%zmm4,%zmm25
77a: 62 f2 fd 48 19 60 fb vbroadcastsd -0x28(%rax),%zmm4
781: 62 62 e5 48 b8 c4 vfmadd231pd %zmm4,%zmm3,%zmm24
787: 62 e2 ed 48 b8 fc vfmadd231pd %zmm4,%zmm2,%zmm23
78d: 62 e2 f5 48 b8 f4 vfmadd231pd %zmm4,%zmm1,%zmm22
793: 62 e2 fd 48 b8 ec vfmadd231pd %zmm4,%zmm0,%zmm21
799: 62 f2 fd 48 19 60 fc vbroadcastsd -0x20(%rax),%zmm4
7a0: 62 e2 e5 48 b8 e4 vfmadd231pd %zmm4,%zmm3,%zmm20
7a6: 62 e2 ed 48 b8 dc vfmadd231pd %zmm4,%zmm2,%zmm19
7ac: 62 e2 f5 48 b8 d4 vfmadd231pd %zmm4,%zmm1,%zmm18
7b2: 62 e2 fd 48 b8 cc vfmadd231pd %zmm4,%zmm0,%zmm17
7b8: 62 f2 fd 48 19 60 fd vbroadcastsd -0x18(%rax),%zmm4
7bf: 62 e2 e5 48 b8 c4 vfmadd231pd %zmm4,%zmm3,%zmm16
7c5: 62 72 ed 48 b8 fc vfmadd231pd %zmm4,%zmm2,%zmm15
7cb: 62 72 f5 48 b8 f4 vfmadd231pd %zmm4,%zmm1,%zmm14
7d1: 62 72 fd 48 b8 ec vfmadd231pd %zmm4,%zmm0,%zmm13
7d7: 62 f2 fd 48 19 60 fe vbroadcastsd -0x10(%rax),%zmm4
7de: 62 72 e5 48 b8 e4 vfmadd231pd %zmm4,%zmm3,%zmm12
7e4: 62 72 ed 48 b8 dc vfmadd231pd %zmm4,%zmm2,%zmm11
7ea: 62 72 f5 48 b8 d4 vfmadd231pd %zmm4,%zmm1,%zmm10
7f0: 62 72 fd 48 b8 cc vfmadd231pd %zmm4,%zmm0,%zmm9
7f6: 62 f2 fd 48 19 60 ff vbroadcastsd -0x8(%rax),%zmm4
7fd: 62 72 dd 48 b8 c3 vfmadd231pd %zmm3,%zmm4,%zmm8
803: 62 f2 dd 48 b8 fa vfmadd231pd %zmm2,%zmm4,%zmm7
809: 62 f2 dd 48 b8 f1 vfmadd231pd %zmm1,%zmm4,%zmm6
80f: 62 f2 dd 48 b8 e8 vfmadd231pd %zmm0,%zmm4,%zmm5
815: 49 39 ca cmp %rcx,%r10
818: 7e 28 jle 842 <bli_ddgemm_vect_c_generic_ref+0x1e2>
81a: 49 8d 1c 0c lea (%r12,%rcx,1),%rbx
81e: 48 83 fb 05 cmp $0x5,%rbx
822: 0f 87 08 ff ff ff ja 730 <bli_ddgemm_vect_c_generic_ref+0xd0>
828: 0f 18 0f prefetcht0 (%rdi)
82b: 0f 18 4f 40 prefetcht0 0x40(%rdi)
82f: 0f 18 8f 80 00 00 00 prefetcht0 0x80(%rdi)
836: 0f 18 8f c0 00 00 00 prefetcht0 0xc0(%rdi)
83d: e9 ee fe ff ff jmp 730 <bli_ddgemm_vect_c_generic_ref+0xd0>
for ( dim_t i = 0; i < mr; ++i ) \ | ||
PRAGMA_SIMD \ | ||
for ( dim_t j = 0; j < nr; ++j ) \ | ||
taxpbys \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably what brought you some gain compared to the reference kernel. The scaling-by-alpha is done at the same time of accumulation-to-c by using AXBY (FMA).
Now the point is: This reference kernel was written to be simple-and-stupid, easily comprehensible and not aimed to be fast. Do we really want to make it a little harder to new people to understand, in exchange of some percentage of performance, given it was not the original purpose of this kernel.
BTW, maybe I was a bit paranoid, and perhaps a simple comment saying // Scaling ab by alpha and accumulate to c with AXBY()
suffices to help the reader.
I would prefer some direct modification in the original reference kernel (PRAGMA_UNROLL, remove redundant ab-zero-init, AXBY-alpha-scaling-accumulation, or even __builtin_prefetch()
) with well-explained comments, rather than four new reference kernels which will take more time to further people to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well there are already two original reference kernels; the slow version is the first one in the file called gemm_genxx
already and gemm
itself was already the fast path, as it says in the comment:
// An implementation that attempts to facilitate emission of vectorized
// instructions via constant loop bounds + #pragma omp simd directives.
// If compile-time MR/NR are not available (indicated by BLIS_[MN]R_x = -1),
// then the non-unrolled version (above) is used.
If I try to make the original fast path faster I simply don't get the same speed ups because the whole C tile is spilled to memory and I might as well not change anything.
An alternative also would be to have the 4 new kernels doing the only fast path, and let all other oddball cases use gemm_genxx
, ie. changing the if in
/* If either BLIS_MR_? or BLIS_NR_? was left undefined by the subconfig,
the compiler can't fully unroll the MR and NR loop iterations below,
which means there's no benefit to using this kernel over a general-
purpose implementation instead. */ \
if ( mr == -1 || nr == -1 || rs_a != 1 || cs_b != 1 ) \
{ \
PASTEMAC(ch,ch,gemm_gen,arch,suf) \
to
if ( mr != m || nr != n || rs_a != 1 || cs_b != 1 || (cs_c != 1 && rs_c != 1) || k == 0) \
or even flipping it around: the slow version is the main reference gemm and the fast version is called via that if statement flipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will have to think if the k==0
restriction can be potentially lifted for all gemm kernels, in which case modifying the check as suggested would be fine.
Thank you @bartoldeman for your contribution. I added some comments in review. Feel free to reply back. |
@bartoldeman I haven't done as thorough a review as @hominhquan but I especially like that you were able to find a way to convince the compiler to keep the AB microtile in registers. This is something I had struggled with a lot and was the biggest deficiency compared to the hand-written kernels. (Except for icc which did some very strange things in the loop body, but it's deprecated now...) Thinking of integration into BLIS, what would be really neat is if as much as possible about the kernel was configurable via macros, e.g. number of iterations before to prefetch C, row-major vs. column-major, etc. Then, if compilers tend to play nice in a portable way (we'd need to look at this), then it would be an excellent starting point for new architectures. |
@devinamatthews I created a |
Yes just those two settings would be a great place to start. We'd just want to give them "BLISier" names and write some documentation. We'd want to find somebody to test on other architectures as well (I can only do Zen3, SKX, and Apple M1). |
Related issue: #259
This proof of concept is the result of me playing around a bit with reference kernels to better understand the underlying algorithms used for GEMM by BLIS and OpenBLAS for an upcoming talk ( https://easybuild.io/eum25/#linalg ): with blislab I got close to peak performance with just a kernel written in C.
This is a proof-of-concept since I'm not quite sure how to integrate parts of it, particularly the prefetch stuff, and I may have abused the C preprocessor a bit too much, although other things may be straight forward.
So here's the idea:
via a macro, generate 4 fast kernels:
row-major/column-major and
beta==0
/beta!=0
Then the
for
loop fork
was replaced by ado
-while
loop, so it only works withk>0
(checked before).Some 20 iterations before the end, much like various asm kernels, it'll prefetch relevant parts of C; I did not see any benefit prefetching A and B.
Next I also needed to fold the scaling into the c updater, replacing
bli_tcopys
withbli_tscal2s
andbli_txpbys
bybli_taxpbys
.I found that if I use a
for
loop instead ofdo
-while
or test forbeta==0
inside the kernel the compiler spills the whole C-tile from registers onto the stack, but it'll keep it in registers with this approach.Some tests on zen4 (single socket AMD EPYC 9534 64-Core Processor, Genoa) with GCC 13.3,
CFLAGS="-march=native" ./configure generic
on a 2400x2400 single-threaded dgemm:(it was cool to beat AOCL-BLAS by a small amount, although of course there may be other cases where it won't!)
(*) this used
CFLAGS="-march=native -DBLIS_MR_d=32 -DBLIS_NR_d=6" ./configure generic
and the following change: