Skip to content

Commit

Permalink
[S4] Update S4D 'kernel' to 'backend' parameter rename in standalone
Browse files Browse the repository at this point in the history
  • Loading branch information
albertfgu committed Jul 8, 2023
1 parent 8a12f34 commit 94d0257
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 8 additions & 7 deletions models/s4/s4.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ class SSMKernelDiag(SSMKernel):
Parameterize the real/imag parts of the diagonal of A under this function.
bandlimit: Mask high frequencies of the kernel (indices corresponding to
diagonal elements with large imaginary part). Introduced in S4ND paper.
kernel: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
force_real : Force A to have 0 imaginary part, to emulate EMA.
"""

Expand All @@ -996,7 +996,7 @@ def __init__(
real_transform: str = 'exp',
imag_transform: str = 'none',
bandlimit: Optional[float] = None,
kernel: str = 'cuda',
backend: str = 'cuda',
force_real: bool = False,
**kwargs,
):
Expand All @@ -1006,7 +1006,7 @@ def __init__(
self.real_transform = real_transform
self.imag_transform = imag_transform
self.bandlimit = bandlimit
self.kernel = kernel
self.backend = backend
self.force_real = force_real

# Initialize dt, A, B, C
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def register_params(self, A, B, C, inv_dt, P):
Note: tensor shape N here denotes half the true state size, because of conjugate symmetry
"""

assert self.kernel in ['cuda', 'keops', 'naive']
assert self.backend in ['cuda', 'keops', 'naive']

if self.dt_fast: inv_dt = torch.asinh(inv_dt)

Expand Down Expand Up @@ -1115,9 +1115,9 @@ def forward(self, L, state=None, rate=1.0):
C = (B[:, None, :, :] * C).view(-1, self.H, self.N)

# Dispatch which Vandermonde kernel to use
if has_cuda_extension and C.dtype == torch.cfloat and C.device.type == 'cuda' and self.kernel == 'cuda':
if has_cuda_extension and C.dtype == torch.cfloat and C.device.type == 'cuda' and self.backend == 'cuda':
log_vandermonde = log_vandermonde_cuda
elif has_pykeops and self.kernel in ['cuda', 'keops']:
elif has_pykeops and self.backend in ['cuda', 'keops']:
log_vandermonde = log_vandermonde_keops
else:
log_vandermonde = log_vandermonde_naive
Expand Down Expand Up @@ -1197,14 +1197,15 @@ def forward_state(self, u, state):
AL = self.dA ** u.size(-1)
u = u.flip(-1).to(self.dA).contiguous() # (B H L)
# Dispatch which Vandermonde kernel to use
if has_pykeops and self.kernel in ['cuda', 'keops']:
if has_pykeops and self.backend in ['cuda', 'keops']:
log_vandermonde_transpose = log_vandermonde_transpose_keops
else:
log_vandermonde_transpose = log_vandermonde_transpose_naive
v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1))
next_state = AL * state + v
return next_state


class SSMKernelDPLR(SSMKernelDiag):
"""SSM kernel for diagonal + low rank (DPLR) state matrices, corresponding to the original S4 model."""

Expand Down
2 changes: 1 addition & 1 deletion src/models/sequence/kernels/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class SSMKernelDiag(SSMKernel):
Parameterize the real/imag parts of the diagonal of A under this function.
bandlimit: Mask high frequencies of the kernel (indices corresponding to
diagonal elements with large imaginary part). Introduced in S4ND paper.
kernel: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
backend: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
force_real : Force A to have 0 imaginary part, to emulate EMA.
"""

Expand Down

0 comments on commit 94d0257

Please sign in to comment.