-
naive版本,仅达到cublas 10%性能
-
v1 global分块到shared_mem --》 v2 shared_mem 分块 到register
从global_mem到shared_mem,对global mem访存太多,考虑分块放到shared mem。从global mem中读取次数为M=m/bm.N=n/bn.k=K/bk... M*N*K*(bm*bk + bk*bn) ---> m*n*k*(1/bm + 1/bn) 。不对shared mem分块,每个线程负责C中的一个元素的计算
-
v2 优化 减少寄存器ty 使用
storing
threadIdx.x
before re-using it massively. benefit the compiler optimization.没看出优化,可能优化太小 被掩盖了 -
v3 继续优化 reduce bank conflict
shared mem上读取数据的顺序,从列主序-> 行主序读取
-
v4 4*1 micro kernel
each thread to compute 4 elements , we restrict 256 threads for each TB.each thread loads a
4x1
A and an1x1
B and computesC(4x1)
+=A(4x1)*B(1x1)
-
v5 v4 + fp4 向量化读写
强转为float4,读写指令变少,对寄存器的需求变大,达不到最大并发warp数,可能会限制occupancy。 从shared_mem到register,分析一个block中的bm*bk + bk*bn次访存。对shared mem分块,成rm*rn的子矩阵,一个block中开启bm/rm * bn/rn个线程,每个线程负责rm*rn的子矩阵计算。bm*bk + bk*bn的小矩阵块 提前放到shared mem中。同理 分块之后就是v2,对shared mem的访存减少为 1/2 * (1/rm+1/rn )。
-
v6 每个thread计算 4x4矩阵 BMBNBK=64 64 16 RM,RN = 4,4
each thread to compute a
4x4
sub-matrix ofC
so we gain massive data re-use at the register level compared with the previous step. 输入足够大,让TBs每个block多计算能提升性能。 Here we increase{Ms,Ns}
from the previous{32,32}
to{64,64}
but decreased theKs
from32
to16
to maintain the same shared memory consumption -
v7 每个thread计算 8x8矩阵 BMBNBK=128 128 8 RM,RN = 8,8
more workloads for each TB AND each thread?
-
v8 warp level并发
{Mw,Nw}
={4xMr,8xNr}
提高并发性, -
v9 数据prefetch
从global中访存实际上是非常慢导致了latency。虽然GPU中可以通过block的切换来掩盖这种latency,但是由于分配的shared memory比较多,活跃的block并不太多,这种延时很难被掩盖。对于一个thread,需要计算一个rm∗rn 的小矩阵,但是必须先将数据从shared memory传到寄存器上,才能开始进行计算。所以导致了每进行一次迭代,计算单元就需要停下来等待,计算单元不能被喂饱。需要多开一个buffer,进行读写分离。
优化 | grid / block | input size mnk | kernel latency(ms) | performance% | performance | |||
---|---|---|---|---|---|---|---|---|
cublas | 2048x2048x2048 | 0.83312 | 1024x1024x1024 | 0.11318 | 100 | \ | \ | \ |
4096x4096x4096 | 6.196864 | 4096x4096x1024 | 1.85763 | 100 | \ | \ | \ | |
8192x8192x8192 | 51.726273 | 8192x8192x1024 | 6.14787 | 100 | \ | \ | \ | |
naïve | global mem读取 | (32,32) (32,32) | 1024x1024x1024 | 0.99981 | 11.3206 | \ | \ | \ |
(128,128) (32,32) | 4096x4096x1024 | 16.9926 | 10.932 | \ | \ | \ | ||
(256,256) (32,32) | 8192x8192x1024 | 67.735 | 9.07636 | \ | \ | \ | ||
v1 | 分块到shared_mem | (32,32) (32,32) | 1024x1024x1024 | 2.4361 | 4.64612 | \ | \ | \ |
(128,128) (32,32) | 4096x4096x1024 | 38.2188 | 4.86052 | \ | \ | \ | ||
(256,256) (32,32) | 8192x8192x1024 | 155.394 | 3.95631 | \ | \ | \ | ||
v2 | shared_mem plus | (32,32) (1024) | 1024x1024x1024 | 2.43501 | 4.6482 | \ | \ | \ |
减少寄存器ty 使用 | (128,128) (1024) | 4096x4096x1024 | 38.1966 | 4.86334 | \ | \ | \ | |
方便编译优化 | (256,256) (1024) | 8192x8192x1024 | 157.741 | 3.89745 | \ | \ | \ | |
v3 | shared_mem plus | (32,32) (1024) | 1024x1024x1024 | 1.35408 | 8.35874 | \ | \ | \ |
Reduce bank conflict | (128,128) (1024) | 4096x4096x1024 | 24.2794 | 7.65107 | \ | \ | \ | |
transpose B | (256,256) (1024) | 8192x8192x1024 | 87.3787 | 7.0359 | \ | \ | \ | |
v4 | v3 + micro kernel 4x1 | (32,32) (256) | 1024x1024x1024 | 0.47008 | 24.0776 | \ | \ | \ |
(128,128) (256) | 4096x4096x1024 | 7.12227 | 26.082 | \ | \ | \ | ||
(256,256) (256) | 8192x8192x1024 | 29.1672 | 21.078 | \ | \ | \ | ||
v5 | v4 + fp4 | (32,32) (256) | 1024x1024x1024 | 0.41395 | 27.3423 | \ | \ | \ |
vectorized load store | (128,128) (256) | 4096x4096x1024 | 6.22848 | 29.8248 | \ | \ | \ | |
(256,256) (256) | 8192x8192x1024 | 24.7156 | 24.8745 | \ | \ | \ | ||
v6 | fp4 | (32,32) (256) | 1024x1024x1024 | 0.18403 | 61.5023 | 2048x2048x2048 | 1.14 | 72.841 |
BM BN BK | 1个线程 4x4 sub-matrix |
(128,128) (256) | 4096x4096x1024 | 2.18477 | 85.0265 | 4096x4096x4096 | 9.33 | 66.436 |
64 64 16 | (256,256) (256) | 8192x8192x1024 | 8.66074 | 70.9856 | 8192x8192x8192 | 70.1 | 73.827 | |
V7 | fp4 | (32,32) (256) | 1024x1024x1024 | 0.18736 | 60.4099 | 2048x2048x2048 | 1.13 | 73.581 |
BM BN BK | 1个线程 8x8 sub-matrix |
(128,128) (256) | 4096x4096x1024 | 1.84573 | 100.645 | 4096x4096x4096 | 7.67 | 80.756 |
128 128 8 | (256,256) (256) | 8192x8192x1024 | 6.976 | 88.1289 | 8192x8192x8192 | 56.9 | 90.944 | |
V8 | fp4 | (32,32) (256) | 1024x1024x1024 | 0.17622 | 64.2273 | 2048x2048x2048 | 1.04 | 80.08 |
BM BN BK | 1个线程 8x8 sub-matrix |
(128,128) (256) | 4096x4096x1024 | 1.71296 | 108.446 | 4096x4096x4096 | 6.53 | 94.91 |
128 128 8 | warp并行 | (256,256) (256) | 8192x8192x1024 | 6.53965 | 94.0092 | 8192x8192x8192 | 56.3 | 91.9 |
V9 | prefetch data | (32,32) (256) | 1024x1024x1024 | 0.13584 | 83.3216 | 2048x2048x2048 | 0.93 | 89.9 |
(128,128) (256) | 4096x4096x1024 | 1.55453 | 119.498 | 4096x4096x4096 | 5.9 | 105 | ||
(256,256) (256) | 8192x8192x1024 | 5.98864 | 102.659 | 8192x8192x8192 | 51.9 | 99.73 |