Skip to content

andrewhinh/flash-attn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

flash-attn

Flash Attention in Cuda

An implementation of Flash Attention in CUDA alongside a website to visualize its effect on input embeddings.

setup

# install uv
curl -LsSf https://astral.sh/uv/install.sh | sh

# install dependencies
uv sync

# for me, CUDA 12 (run `nvcc --version`) running on Linux x86_64 Ubuntu 22.04
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
sudo apt-get update
sudo apt-get -y install libcudnn9-dev-cuda-12

# (optional) install modal
uv add modal
uv run modal setup

usage

Compile and test queue in C++:

g++ src/mapqueue.cpp -o dist/mapqueue
./dist/mapqueue

Test queue with Python:

uv run src/mapqueue.py

Compile and test forward pass in C++:

nvcc -O3 -use_fast_math src/forward.cu -o dist/forward -L/usr/lib
./dist/forward

Test forward pass in Python:

uv run src/forward.py

or with Modal:

modal run src/forward.py

Run website locally:

uv run src/app.py

Serve on Modal:

modal serve src/app.py

Deploy on Modal:

modal deploy --env=main src/app.py

About

Flash Attention in Cuda

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages