Skip to content

Latest commit

 

History

History
 
 

flash_attention_cuda

flash attention implementin CUDA

NOTE: specific pytorch version require to support the deserted API. Just use the standalone version or CUTLASS version.

roadmap

  • naive self attention python
  • naive self attention cuda
  • naive self attention python API binding
    • TODO:
      • half support
      • make template data type more general
      • thread balance and too many thread may cause crash
      • clean deprecated warning
  • flash attention 1 cuda
  • flash attention 2 cuda
  • flash attention 1/2 python binding
  • split template and more general template(like dim and block size)
  • MHA support
  • causal mode support
  • flash attention cute
  • checkout static_switch.h in flash attention

result

  • You need result-oriented programming in CUDA
    • e.g. for C[x, y] should from thread (x, y)