Skip to content

Commit

Permalink
Make FlashAttention installation optional
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz committed Jul 14, 2022
1 parent 1c279f9 commit d07ae9c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,4 @@ dependencies:
- typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10
- wandb==0.12.21
- git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d
- git+https://github.com/NVIDIA/dllogger.git
25 changes: 24 additions & 1 deletion openfold/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import importlib
import ml_collections as mlc


Expand Down Expand Up @@ -36,6 +37,10 @@ def string_to_setting(s):
if(s1_setting and s2_setting):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")

fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
raise ValueError("use_flash requires that FlashAttention is installed")


def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
Expand All @@ -57,6 +62,24 @@ def model_config(name, train=False, low_prec=False):
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "finetuning_no_templ":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "finetuning_no_templ_ptm":
# AF2 Suppl. Table 4, "finetuning" setting
c.data.train.max_extra_msa = 5120
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 512
c.model.template.enabled = False
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif name == "model_1":
# AF2 Suppl. Table 5, Model 1.1.1
c.data.train.max_extra_msa = 5120
Expand Down Expand Up @@ -324,7 +347,7 @@ def model_config(name, train=False, low_prec=False):
"use_lma": False,
# Use FlashAttention in selected modules. Mutually exclusive with
# use_lma.
"use_flash": True,
"use_flash": False,
"offload_inference": False,
"c_z": c_z,
"c_m": c_m,
Expand Down
16 changes: 13 additions & 3 deletions openfold/model/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np

import deepspeed
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func

fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(fa_is_installed):
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attention import FlashAttention
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func

import torch
import torch.nn as nn
from scipy.stats import truncnorm
Expand Down Expand Up @@ -643,6 +648,11 @@ def _lma(

@torch.jit.ignore
def _flash_attn(q, k, v, kv_mask):
if(not fa_is_installed):
raise ValueError(
"_flash_attn requires that FlashAttention be installed"
)

batch_dims = q.shape[:-3]
no_heads, n, c = q.shape[-3:]
dtype = q.dtype
Expand Down
3 changes: 3 additions & 0 deletions scripts/install_third_party_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ lib/conda/bin/python3 -m pip install nvidia-pyindex
conda env create --name=${ENV_NAME} -f environment.yml
source activate ${ENV_NAME}

echo "Attempting to install FlashAttention"
pip install git+https://github.com/HazyResearch/flash-attention.git@5b838a8bef78186196244a4156ec35bbb58c337d && echo "Installation successful"

# Install DeepMind's OpenMM patch
OPENFOLD_DIR=$PWD
pushd lib/conda/envs/$ENV_NAME/lib/python3.7/site-packages/ \
Expand Down

0 comments on commit d07ae9c

Please sign in to comment.