Neighborhood Attention Extension
Bringing attention to a neighborhood near you!
NATTEN is an open-source project aimed at providing an interface to neighborhood attention, and more generally sliding window attention. If you're not familiar with neighborhood attention, we recommend referring to our papers, or watching our presentation on YouTube.
NATTEN currently works as an extension to PyTorch, but we plan to reduce dependency on the torch API and possibly support other deep learning frameworks in the future. NATTEN provides Neighborhood Attention (local attention) and Dilated Neighborhood Attention (sparse global attention, a.k.a. dilated local attention) as PyTorch modules for both 1D and 2D data.
We are finally releasing our new GEMM-based CUDA kernels, which depend on and are modeled after CUTLASS's Implicit GEMM kernels for convolution.
Note that these kernels were developed before the CUTLASS 3.0 release, and are therefore still following the CUTLASS 2.X structure. We plan to write new kernels based on CUTLASS 3.X and CUTE in the near future.
It means that if you're running on SM80 or higher (Ampere, Ada Lovelace, Hopper), you can start using our GEMM based kernels and see up to 10X improvement in latency. However, do note that their current float16/bfloat16 implementations do not typically result in improved latency, due to a memory alignment issue, which will be resolved in future releases.
NOTE: the table presents the average improvement in latency over different problem sizes with full precision (tfloat32).
Volta and earlier are not supported at this time, but feel free to open an issue if you're interested.
The new NATTEN is also heavily refactored to both continue to support older architectures with our naive kernels, and to accommodate our new kernels which only target SM80 and above.
We're still in the process of deciding the best way to roll out the new kernels via PyPi, which means you can't get these new kernels via pip. However, you can build NATTEN from source! Just look at the instructions below on building from source.
The new NATTEN library sets up constants that are binded to the python interface, which will allow you to check whether you've compiled with: a. CUDA, b. Float16 (half) support, c. Bfloat16 support, d. New GEMM kernels.
import natten
# Whether NATTEN was built with CUDA
print(natten.has_cuda())
# Whether NATTEN with CUDA was built with support for float16
print(natten.has_half())
# Whether NATTEN with CUDA was built with support for bfloat16
print(natten.has_bfloat())
# Whether NATTEN with CUDA was built with the new GEMM kernels
print(natten.has_gemm())
If natten.has_gemm()
returns true, by default NATTEN will call the faster GEMM kernels instead of the original naive kernels
for both NA1D and NA2D. 3D Neighborhood attention is not supported at this time, but you can still use the naive kernels.
In addition, we will be adding scripts that allow you to profile and observe latency from the kernels with those options available.
With the latest code refactor, naive kernels now support arbitrary kernel sizes, and support for bfloat16 (BF16) was also added.
Sliding window self attention mechanisms have been relatively overlooked, in part due to implementation difficulties. For example, in a paper proposing one of the earliest examples of such methods, SASA, it was noted that although such methods are theoretically efficient, they're relatively slow in practice, compared to convolutions, which have been implemented in most well-known deep learning libraries.
That is why we started developing NATTEN, an extension to existing libraries with efficient implementations of sliding window attention mechanisms, which will enable research in this direction including building powerful hierarchical vision transformers.
For more information, we highly recommend reading our preprints NAT and DiNAT, and check out their repository.
The latest version of NATTEN runs pretty fast on Ampere with the latest torch and CUDA versions.
- python >= 3.7
- torch >= 1.8
- cmake >= 3.20
NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, 3.9, 3.10(only torch >= 1.11), and 3.11 (only torch >= 1.13).
NOTE: The current version of NATTEN comes with Linux-only wheels, and supports Pascal and above (SM >= 60
, i.e. Tesla P100).
Make sure your GPU is supported by referring to
this webpage.
Future versions will extend support to older GPUs.
Just refer to our website, shi-labs.com/natten, select your PyTorch version and the CUDA version it was compiled with, copy-paste the command and install in seconds!
For example, if you're on torch==2.0.0+cu118
, you should install NATTEN using the following wheel:
pip3 install natten -f https://shi-labs.com/natten/wheels/cu118/torch2.0.0/index.html
More generally:
pip3 install natten -f https://shi-labs.com/natten/wheels/{cu_version}/torch{torch_version}/index.html
NOTE: If you do not specify a wheel URL, pip will collect NATTEN and try to compile on locally, which depending on your system might take up to 30 minutes. We strongly recommend using our website if you're a Linux user.
Unfortunately we are not yet able to build Mac wheels, but you can compile on install, so just run:
pip3 install natten
NATTEN now supports Windows devices with CUDA, but does not yet have Windows wheels. This means you need to clone this repository, and build NATTEN from source, as instructed below.
Once you've set up your Python environment and installed PyTorch with CUDA, simply clone and build:
git clone https://github.com/SHI-Labs/NATTEN
cd NATTEN
pip install -r requirements.txt
make
NOTE: NATTEN will use the PyTorch API to detect your GPU architecture, and will by default attempt to use 1/4th of the number of processes your system allows to build. You can override them by passing in the following arguments:
# Build with 2 workers/processes
make WORKERS=2
# Build targeting SM89 (Ada Lovelace)
make CUDA_ARCH="8.9"
Please also note that building with the latest GEMM kernels can be a bit time consuming, which means at least 10 - 20 minutes given that you use enough workers. It is technically possible to improve build time by generating more source files and using more workers (at the expense of generating a larger binary), but that option will be made available in the future.
You can optionally run unit tests to verify building from source finished successfully:
make test
- Neighborhood Attention 1D (CPU, naive)
- Neighborhood Attention 2D (CPU, naive)
- Neighborhood Attention 3D (CPU, naive)
- Neighborhood Attention 1D (CUDA, naive)
- Neighborhood Attention 2D (CUDA, naive)
- Neighborhood Attention 3D (CUDA, naive)
- Neighborhood Attention 1D (CUDA, gemm-based, SM80 and above)
- Neighborhood Attention 2D (CUDA, gemm-based, SM80 and above)
- Dilation support
- Float16 support and utilization
- BFloat16 support
- Windows builds
- Neighborhood Attention 1D (CUDA, fused kernels)
- Neighborhood Attention 2D (CUDA, fused kernels)
- Kepler and Maxwell (30<=SM<60) support
Simply import NeighborhoodAttention1D
, NeighborhoodAttention2D
, or NeighborhoodAttention3D
from natten
:
from natten import NeighborhoodAttention1D
from natten import NeighborhoodAttention2D
from natten import NeighborhoodAttention3D
na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4)
na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4)
na3d = NeighborhoodAttention3D(dim=128, kernel_size=7, dilation=2, num_heads=4)
NA3D also supports different kernel size and dilation values for depth:
na3d = NeighborhoodAttention3D(
dim=128,
kernel_size=7,
kernel_size_d=5,
dilation=2,
dilation_d=3,
num_heads=4)
Modules expect inputs of shape [batch_size, *, dim]
:
- NA1D:
[batch_size, sequence_length, dim]
- NA2D:
[batch_size, height, width, dim]
- NA3D:
[batch_size, depth, height, width, dim]
We recommend counting flops through fvcore.
pip install fvcore
Once you have fvcore installed, you can directly use our dedicated FLOP counter:
from natten.flops import get_flops
flops = get_flops(model, input)
Alternatively, if you are using fvcore's FlopCountAnalysis
directly, be sure to add our op handles:
from fvcore.nn import FlopCountAnalysis
from natten.flops import add_natten_handle
# ...
flop_ctr = FlopCountAnalysis(model, input)
flop_ctr = add_natten_handle(flop_ctr)
# ...
NATTEN is released under the MIT License.
@inproceedings{hassani2023neighborhood,
title = {Neighborhood Attention Transformer},
author = {Ali Hassani and Steven Walton and Jiachen Li and Shen Li and Humphrey Shi},
year = 2023,
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}
}
@article{hassani2022dilated,
title = {Dilated Neighborhood Attention Transformer},
author = {Ali Hassani and Humphrey Shi},
year = 2022,
url = {https://arxiv.org/abs/2209.15001},
eprint = {2209.15001},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
We would like to thank NVIDIA, and the CUTLASS project and team for their efforts in creating and open-sourcing CUTLASS. We would also like to thank Haicheng Wu for his valuable feedback and comments which led to the creation of Implicit GEMM NA. We also thank Meta, and the PyTorch project and team.