Skip to content

Commit

Permalink
Hardware Aware Mamba + Long Term Forecasting
Browse files Browse the repository at this point in the history
  • Loading branch information
frecklebars committed Mar 30, 2024
1 parent 22175f6 commit 69ffc6c
Show file tree
Hide file tree
Showing 15 changed files with 459 additions and 132 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,5 @@ data_loader_all.py
/utils/self_tools.py
/scripts/exp_scripts/

/checkpoints/
/checkpoints/
/results/
3 changes: 2 additions & 1 deletion exp/exp_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from models import Autoformer, Transformer, TimesNet, Nonstationary_Transformer, DLinear, FEDformer, \
Informer, LightTS, Reformer, ETSformer, Pyraformer, PatchTST, MICN, Crossformer, FiLM, iTransformer, \
Koopa, TiDE, FreTS, Mamba
Koopa, TiDE, FreTS, MambaSimple, Mamba


class Exp_Basic(object):
Expand All @@ -28,6 +28,7 @@ def __init__(self, args):
'Koopa': Koopa,
'TiDE': TiDE,
'FreTS': FreTS,
'MambaSimple': MambaSimple,
'Mamba': Mamba,
}
self.device = self._acquire_device()
Expand Down
145 changes: 15 additions & 130 deletions models/Mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,163 +3,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, einsum

from layers.Embed import DataEmbedding
from mamba_ssm import Mamba

from layers.Embed import DataEmbedding

class Model(nn.Module):
"""
Mamba, linear-time sequence modeling with selective state spaces O(L)
Paper link: https://arxiv.org/abs/2312.00752
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
"""


def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len

self.d_inner = configs.d_model * configs.expand
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"

self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)

self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
self.norm = RMSNorm(configs.d_model)
self.mamba = Mamba(
d_model = configs.d_model,
d_state = configs.d_ff,
d_conv = configs.d_conv,
expand = configs.expand,
)

self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)

def short_term_forecast(self, x_enc, x_mark_enc):
def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc

x = self.embedding(x_enc, x_mark_enc)
for layer in self.layers:
x = layer(x)

x = self.norm(x)
x = self.mamba(x)
x_out = self.out_layer(x)

x_out = x_out * std_enc + mean_enc
return x_out

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'short_term_forecast':
x_out = self.short_term_forecast(x_enc, x_mark_enc)
# print(f"MAMBA FORECAST SIZE: {x_out.shape}")
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
return x_out[:, -self.pred_len:, :]

# other tasks not implemented


class ResidualBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(ResidualBlock, self).__init__()

self.mixer = MambaBlock(configs, d_inner, dt_rank)
self.norm = RMSNorm(configs.d_model)

def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output

class MambaBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(MambaBlock, self).__init__()
self.d_inner = d_inner
self.dt_rank = dt_rank

self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)

self.conv1d = nn.Conv1d(
in_channels = self.d_inner,
out_channels = self.d_inner,
bias = True,
kernel_size = configs.d_conv,
padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work?
groups = self.d_inner
)

# takes in x and outputs the input-specific delta, B, C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)

# projects delta
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))

self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)

def forward(self, x):
"""
Figure 3 in Section 3.4 in the paper
"""
(b, l, d) = x.shape

x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)

x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, "b d l -> b l d")

x = F.silu(x)

y = self.ssm(x)
y = y * F.silu(res)

output = self.out_proj(y)
return output


def ssm(self, x):
"""
Algorithm 2 in Section 3.2 in the paper
"""

(d_in, n) = self.A_log.shape

A = -torch.exp(self.A_log.float()) # [d_in, n]
D = self.D.float() # [d_in]

x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
y = self.selective_scan(x, delta, A, B, C, D)

return y

def selective_scan(self, u, delta, A, B, C, D):
(b, l, d_in) = u.shape
n = A.shape[1]

deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"

# selective scan, sequential instead of parallel
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
ys.append(y)

y = torch.stack(ys, dim=1) # [B, L, d_in]
y = y + u * D

return y

class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))

def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
# other tasks not implemented
175 changes: 175 additions & 0 deletions models/MambaSimple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, einsum

from layers.Embed import DataEmbedding


class Model(nn.Module):
"""
Mamba, linear-time sequence modeling with selective state spaces O(L)
Paper link: https://arxiv.org/abs/2312.00752
Implementation refernce: https://github.com/johnma2006/mamba-minimal/
"""

def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len

self.d_inner = configs.d_model * configs.expand
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto"

self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout)

self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)])
self.norm = RMSNorm(configs.d_model)

self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False)

# def short_term_forecast(self, x_enc, x_mark_enc):
def forecast(self, x_enc, x_mark_enc):
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc = x_enc / std_enc

x = self.embedding(x_enc, x_mark_enc)
for layer in self.layers:
x = layer(x)

x = self.norm(x)
x_out = self.out_layer(x)

x_out = x_out * std_enc + mean_enc
return x_out

# def long_term_forecast(self, x_enc, x_mark_enc):
# x = self.embedding(x_enc, x_mark_enc)
# for layer in self.layers:
# x = layer(x)

# x = self.norm(x)
# x_out = self.out_layer(x)
# return x_out

def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name in ['short_term_forecast', 'long_term_forecast']:
x_out = self.forecast(x_enc, x_mark_enc)
return x_out[:, -self.pred_len:, :]


# other tasks not implemented


class ResidualBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(ResidualBlock, self).__init__()

self.mixer = MambaBlock(configs, d_inner, dt_rank)
self.norm = RMSNorm(configs.d_model)

def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output

class MambaBlock(nn.Module):
def __init__(self, configs, d_inner, dt_rank):
super(MambaBlock, self).__init__()
self.d_inner = d_inner
self.dt_rank = dt_rank

self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False)

self.conv1d = nn.Conv1d(
in_channels = self.d_inner,
out_channels = self.d_inner,
bias = True,
kernel_size = configs.d_conv,
padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work?
groups = self.d_inner
)

# takes in x and outputs the input-specific delta, B, C
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False)

# projects delta
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner)
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))

self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False)

def forward(self, x):
"""
Figure 3 in Section 3.4 in the paper
"""
(b, l, d) = x.shape

x_and_res = self.in_proj(x) # [B, L, 2 * d_inner]
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1)

x = rearrange(x, "b l d -> b d l")
x = self.conv1d(x)[:, :, :l]
x = rearrange(x, "b d l -> b l d")

x = F.silu(x)

y = self.ssm(x)
y = y * F.silu(res)

output = self.out_proj(y)
return output


def ssm(self, x):
"""
Algorithm 2 in Section 3.2 in the paper
"""

(d_in, n) = self.A_log.shape

A = -torch.exp(self.A_log.float()) # [d_in, n]
D = self.D.float() # [d_in]

x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff]
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n]
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in]
y = self.selective_scan(x, delta, A, B, C, D)

return y

def selective_scan(self, u, delta, A, B, C, D):
(b, l, d_in) = u.shape
n = A.shape[1]

deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B"

# selective scan, sequential instead of parallel
x = torch.zeros((b, d_in, n), device=deltaA.device)
ys = []
for i in range(l):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = einsum(x, C[:, i, :], "b d n, b n -> b d")
ys.append(y)

y = torch.stack(ys, dim=1) # [B, L, d_in]
y = y + u * D

return y

class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super(RMSNorm, self).__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))

def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
Loading

0 comments on commit 69ffc6c

Please sign in to comment.