Skip to content

Commit 5a76acb

Browse files
committedJan 6, 2025
refactor math attn to attn fn
1 parent 4715633 commit 5a76acb

File tree

4 files changed

+19
-22
lines changed

4 files changed

+19
-22
lines changed
 

‎src/zeroband/models/llama/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# Llama 2 is licensed under the LLAMA 2 Community License,
88
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
99

10-
from zeroband.models.llama.model import ModelArgs, Transformer
10+
from zeroband.models.llama.model import AttnFnType, ModelArgs, Transformer
1111

1212
__all__ = ["Transformer"]
1313

@@ -85,7 +85,7 @@ def get_model(
8585
type_model: str,
8686
vocab_size: int,
8787
seq_length: int,
88-
math_attn: bool,
88+
attn_fn: AttnFnType,
8989
) -> tuple[Transformer, ModelArgs]:
9090
"""get the transformer model"""
9191

@@ -98,6 +98,6 @@ def get_model(
9898

9999
config.vocab_size = vocab_size
100100
config.max_seq_len = seq_length
101-
config.math_attn = math_attn
101+
config.attn_fn = attn_fn
102102

103103
return Transformer(config), config

‎src/zeroband/models/llama/model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import contextlib
1515
from dataclasses import dataclass
16-
from typing import Optional, Tuple
16+
from typing import Literal, Optional, Tuple, TypeAlias
1717

1818
import torch
1919
import torch.nn.functional as F
@@ -41,6 +41,9 @@ def flex_attention_compiled(
4141
return _flex_attention_compiled(q, k, v, block_mask=block_mask)
4242

4343

44+
AttnFnType: TypeAlias = Literal["flex", "math"]
45+
46+
4447
@dataclass
4548
class ModelArgs:
4649
dim: int = 4096
@@ -60,7 +63,7 @@ class ModelArgs:
6063
depth_init: bool = True
6164
norm_type: str = "fused_rmsnorm"
6265

63-
math_attn: bool = False # slow for testing
66+
attn_fn: AttnFnType = "flex" # slow for testing
6467

6568

6669
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
@@ -226,7 +229,7 @@ def __init__(self, model_args: ModelArgs):
226229
self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False)
227230
self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False)
228231

229-
self.math_attn = model_args.math_attn
232+
self.attn_fn = model_args.attn_fn
230233

231234
def init_weights(self, init_std: float):
232235
for linear in (self.wq, self.wk, self.wv):
@@ -277,7 +280,7 @@ def forward(
277280
return self.wo(output)
278281

279282
def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor:
280-
with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext():
283+
with sdpa_kernel(SDPBackend.MATH) if self.attn_fn == "math" else contextlib.nullcontext():
281284
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
282285
output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim)
283286
return output

‎src/zeroband/train.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from typing import Literal
33
import time
4-
import warnings
54
from pydantic import model_validator
65
from multiprocessing.process import _children
76

@@ -19,7 +18,7 @@
1918
from zeroband.diloco import Diloco, DilocoConfig
2019
from zeroband.comms import ElasticDeviceMesh
2120
from zeroband.loss import cross_entropy_max_z_loss
22-
from zeroband.models.llama.model import create_block_mask_from_seqlens
21+
from zeroband.models.llama.model import AttnFnType, create_block_mask_from_seqlens
2322

2423
from zeroband.utils import (
2524
FakeTokenizer,
@@ -74,16 +73,8 @@ class TrainConfig(BaseConfig):
7473
memory_profiler: MemoryProfilerConfig | None = None
7574

7675
sequence_packing: bool = True
77-
attn_fn: Literal["flash", "sdpa"] | None = None
7876

79-
math_attn: bool = False # slow
80-
81-
@model_validator(mode="after")
82-
def validate_attn_fn(self):
83-
if self.attn_fn is not None:
84-
warnings.warn("attn_fn argument is deprecated")
85-
86-
return self
77+
attn_fn: AttnFnType = "flex"
8778

8879

8980
class MonitorConfig(BaseConfig):
@@ -200,7 +191,7 @@ def train(config: Config):
200191
config.type_model,
201192
vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE,
202193
seq_length=config.data.seq_length,
203-
math_attn=config.train.math_attn,
194+
attn_fn=config.train.attn_fn,
204195
)
205196

206197
model = model.to(world_info.local_rank)

‎tests/test_torchrun/test_train.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def test_ckpt(tmp_path: Path):
144144
"20",
145145
"--train.log_model_hash",
146146
"--no-train.sequence_packing",
147-
"--train.math_attn",
147+
"--train.attn_fn",
148+
"math",
148149
],
149150
diloco=True,
150151
)
@@ -164,7 +165,8 @@ def test_ckpt(tmp_path: Path):
164165
"20",
165166
"--train.log_model_hash",
166167
"--no-train.sequence_packing",
167-
"--train.math_attn",
168+
"--train.attn_fn",
169+
"math",
168170
],
169171
diloco=True,
170172
)
@@ -184,7 +186,8 @@ def test_ckpt(tmp_path: Path):
184186
# "20",
185187
# "--train.log_model_hash",
186188
# "--no-train.sequence_packing",
187-
# "--train.math_attn",
189+
# "--train.attn_fn",
190+
# "math",
188191
# ],
189192
# diloco=True,
190193
# )

0 commit comments

Comments
 (0)
Please sign in to comment.