Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Torch2 (mosaicml#177)

* make triton attn req mlri tagged triton

* add comment

* updt err

* clean up req / install

* updt

* updt

* exclude HazyR flash attn from pyright

* lint

* exclude flash_attn_triton.py from pyright

* updt torch version

* updt install instructions

* updt

* add extra install instructions for installing CMake

* lint

* updt

* updt torch

* updt

* adding torch1.13 and torch2 testing matrix

* Update pr-gpu.yaml

* Update test_model.py

* Update pr-cpu.yaml

* Update pr-gpu.yaml

* Update test_dataloader.py

* Update pr-gpu.yaml
  • Loading branch information
vchiley authored May 19, 2023
1 parent a4bae28 commit bb7f8bb
Show file tree
Hide file tree
Showing 12 changed files with 869 additions and 23 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ jobs:
strategy:
matrix:
include:
- name: 'cpu'
container: mosaicml/pytorch:latest
- name: 'cpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
- name: 'cpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
markers: 'not gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ jobs:
strategy:
matrix:
include:
- name: 'gpu'
container: mosaicml/pytorch:latest
- name: 'gpu-latest'
container: mosaicml/pytorch:latest # mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
- name: 'gpu-2.0.1'
container: mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04
markers: 'gpu'
pytest_command: 'coverage run -m pytest'
name: ${{ matrix.name }}
Expand Down
7 changes: 4 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ jobs:
PYPI_PACKAGE_NAME="llm-foundry-test-$(date +%Y%m%d%H%M%S)"
fi
# Remove the xentropy-cuda-lib dependency as PyPI does not support direct installs. The
# error message for importing FusedCrossEntropy gives instructions on how to install if a
# user tries to use it without this dependency.
# Remove the xentropy-cuda-lib and triton-pre-mlir dependencies as PyPI does not support
# direct installs. The error message for importing FusedCrossEntropy gives instructions
# on how to install if a user tries to use it without this dependency.
sed '/xentropy-cuda-lib@git+https:\/\/github.com\/HazyResearch\/flash-attention.git@.*/d' -i setup.py
sed '/triton-pre-mlir@git+https:\/\/github.com\/vchiley\/triton.git@.*/d' -i setup.py
python -m pip install --upgrade build twine
python -m build
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
default_language_version:
python: python3
exclude: llmfoundry/models/layers/flash_attn_triton.py
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ Here's what you need to get started with our LLM stack:

# Installation

This assumes you already have PyTorch and CMake installed.

To get started, clone this repo and install the requirements:

<!--pytest.mark.skip-->
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.layers import flash_attn_triton
from llmfoundry.models.layers.attention import (
ATTN_CLASS_REGISTRY, MultiheadAttention, MultiQueryAttention,
attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
Expand All @@ -9,6 +10,7 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY, LPLayerNorm

__all__ = [
'flash_attn_triton',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,13 @@ def triton_flash_attn_fn(
multiquery=False,
):
try:
from flash_attn import flash_attn_triton # type: ignore
from llmfoundry.models.layers import flash_attn_triton # type: ignore
except:
raise RuntimeError(
'Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202'
)
raise ValueError(
'Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU '
'and `pip install .[gpu]` if installing from source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` '
'if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). '
'Note: (1) requires you have CMake and PyTorch already installed.')

check_valid_inputs(query, key, value)

Expand Down
Loading

0 comments on commit bb7f8bb

Please sign in to comment.