Skip to content

Commit

Permalink
create a separate forward for language model pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 11, 2024
1 parent 0b341ea commit fff740b
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "transfusion-pytorch"
version = "0.0.35"
version = "0.0.36"
description = "Transfusion in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
28 changes: 27 additions & 1 deletion tests/test_transfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_auto_modality_transform(
randint_ = partial(randint, 0, text_tokens)

model = Transfusion(
num_text_tokens = 256,
num_text_tokens = text_tokens,
dim_latent = 384,
modality_token_transform = 'c h w -> (h w) c',
modality_default_length = 32,
Expand All @@ -87,3 +87,29 @@ def test_auto_modality_transform(
prime = [tensor(model.som_ids[0])]

one_multimodal_sample = model.sample(prime, max_length = 128)

@pytest.mark.parametrize('use_flex_attn', (False, True))
def test_auto_modality_transform(
use_flex_attn: bool
):

if use_flex_attn and not exists(flex_attention):
return pytest.skip()

model = Transfusion(
num_text_tokens = 256,
dim_latent = 384,
modality_token_transform = 'c h w -> (h w) c',
modality_default_length = 32,
transformer = dict(
dim = 512,
depth = 2,
use_flex_attn = use_flex_attn
)
)

text = randint(0, 256, (2, 1024))

loss = model(text)

loss.backward()
126 changes: 106 additions & 20 deletions transfusion_pytorch/transfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,26 +380,36 @@ def __init__(
def forward(
self,
x: Float['b n {self.dim}'],
cond: Float['b {self.dim_cond}'] | Float['b n {self.dim_cond}'],
is_any_modality: bool | Bool['b n'],
cond: Float['b {self.dim_cond}'] | Float['b n {self.dim_cond}'] | None = None,
is_any_modality: bool | Bool['b n'] | None = None,
**kwargs
):
if isinstance(is_any_modality, bool):
is_any_modality = torch.full((x.shape[:-1]), is_any_modality, device = x.device, dtype = torch.bool)
assert not (exists(cond) ^ exists(is_any_modality))

has_modality = exists(is_any_modality)

if has_modality:
if isinstance(is_any_modality, bool):
is_any_modality = torch.full((x.shape[:-1]), is_any_modality, device = x.device, dtype = torch.bool)

is_any_modality = rearrange(is_any_modality, '... -> ... 1')
is_any_modality = rearrange(is_any_modality, '... -> ... 1')

if cond.ndim == 2:
if exists(cond) and cond.ndim == 2:
cond = rearrange(cond, 'b d -> b 1 d')

x = self.layernorm(x)

gamma, beta = self.to_film(cond).chunk(2, dim = -1)
if has_modality:
gamma, beta = self.to_film(cond).chunk(2, dim = -1)

text_tokens = x * (self.layernorm_gamma + 1.)
modality_tokens = x * (gamma + 1.) + beta

x = torch.where(is_any_modality, modality_tokens, text_tokens)
if has_modality:
modality_tokens = x * (gamma + 1.) + beta

x = torch.where(is_any_modality, modality_tokens, text_tokens)
else:
x = text_tokens

# attention or feedforwards

Expand All @@ -413,9 +423,13 @@ def forward(
# take care of conditioning output separately for text vs modality

text_out = out * (self.layerscale + 1.)
modalities_out = out * self.to_ada_ln_zero(cond).sigmoid()

conditioned_out = torch.where(is_any_modality, modalities_out, text_out)
if has_modality:
modalities_out = out * self.to_ada_ln_zero(cond).sigmoid()

conditioned_out = torch.where(is_any_modality, modalities_out, text_out)
else:
conditioned_out = text_out

# take care of function returning cache

Expand Down Expand Up @@ -493,9 +507,12 @@ def forward(
attn_mask: Tensor | None = None,
rotary_emb: Tensor | None = None,
cache: Tensor | None = None,
causal = False,
block_mask = None,
return_kv_cache = False
):
device = x.device

assert not (exists(block_mask) and exists(attn_mask))

x = self.norm(x)
Expand All @@ -522,6 +539,7 @@ def forward(
# whether to use flex attention or not

if self.use_flex_attn:
assert not causal, 'causal mask should be constructed in transformer'

flex_attn_kwargs = dict(block_mask = block_mask)

Expand All @@ -536,8 +554,14 @@ def forward(

sim = softclamp(sim, self.softcap_value)

mask_value = -torch.finfo(sim.dtype).max

if causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)

if exists(attn_mask):
mask_value = -torch.finfo(sim.dtype).max
sim = einx.where('b i j, b h i j, -> b h i j', attn_mask, sim, mask_value)

attn = sim.softmax(dim = -1)
Expand Down Expand Up @@ -604,29 +628,42 @@ def __init__(
def forward(
self,
x,
times: Float[''] | Float['b'] | Float['b n'],
times: Float[''] | Float['b'] | Float['b n'] | None = None,
attn_mask: Bool['b i j'] | None = None,
modality_positions: RawModalityPositions | Int['b n 2'] | None = None,
is_any_modality: bool | Bool['b n'] | None = None,
rotary_emb: Tensor | None = None,
cache: Tensor | None = None,
causal_mask: bool = False,
return_kv_cache = False
):
batch, seq_len, device = x.shape[0], x.shape[-2], x.device
assert not (exists(attn_mask) and exists(modality_positions))

# handle time

if times.ndim == 0:
times = repeat(times, ' -> b', b = batch)
cond = None

if exists(times):
if times.ndim == 0:
times = repeat(times, ' -> b', b = batch)

cond = self.to_time_cond(times)
cond = self.to_time_cond(times)

# create the specialized mask needed for autoregressive text + bidirectional diffusion attention

attn_mask_kwargs = dict()

if causal_mask:
if self.use_flex_attn:
block_mask = create_block_mask(causal, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device)
attn_mask_kwargs.update(block_mask = block_mask)
else:
attn_mask_kwargs.update(causal = True)

if exists(modality_positions):
assert not causal_mask

if self.use_flex_attn:
transfusion_mask_fn = transfusion_attn_mask(modality_positions)
block_mask = create_block_mask(transfusion_mask_fn, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len, device = device)
Expand All @@ -635,8 +672,7 @@ def forward(
attn_mask = naive_attn_mask(seq_len, modality_positions, device = device)
attn_mask_kwargs.update(attn_mask = attn_mask)

if not exists(is_any_modality):
assert exists(modality_positions)
if not exists(is_any_modality) and exists(modality_positions):
is_any_modality = modality_positions_to_is_modality_mask(seq_len, modality_positions).any(dim = 1)
is_any_modality = reduce(is_any_modality, 'b t n -> b n', 'any')

Expand Down Expand Up @@ -996,9 +1032,54 @@ def ode_step_fn(step_times, denoised):

return modality_sample

def forward_text(
self,
text: Int['b n']
) -> Float['']:

device = self.device
text = text.to(device)

text, labels = text[:, :-1], text[:, 1:]

# embed text

text = text.masked_fill(text == -1, 0)
tokens = self.text_embed(text)

# rotary

seq_len = tokens.shape[-2]
pos = torch.arange(seq_len, device = device)

rotary_emb = self.rotary_emb(pos)

# attention

embed = self.transformer(
tokens,
rotary_emb = rotary_emb,
causal_mask = True
)

# text unembedding

logits = self.to_text_logits(embed)

loss = F.cross_entropy(
rearrange(logits, 'b n l -> b l n'),
labels,
ignore_index = self.ignore_index
)

return loss

def forward(
self,
modalities: list[ModalitySample],
modalities: (
list[ModalitySample] |
Int['b n']
),
times: (
Float['b m'] |
Callable[[Int['b m 3']], Float['b m']] | # allows a researcher to customize the times (noise level) based on the overall modality configuration of a sample
Expand All @@ -1018,8 +1099,13 @@ def forward(
tuple[Float[''], LossBreakdown]
):
is_decoding = exists(decoding_text_or_modality)
is_text_only = torch.is_tensor(modalities) and modalities.dtype in (torch.int, torch.long)

return_loss &= (not return_embed or not is_decoding)

return_loss &= not return_embed
if is_text_only:
assert return_loss
return self.forward_text(modalities)

device = self.device
tensor_ = partial(tensor, device = device)
Expand Down

0 comments on commit fff740b

Please sign in to comment.