forked from thuml/Time-Series-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Hardware Aware Mamba + Long Term Forecasting
- Loading branch information
1 parent
22175f6
commit 69ffc6c
Showing
15 changed files
with
459 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,4 +158,5 @@ data_loader_all.py | |
/utils/self_tools.py | ||
/scripts/exp_scripts/ | ||
|
||
/checkpoints/ | ||
/checkpoints/ | ||
/results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import rearrange, repeat, einsum | ||
|
||
from layers.Embed import DataEmbedding | ||
|
||
|
||
class Model(nn.Module): | ||
""" | ||
Mamba, linear-time sequence modeling with selective state spaces O(L) | ||
Paper link: https://arxiv.org/abs/2312.00752 | ||
Implementation refernce: https://github.com/johnma2006/mamba-minimal/ | ||
""" | ||
|
||
def __init__(self, configs): | ||
super(Model, self).__init__() | ||
self.task_name = configs.task_name | ||
self.pred_len = configs.pred_len | ||
|
||
self.d_inner = configs.d_model * configs.expand | ||
self.dt_rank = math.ceil(configs.d_model / 16) # TODO implement "auto" | ||
|
||
self.embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq, configs.dropout) | ||
|
||
self.layers = nn.ModuleList([ResidualBlock(configs, self.d_inner, self.dt_rank) for _ in range(configs.e_layers)]) | ||
self.norm = RMSNorm(configs.d_model) | ||
|
||
self.out_layer = nn.Linear(configs.d_model, configs.c_out, bias=False) | ||
|
||
# def short_term_forecast(self, x_enc, x_mark_enc): | ||
def forecast(self, x_enc, x_mark_enc): | ||
mean_enc = x_enc.mean(1, keepdim=True).detach() | ||
x_enc = x_enc - mean_enc | ||
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() | ||
x_enc = x_enc / std_enc | ||
|
||
x = self.embedding(x_enc, x_mark_enc) | ||
for layer in self.layers: | ||
x = layer(x) | ||
|
||
x = self.norm(x) | ||
x_out = self.out_layer(x) | ||
|
||
x_out = x_out * std_enc + mean_enc | ||
return x_out | ||
|
||
# def long_term_forecast(self, x_enc, x_mark_enc): | ||
# x = self.embedding(x_enc, x_mark_enc) | ||
# for layer in self.layers: | ||
# x = layer(x) | ||
|
||
# x = self.norm(x) | ||
# x_out = self.out_layer(x) | ||
# return x_out | ||
|
||
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None): | ||
if self.task_name in ['short_term_forecast', 'long_term_forecast']: | ||
x_out = self.forecast(x_enc, x_mark_enc) | ||
return x_out[:, -self.pred_len:, :] | ||
|
||
|
||
# other tasks not implemented | ||
|
||
|
||
class ResidualBlock(nn.Module): | ||
def __init__(self, configs, d_inner, dt_rank): | ||
super(ResidualBlock, self).__init__() | ||
|
||
self.mixer = MambaBlock(configs, d_inner, dt_rank) | ||
self.norm = RMSNorm(configs.d_model) | ||
|
||
def forward(self, x): | ||
output = self.mixer(self.norm(x)) + x | ||
return output | ||
|
||
class MambaBlock(nn.Module): | ||
def __init__(self, configs, d_inner, dt_rank): | ||
super(MambaBlock, self).__init__() | ||
self.d_inner = d_inner | ||
self.dt_rank = dt_rank | ||
|
||
self.in_proj = nn.Linear(configs.d_model, self.d_inner * 2, bias=False) | ||
|
||
self.conv1d = nn.Conv1d( | ||
in_channels = self.d_inner, | ||
out_channels = self.d_inner, | ||
bias = True, | ||
kernel_size = configs.d_conv, | ||
padding = configs.d_conv - 1, # TODO dont understand this; come back and do kernel = 3 padding = 1 instead if it doesnt work? | ||
groups = self.d_inner | ||
) | ||
|
||
# takes in x and outputs the input-specific delta, B, C | ||
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + configs.d_ff * 2, bias=False) | ||
|
||
# projects delta | ||
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) | ||
|
||
A = repeat(torch.arange(1, configs.d_ff + 1), "n -> d n", d=self.d_inner) | ||
self.A_log = nn.Parameter(torch.log(A)) | ||
self.D = nn.Parameter(torch.ones(self.d_inner)) | ||
|
||
self.out_proj = nn.Linear(self.d_inner, configs.d_model, bias=False) | ||
|
||
def forward(self, x): | ||
""" | ||
Figure 3 in Section 3.4 in the paper | ||
""" | ||
(b, l, d) = x.shape | ||
|
||
x_and_res = self.in_proj(x) # [B, L, 2 * d_inner] | ||
(x, res) = x_and_res.split(split_size=[self.d_inner, self.d_inner], dim=-1) | ||
|
||
x = rearrange(x, "b l d -> b d l") | ||
x = self.conv1d(x)[:, :, :l] | ||
x = rearrange(x, "b d l -> b l d") | ||
|
||
x = F.silu(x) | ||
|
||
y = self.ssm(x) | ||
y = y * F.silu(res) | ||
|
||
output = self.out_proj(y) | ||
return output | ||
|
||
|
||
def ssm(self, x): | ||
""" | ||
Algorithm 2 in Section 3.2 in the paper | ||
""" | ||
|
||
(d_in, n) = self.A_log.shape | ||
|
||
A = -torch.exp(self.A_log.float()) # [d_in, n] | ||
D = self.D.float() # [d_in] | ||
|
||
x_dbl = self.x_proj(x) # [B, L, d_rank + 2 * d_ff] | ||
(delta, B, C) = x_dbl.split(split_size=[self.dt_rank, n, n], dim=-1) # delta: [B, L, d_rank]; B, C: [B, L, n] | ||
delta = F.softplus(self.dt_proj(delta)) # [B, L, d_in] | ||
y = self.selective_scan(x, delta, A, B, C, D) | ||
|
||
return y | ||
|
||
def selective_scan(self, u, delta, A, B, C, D): | ||
(b, l, d_in) = u.shape | ||
n = A.shape[1] | ||
|
||
deltaA = torch.exp(einsum(delta, A, "b l d, d n -> b l d n")) # A is discretized using zero-order hold (ZOH) discretization | ||
deltaB_u = einsum(delta, B, u, "b l d, b l n, b l d -> b l d n") # B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: "A is the more important term and the performance doesn't change much with the simplification on B" | ||
|
||
# selective scan, sequential instead of parallel | ||
x = torch.zeros((b, d_in, n), device=deltaA.device) | ||
ys = [] | ||
for i in range(l): | ||
x = deltaA[:, i] * x + deltaB_u[:, i] | ||
y = einsum(x, C[:, i, :], "b d n, b n -> b d") | ||
ys.append(y) | ||
|
||
y = torch.stack(ys, dim=1) # [B, L, d_in] | ||
y = y + u * D | ||
|
||
return y | ||
|
||
class RMSNorm(nn.Module): | ||
def __init__(self, d_model, eps=1e-5): | ||
super(RMSNorm, self).__init__() | ||
self.eps = eps | ||
self.weight = nn.Parameter(torch.ones(d_model)) | ||
|
||
def forward(self, x): | ||
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight | ||
return output |
Oops, something went wrong.