Skip to content

Commit

Permalink
Add step to sLSTM layer, debug conv.
Browse files Browse the repository at this point in the history
  • Loading branch information
kpoeppel committed Jun 17, 2024
1 parent 6a8ac43 commit 515eec1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 99 deletions.
48 changes: 43 additions & 5 deletions xlstm/blocks/slstm/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,55 @@ def reset_parameters(self):
small_init_init_(self.zgate.weight, dim=self.config.embedding_dim)
small_init_init_(self.ogate.weight, dim=self.config.embedding_dim)

def step(
self,
x: torch.Tensor,
conv_state: Optional[torch.Tensor] = None,
slstm_state: Optional[torch.Tensor] = None,
):
B, S, _ = x.shape

if self.config.conv1d_kernel_size > 0:
x_conv, conv_state = self.conv1d.step(x, conv_state=conv_state)
x_conv = self.conv_act_fn(x_conv)
else:
x_conv = x

i, f, z, o = (
self.fgate(x_conv),
self.igate(x_conv),
self.zgate(x),
self.ogate(x),
)

y, last_state = self.slstm_cell(
torch.cat([i, f, z, o], dim=-1), state=slstm_state
)

y = self.dropout(y)

out = self.group_norm(y).transpose(1, 2).view(B, S, -1)

return out, last_state

def forward(
self,
x: torch.Tensor,
initial_state: Optional[torch.Tensor] = None,
conv_state: Optional[torch.Tensor] = None,
slstm_state: Optional[torch.Tensor] = None,
return_last_state=False,
**kwargs,
) -> torch.Tensor:
B, S, _ = x.shape

if self.config.conv1d_kernel_size > 0:
x_conv = self.conv_act_fn(self.conv1d(x))
if return_last_state:
x_conv = self.conv1d(x, conv_state, return_last_state=return_last_state)
else:
x_conv, conv_state = self.conv1d(
x, conv_state, return_last_state=return_last_state
)
x_conv = self.conv_act_fn(x_conv)
else:
x_conv = x

Expand All @@ -110,15 +148,15 @@ def forward(
self.ogate(x),
)

y, last_state = self.slstm_cell(
torch.cat([i, f, z, o], dim=-1), state=initial_state
y, slstm_state_state = self.slstm_cell(
torch.cat([i, f, z, o], dim=-1), state=slstm_state
)

y = self.dropout(y)

out = self.group_norm(y).transpose(1, 2).view(B, S, -1)

if return_last_state:
return out, last_state
return out, {"conv_state": conv_state, "slstm_state": slstm_state}
else:
return out
97 changes: 3 additions & 94 deletions xlstm/components/conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) NXAI GmbH and its affiliates 2024
# Maximilian Beck
# Maximilian Beck, Korbinian Pöppel
from dataclasses import dataclass, field
from typing import Optional

import torch

Expand Down Expand Up @@ -44,100 +45,8 @@ def conv1d_step(
), f"x has feature dimension {x.shape[2]} but conv_state has feature dimension {conv_state.shape[2]}"
assert x.shape[1] == 1, f"x has sequence length {x.shape[1]} but it should be 1"
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=1))
conv_state[:, -1, :] = x
conv_state[:, -1:, :] = x
y = torch.sum(conv_state * conv1d_weight, dim=1, keepdim=True)
if conv1d_bias is not None:
y += conv1d_bias
return y, conv_state


class CausalConv1d(nn.Module):
config_class = CausalConv1dConfig
"""
Implements causal depthwise convolution of a time series tensor.
Input: Tensor of shape (B,T,F), i.e. (batch, time, feature)
Output: Tensor of shape (B,T,F)
Args:
feature_dim: number of features in the input tensor
kernel_size: size of the kernel for the depthwise convolution
causal_conv_bias: whether to use bias in the depthwise convolution
channel_mixing: whether to use channel mixing (i.e. groups=1) or not (i.e. groups=feature_dim)
If True, it mixes the convolved features across channels.
If False, all the features are convolved independently.
"""

def __init__(self, config: CausalConv1dConfig):
super().__init__()
self.config = config
self.groups = self.config.feature_dim
if self.config.channel_mixing:
self.groups = 1
if self.config.kernel_size == 0:
self.conv = None # Noop
else:
self.pad = (
self.config.kernel_size - 1
) # padding of this size assures temporal causality.
self.conv = nn.Conv1d(
in_channels=self.config.feature_dim,
out_channels=self.config.feature_dim,
kernel_size=self.config.kernel_size,
padding=self.pad,
groups=self.groups,
bias=self.config.causal_conv_bias,
**self.config.conv1d_kwargs,
)
# B, C, L
self.reset_parameters()

def reset_parameters(self, **kwargs):
self.conv.reset_parameters()

def _create_weight_decay_optim_groups(
self,
) -> tuple[set[nn.Parameter], set[nn.Parameter]]:
if self.config.kernel_size == 0:
return (), ()
else:
weight_decay = (self.conv.weight,)
no_weight_decay = ()
if self.config.causal_conv_bias:
no_weight_decay += (self.conv.bias,)
return weight_decay, no_weight_decay

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.config.kernel_size == 0:
return x
y = x.transpose(2, 1) # (B,F,T) tensor - now in the right shape for conv layer.
y = self.conv(y) # (B,F,T+pad) tensor
# same as y[:, :, :T].transpose(2, 1) (this is how it is done in Mamba)
return y[:, :, : -self.pad].transpose(2, 1)

def step(
self,
x: torch.Tensor,
conv_state: tuple[torch.Tensor] = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor]]:

if self.config.kernel_size == 0:
return x, conv_state

B, S, D = x.shape

if conv_state is None:
conv_state = (
torch.zeros(
size=(B, self.config.kernel_size, D),
device=self.conv.weight.device,
dtype=self.conv.weight.dtype,
),
)

y, conv_state = conv1d_step(
x,
conv_state[0],
self.conv.weight[:, 0, :].transpose(0, 1), # rearrange(, "D 1 KS -> KS D")
conv1d_bias=self.conv.bias if self.config.causal_conv_bias else None,
)
return y, (conv_state,)

0 comments on commit 515eec1

Please sign in to comment.