Skip to content

Commit

Permalink
move axial positional embedding to a factorized version in a reusable…
Browse files Browse the repository at this point in the history
… lib
  • Loading branch information
lucidrains committed Jan 4, 2025
1 parent 8b2d0be commit 1a14b55
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "transfusion-pytorch"
version = "0.8.0"
version = "0.9.1"
description = "Transfusion in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand All @@ -24,6 +24,7 @@ classifiers=[
]

dependencies = [
'axial-positional-embedding>=0.3.4',
'beartype',
'einx>=0.3.0',
'einops>=0.8.0',
Expand Down
28 changes: 7 additions & 21 deletions transfusion_pytorch/transfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

from ema_pytorch import EMA

from axial_positional_embedding import ContinuousAxialPositionalEmbedding

from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb

from hyper_connections import HyperConnections
Expand Down Expand Up @@ -530,15 +532,10 @@ def __init__(
expand_factor = 2.
):
super().__init__()
self.num_dimensions = num_dimensions
dim_hidden = int(dim * expand_factor)

self.mlp = nn.Sequential(
nn.Linear(num_dimensions, dim),
nn.SiLU(),
nn.Linear(dim, dim_hidden),
nn.SiLU(),
nn.Linear(dim_hidden, dim)
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(
dim = dim,
num_axial_dims = num_dimensions,
mlp_expansion = expand_factor
)

# tensor typing
Expand All @@ -556,18 +553,7 @@ def forward(
flatten_dims = False
) -> Float['... {self._d}']:

if isinstance(modality_shape, torch.Size):
modality_shape = tensor(modality_shape)

modality_shape = modality_shape.to(self.device)

assert len(modality_shape) == self.num_dimensions
dimensions = modality_shape.tolist()

grid = torch.meshgrid([torch.arange(dim_len, device = self.device) for dim_len in dimensions], indexing = 'ij')
axial_positions = stack(grid, dim = -1)

pos_emb = self.mlp(axial_positions.float())
pos_emb = self.axial_pos_emb(modality_shape)

if flatten_dims:
pos_emb = rearrange(pos_emb, '... d -> (...) d')
Expand Down

0 comments on commit 1a14b55

Please sign in to comment.