forked from state-spaces/s4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
s4.py
1952 lines (1612 loc) · 73.7 KB
/
s4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Standalone version of Structured State Space sequence model (S4)."""
from collections import defaultdict
from typing import Optional, Mapping, Tuple, Union
import logging
from functools import partial
import math
import numpy as np
from scipy import special as ss
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_only
from einops import rearrange, repeat
# Function aliases
contract = torch.einsum
_conj = lambda x: torch.cat([x, x.conj()], dim=-1)
_c2r = torch.view_as_real
_r2c = torch.view_as_complex
if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10):
_resolve_conj = lambda x: x.conj().resolve_conj()
else:
_resolve_conj = lambda x: x.conj()
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
"""Initializes multi-GPU-friendly python logger."""
logger = logging.getLogger(name)
logger.setLevel(level)
# this ensures all logging levels get marked with the rank zero decorator
# otherwise logs would get multiplied for each GPU process in multi-GPU setup
for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
setattr(logger, level, rank_zero_only(getattr(logger, level)))
return logger
log = get_logger(__name__)
"""Structured matrix kernels"""
# Try CUDA extension
try:
from extensions.kernels.cauchy import cauchy_mult as cauchy_cuda
from extensions.kernels.vandermonde import log_vandermonde_cuda
has_cuda_extension = True
log.info("CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) found.")
except:
log.warning(
"CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled."
)
has_cuda_extension = False
# Try pykeops
try:
import pykeops
from pykeops.torch import Genred
has_pykeops = True
log.info("Pykeops installation found.")
def _broadcast_dims(*tensors):
max_dim = max([len(tensor.shape) for tensor in tensors])
tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors]
return tensors
def cauchy_keops(v, z, w):
expr_num = 'z * ComplexReal(v) - Real2Complex(Sum(v * w))'
expr_denom = 'ComplexMult(z-w, z-Conj(w))'
cauchy_mult = Genred(
f'ComplexDivide({expr_num}, {expr_denom})',
[
'v = Vj(2)',
'z = Vi(2)',
'w = Vj(2)',
],
reduction_op='Sum',
axis=1,
)
v, z, w = _broadcast_dims(v, z, w)
v = _c2r(v)
z = _c2r(z)
w = _c2r(w)
r = 2*cauchy_mult(v, z, w, backend='GPU')
return _r2c(r)
def log_vandermonde_keops(v, x, L):
expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))'
vandermonde_mult = Genred(
expr,
[
'v = Vj(2)',
'x = Vj(2)',
'l = Vi(2)',
],
reduction_op='Sum',
axis=1,
)
l = torch.arange(L).to(x)
v, x, l = _broadcast_dims(v, x, l)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(v, x, l, backend='GPU')
return 2*_r2c(r).real
def log_vandermonde_transpose_keops(u, v, x, L):
"""
u: ... H L
v: ... H N
x: ... H N
Returns: ... H N
V = Vandermonde(a, L) : (H N L)
contract_L(V * u * v)
"""
expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))'
vandermonde_mult = Genred(
expr,
[
'u = Vj(2)',
'v = Vi(2)',
'x = Vi(2)',
'l = Vj(2)',
],
reduction_op='Sum',
axis=1,
)
l = torch.arange(L).to(x)
u, v, x, l = _broadcast_dims(u, v, x, l)
u = _c2r(u)
v = _c2r(v)
x = _c2r(x)
l = _c2r(l)
r = vandermonde_mult(u, v, x, l, backend='GPU')
return _r2c(r)
except ImportError:
has_pykeops = False
if not has_cuda_extension:
log.warning(
"Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency."
)
# Fallback versions
def cauchy_naive(v, z, w):
"""
v: (..., N)
z: (..., L)
w: (..., N)
returns: (..., L) \sum v/(z-w)
"""
v = _conj(v)
w = _conj(w)
cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L)
return torch.sum(cauchy_matrix, dim=-2)
def log_vandermonde_naive(v, x, L, conj=True):
"""
v: (..., N)
x: (..., N)
returns: (..., L) \sum v x^l
"""
vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L)
vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L)
return 2*vandermonde_prod.real
def log_vandermonde_transpose_naive(u, v, x, L):
vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L)
vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L)
return vandermonde_prod
""" Simple nn.Module components """
def Activation(activation=None, dim=-1):
if activation in [ None, 'id', 'identity', 'linear' ]:
return nn.Identity()
elif activation == 'tanh':
return nn.Tanh()
elif activation == 'relu':
return nn.ReLU()
elif activation == 'gelu':
return nn.GELU()
elif activation == 'elu':
return nn.ELU()
elif activation in ['swish', 'silu']:
return nn.SiLU()
elif activation == 'glu':
return nn.GLU(dim=dim)
elif activation == 'sigmoid':
return nn.Sigmoid()
elif activation == 'softplus':
return nn.Softplus()
else:
raise NotImplementedError("hidden activation '{}' is not implemented".format(activation))
def LinearActivation(
d_input, d_output, bias=True,
transposed=False,
activation=None,
activate=False, # Apply activation as part of this module
**kwargs,
):
"""Returns a linear nn.Module with control over axes order, initialization, and activation."""
# Construct core module
linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear
if activation is not None and activation == 'glu': d_output *= 2
linear = linear_cls(d_input, d_output, bias=bias, **kwargs)
if activate and activation is not None:
activation = Activation(activation, dim=-2 if transposed else -1)
linear = nn.Sequential(linear, activation)
return linear
class DropoutNd(nn.Module):
def __init__(self, p: float = 0.5, tie=True, transposed=True):
"""
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
"""
super().__init__()
if p < 0 or p >= 1:
raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
self.p = p
self.tie = tie
self.transposed = transposed
self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
def forward(self, X):
"""X: (batch, dim, lengths...)."""
if self.training:
if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape
mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p
X = X * mask * (1.0/(1-self.p))
if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
return X
return X
"""Misc functional utilities"""
def power(L, A, v=None):
"""Compute A^L and the scan sum_i A^i v_i.
A: (..., N, N)
v: (..., N, L)
"""
I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device)
powers = [A]
l = 1
while True:
if L % 2 == 1: I = powers[-1] @ I
L //= 2
if L == 0: break
l *= 2
if v is None:
powers = [powers[-1] @ powers[-1]]
else:
powers.append(powers[-1] @ powers[-1])
if v is None: return I
# Invariants:
# powers[-1] := A^l
# l := largest po2 at most L
# Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A
# We do this reverse divide-and-conquer for efficiency reasons:
# 1) it involves fewer padding steps for non-po2 L
# 2) it involves more contiguous arrays
# Take care of edge case for non-po2 arrays
# Note that this initial step is a no-op for the case of power of 2 (l == L)
k = v.size(-1) - l
v_ = powers.pop() @ v[..., l:]
v = v[..., :l]
v[..., :k] = v[..., :k] + v_
# Handle reduction for power of 2
while v.size(-1) > 1:
v = rearrange(v, '... (z l) -> ... z l', z=2)
v = v[..., 0, :] + powers.pop() @ v[..., 1, :]
return I, v.squeeze(-1)
"""HiPPO utilities"""
def transition(measure, N, **measure_args):
"""A, B transition matrices for different measures.
measure: the type of measure
legt - Legendre (translated)
legs - Legendre (scaled)
glagt - generalized Laguerre (translated)
lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization
"""
# Legendre (translated)
if measure == 'legt':
Q = np.arange(N, dtype=np.float64)
R = (2*Q + 1) ** .5
j, i = np.meshgrid(Q, Q)
A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :]
B = R[:, None]
A = -A
# Halve again for timescale correctness
A *= 0.5
B *= 0.5
# Legendre (scaled)
elif measure == 'legs':
q = np.arange(N, dtype=np.float64)
col, row = np.meshgrid(q, q)
r = 2 * q + 1
M = -(np.where(row >= col, r, 0) - np.diag(q))
T = np.sqrt(np.diag(2 * q + 1))
A = T @ M @ np.linalg.inv(T)
B = np.diag(T)[:, None]
B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
elif measure in ['fourier', 'fout']:
freqs = np.arange(N//2)
d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:]
A = np.pi*(-np.diag(d, 1) + np.diag(d, -1))
B = np.zeros(N)
B[0::2] = 2**.5
B[0] = 1
# Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case
A = A - B[:, None] * B[None, :]
B = B[:, None]
else:
raise NotImplementedError
return A, B
def rank_correction(measure, N, rank=1, dtype=torch.float):
"""Return low-rank matrix L such that A + L is normal."""
if measure == 'legs':
assert rank >= 1
P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N)
elif measure == 'legt':
assert rank >= 2
P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N)
P0 = P.clone()
P0[0::2] = 0.
P1 = P.clone()
P1[1::2] = 0.
P = torch.stack([P0, P1], dim=0) # (2 N)
P *= 2**(-0.5) # Halve the rank correct just like the original matrix was halved
elif measure in ['fourier', 'fout']:
P = torch.zeros(N)
P[0::2] = 2**.5
P[0] = 1
P = P.unsqueeze(0)
else: raise NotImplementedError
d = P.size(0)
if rank > d:
P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N)
return P
def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True, B_clip=2.0):
"""Constructs NPLR form of HiPPO matrices.
Returns w, p, q, V, B such that
(w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V
i.e. A = V[w - p q^*]V^*, B = V B
measure: Name of HiPPO method.
N: Size of recurrent A matrix (also known as `d_state` elsewhere).
dtype: Single or double precision.
diagonalize_precision: Calculate diagonalization in double precision.
B_clip: Clip values of B, can help with stability. None for no clipping.
"""
assert dtype == torch.float or dtype == torch.double
cdtype = torch.cfloat if dtype == torch.float else torch.cdouble
A, B = transition(measure, N)
A = torch.as_tensor(A, dtype=dtype) # (N, N)
B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,)
P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N)
AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3)
# We require AP to be nearly skew-symmetric
_A = AP + AP.transpose(-1, -2)
if (err := torch.sum((_A - _A[0,0]*torch.eye(N))**2) / N) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5):
print("WARNING: HiPPO matrix not skew symmetric", err)
# Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately
# Imaginary part can use eigh instead of eig
W_re = torch.mean(torch.diagonal(AP), -1, keepdim=True)
# Diagonalize in double precision
if diagonalize_precision: AP = AP.to(torch.double)
# w, V = torch.linalg.eig(AP) # (..., N) (..., N, N)
W_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N)
if diagonalize_precision: W_im, V = W_im.to(cdtype), V.to(cdtype)
W = W_re + 1j * W_im
# Check: V W V^{-1} = A
# print("check", V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2))
# Only keep half of each conjugate pair
_, idx = torch.sort(W.imag)
W_sorted = W[idx]
V_sorted = V[:, idx]
# There is an edge case when eigenvalues can be 0, which requires some machinery to handle
# We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case)
V = V_sorted[:, :N//2]
W = W_sorted[:N//2] # Only keep negative imaginary components
assert W[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A"
if W[-1].abs() < 1e-4:
V[:, -1] = 0.
V[0, -1] = 2**-0.5
V[1, -1] = 2**-0.5 * 1j
_AP = V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2)
if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5):
print("Warning: Diagonalization of A matrix not numerically precise - error", err)
# print("check", V @ torch.diag_embed(W) @ V.conj().transpose(-1, -2))
V_inv = V.conj().transpose(-1, -2)
# C = initial_C(measure, N, dtype=dtype)
B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B
# C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C
P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P
if B_clip is not None:
B = B.real + 1j*torch.clamp(B.imag, min=-B_clip, max=B_clip)
# W represents the imaginary part of the DPLR form: A = W - PP^*
# Downstream classes just call this A for simplicity,
# which is also more consistent with the diagonal case
return W, P, B, V
def dplr(
init='hippo',
N=64, rank=1, H=1,
dtype=torch.float,
real_random=False,
real_scale=1.0,
imag_random=False,
imag_scale=1.0,
B_random=False,
B_init='constant',
B_scale=1.0,
P_scale=1.0,
normalize=False,
):
"""Directly construct a DPLR matrix.
Args:
- init: (str) ['rand', 'lin', inv', 'real', 'hippo'] Choices for initialization of A.
Most of these affect the imaginary part of A, except for 'real'.
- real_random: (bool) Initialize A.real in -U[0, 1]. Otherwise, initialize to -1/2.
- real_scale: (float) Scaling factor of real part of A.
- imag_random: (bool) Initialize A.imag randomly.
- imag_scale: (bool) Scaling factor of imaginary part of A.
- B_init: (str) ['constant' | 'random' | 'alternating' | 'unit-cw' | 'unit-ccw' | 'hippo']
Choices for initialization of B.
- B_scale: (float) Scaling factor for B
- P_scale: (float) Scaling factor for P
- normalize: (bool) Apply an automatic normalization factor on B
"""
assert dtype == torch.float or dtype == torch.double
dtype = torch.cfloat if dtype == torch.float else torch.cdouble
pi = torch.tensor(math.pi)
# Construct real part of diagonal A (must be non-negative)
if real_random:
real_part = torch.rand(H, N//2)
else:
real_part = .5 * torch.ones(H, N//2)
real_part = real_scale * real_part
# Construct imaginary part of diagonal A (must be non-negative)
if imag_random:
imag_part = N//2 * torch.rand(H, N//2)
else:
imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H)
if init in ['random', 'rand']:
imag_part = torch.exp(torch.randn(H, N//2))
elif init == 'real':
imag_part = 0 * imag_part
if real_random:
real_part = torch.rand(H, N//2) * N//2
else:
# This is the S4D-Real method described in the S4D paper
# The A matrix is diag(-1, -2, ..., -N), which are the eigenvalues of the HiPPO matrix
real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H)
elif init in ['linear', 'lin']:
imag_part = pi * imag_part
elif init in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix
imag_part = 1/pi * N * (N/(1+2*imag_part)-1)
elif init in ['inverse2', 'inv2']:
imag_part = 1/pi * N * (N/(1+imag_part)-1)
elif init in ['quadratic', 'quad']:
imag_part = 1/pi * (1+2*imag_part)**2
elif init in ['legs', 'hippo']:
A, _, _, _ = nplr('legs', N)
imag_part = -A.imag # Positive
else: raise NotImplementedError
imag_part = imag_scale * imag_part
# Construct diagonal A
A = -real_part - 1j * imag_part # Force negative real and imag
assert torch.all(A.real < 1e-4) and torch.all(A.imag <= 0.0) # Allow some tolerance for numerical precision on real part
# Initialize B
if B_random:
log.warning("'B_random' is deprecated in favor of B_init='random' and will be deprecated in a future version.")
if init in ['legs', 'hippo']:
log.info(f'Initializing with S4D-LegS and ignoring argument {B_init=}')
# Special initialization using the HiPPO B matrix
# Note that theory (from S4D paper) says that B should be halved
# to match DPLR but we drop this 0.5 factor for simplicity
_, P, B, _ = nplr('legs', N, B_clip=2.0)
B = repeat(B, 'n -> h n', h=H).clone().contiguous()
elif B_init == 'constant':
B = torch.ones(H, N//2, dtype=dtype)
elif B_init == 'random':
B = torch.randn(H, N//2, dtype=dtype)
elif B_init == 'alternating': # Seems to track 'constant' exactly for some reason
B = torch.ones(H, N//4, 2, dtype=dtype)
B[:, :, 1] *= -1
B = B.view(H, N//2)
elif B_init == 'unit-cw':
z = torch.tensor(torch.exp(-2j * pi / N), dtype=dtype)
B = z ** torch.arange(0, N // 2)
B = repeat(B, 'n -> h n', h=H).clone().contiguous()
elif B_init == 'unit-ccw':
z = torch.tensor(torch.exp(2j * pi / N), dtype=dtype)
B = z ** torch.arange(0, N // 2)
B = repeat(B, 'n -> h n', h=H).clone().contiguous()
else: raise NotImplementedError
B *= B_scale
# Experimental feature that appeared in earlier versions of HTTYH (not extensively tested)
# Seems more principled for normalization theoretically, but seemed to hurt on PathX
if normalize:
norm = -B/A # (H, N) # Result if you integrate the kernel with constant 1 function
zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector
B = B / zeta**.5
# Initialize P
if B_init in ['legs', 'hippo']:
# P constructed earlier
P = repeat(P, 'r n -> r h n', h=H).clone().contiguous()
else:
P = torch.randn(rank, H, N//2, dtype=dtype)
P = P * P_scale
# Initialize V (only used in testing)
V = torch.eye(N, dtype=dtype)[:, :N//2]
V = repeat(V, 'n m -> h n m', h=H)
return A, P, B, V
def ssm(init, N, R, H, **ssm_args):
"""Dispatcher to create single SSM initialization
N: state size
R: rank (for DPLR parameterization)
H: number of independent SSM copies
"""
if init.startswith("diag") or init.startswith("dplr"):
if init.startswith("diag"):
ssm_args["P_scale"] = 0.0
args = init[4:].split("-")
assert args[0] == ""
if len(args) > 1:
ssm_args["init"] = args[1]
A, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args)
else:
A, P, B, V = nplr(init, N, R, **ssm_args)
A = repeat(A, 'n -> s n', s=H)
P = repeat(P, 'r n -> r s n', s=H)
B = repeat(B, 'n -> s n', s=H)
V = repeat(V, 'n m -> s n m', s=H)
return A, P, B, V
combinations = {
'hippo': ['legs', 'fourier'],
'diag': ['diag-inv', 'diag-lin'],
'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'],
}
def combination(inits, N, R, S, **ssm_args):
if isinstance(inits, str):
inits = combinations[inits] if inits in combinations else [inits]
assert S % len(inits) == 0, f"{S} independent trainable SSM copies must be multiple of {len(inits)} different inits"
A, P, B, V = zip(
*[ssm(init, N, R, S // len(inits), **ssm_args) for init in inits]
)
A = torch.cat(A, dim=0) # (S N)
P = torch.cat(P, dim=1) # (R S N)
B = torch.cat(B, dim=0) # (S N)
V = torch.cat(V, dim=0) # (S N N)
return A, P, B, V
"""SSM convolution kernels"""
def inv_transform(param, transform='none'):
"""Initialize a (positive) parameter under a transform."""
param = torch.clamp(param, min=1e-4)
if transform == 'none':
return param
elif transform == 'exp':
return torch.log(param) # Some of the HiPPO methods have real part 0
elif transform == 'relu':
return param
elif transform == 'sigmoid':
return torch.logit(param)
elif transform == 'softplus':
return torch.log(torch.exp(param)-1)
else: raise NotImplementedError
def param_transform(param, transform='none'):
"""Get a (positive) parameter under a transform."""
if transform == 'none':
p = param
elif transform == 'exp':
p = torch.exp(param)
elif transform == 'relu':
# JAX version seems to NaN if you allow 0's, although this code was fine without it
p = F.relu(param)+1e-4
elif transform == 'sigmoid':
p = F.sigmoid(param)
elif transform == 'softplus':
p = F.softplus(param)
else: raise NotImplementedError
return p
class Kernel(nn.Module):
"""Interface for modules that produce convolution kernels.
A main distinction between these and normal Modules is that the forward pass
does not take inputs. It is a mapping from parameters to a tensor that can
be used in other modules, in particular as a convolution kernel.
Because of the unusual parameterization, these kernels may often want special
hyperparameter settings on their parameters. The `register` method provides
an easy interface for controlling this, and is intended to be used with an
optimizer hook that can be found in train.py or example.py.
This class also defines an interface for interacting with kernels *statefully*,
in particular for state space models (SSMs). This interface handles the setting
when a model can be converted from a "CNN" into an "RNN".
_setup_step()
step()
default_state()
forward_state()
See ConvKernel for the simplest instantiation of this interface.
"""
def __init__(
self,
d_model: int = 0,
channels: int = 1,
l_max: Optional[int] = None,
lr: Union[float, Optional[Mapping]] = None,
wd: Union[float, Optional[Mapping]] = 0.0,
verbose: bool = True,
**kwargs,
):
"""General interface.
d_model (H): Model dimension, or number of independent convolution kernels created.
channels (C): Extra dimension in the returned output (see .forward()).
- One interpretation is that it expands the input dimension giving it C separate "heads" per feature.
That is convolving by this kernel maps shape (B L D) -> (B L C D)
- This is also used to implement a particular form of bidirectionality in an efficient way.
- In general for making a more powerful model, instead of increasing C
it is recommended to set channels=1 and adjust H to control parameters instead.
l_max (L): Maximum kernel length (optional). If unspecified, most Kernel instantiations
will return kernels of arbitrary length as passed into .forward().
lr: Optional dictionary specifying special hyperparameters for .register().
Passing in a number (e.g. 0.001) sets attributes of SSM parameters (A, B, dt).
A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters.
wd: Same as lr, but for weight decay.
"""
super().__init__()
assert d_model > 0
self.H = self.d_model = d_model
self.L = self.l_max = l_max
self.channels = channels
self.lr = lr
self.wd = wd
self.verbose = verbose
# Add a catch-all **kwargs to make it easier to change kernels
# without manually moving other options passed in the config.
# Good to log these just so it's explicit.
if self.verbose and len(kwargs) > 0:
log.info(f"{type(self)} extra kwargs: {kwargs}")
# Logic for registering parameters
# Case 1: lr: None | float
# All params should have this lr (None means inherit from global lr)
# Case 2: lr: dict
# Specified params should have that lr, all others should be None
if self.lr is None or isinstance(self.lr, float):
self.lr_dict = defaultdict(lambda: self.lr)
else:
self.lr_dict = defaultdict(lambda: None)
self.lr_dict.update(self.lr)
# Same logic for weight decay
# (but is always just set to 0.0 and hasn't been ablated)
if self.wd is None or isinstance(self.wd, float):
self.wd_dict = defaultdict(lambda: self.wd)
else:
self.wd_dict = defaultdict(lambda: None)
self.wd_dict.update(self.wd)
def forward(self, state=None, rate=1.0, L=None):
"""General interface to generate a global convolution kernel.
state: Initial state for recurrent updates.
E.g. for SSMs, this should have shape (B, H, N) (batch, d_model, d_state).
rate: Relative sampling rate.
L: Target kernel length.
Returns:
- (C, H, L) (channels, d_model, l_kernel) The convolution kernel.
- (B, H, L) (batch, d_model, l_kernel)
Extra information for how the state affects the output of convolving by kernel.
"""
raise NotImplementedError
def register(self, name, tensor, lr=None, wd=0.0):
"""Register a tensor with a configurable learning rate and 0 weight decay"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {}
if lr is not None: optim["lr"] = lr
if wd is not None: optim["weight_decay"] = wd
setattr(getattr(self, name), "_optim", optim)
def _setup_step(self, **kwargs):
"""Convert a model into a recurrent mode for autoregressive inference."""
raise NotImplementedError
def step(self, x, state, **kwargs):
"""Step the model for one timestep with input x and recurrent state."""
raise NotImplementedError
def default_state(self, *args, **kwargs):
"""Return a default initial state."""
raise NotImplementedError
@torch.no_grad()
def forward_state(self, u, state):
"""Forward the state through a sequence, i.e. computes the state after passing chunk through the kernel."""
raise NotImplementedError
@property
def d_state(self):
"""Implement this for interfaces that want to interact with a stateful layer (i.e. SSMs).
Currently the only codepath that might use this is the StateDecoder, which is not used.
"""
raise NotImplementedError
@property
def state_to_tensor(self):
"""Same as d_state, only needed for niche codepaths involving recurrent state."""
raise NotImplementedError
class SSMKernel(Kernel):
"""Parent class for different SSM parameterizations.
This class is abstract and only defines some initializations and flags that are common to all SSM variants.
It is instantiated by subclasses SSMKernel{Dense,Real,Diag,DPLR}.
Options:
d_state (N): State size (dimensionality of parameters A, B, C). Generally shouldn't need to be adjusted and doens't affect speed much for most kernels (e.g. S4, S4D).
deterministic: Use a deterministic initialization for dt, A, B, C.
Useful for debugging as well as constructing a simple exponential decay kernel (e.g. used in S4ND image->video inflation).
dt_min, dt_max: min and max values for the step size dt
dt_tie: Keep dt tied across the N dimensions of the state. Although this theoretically makes more sense, models such as S5 and Mega have found slightly improvements by setting it to False.
dt_transform: Transform function for parameterization of dt (default 'softplus', used to be 'exp')
rank: Rank of low-rank correction for DPLR mode. Needs to be increased for init "legt".
n_ssm: Number of independent trainable (A, B) SSMs, e.g.
`n_ssm=1` means all A/B parameters are tied across the H different instantiations of C.
`n_ssm=None` means all H SSMs are completely independent.
Generally, changing this option can save parameters but doesn't affect performance or speed much.
This parameter must divide H.
init: Options for initialization of (A, B). For DPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin).
init_args: Extra arguments passed into initialization function (see dplr.py for options).
"""
def init_dt(self):
# Generate dt
if self.deterministic: # Meant for debugging
assert self.dt_tie, "Deterministic dt initialization is tied"
assert self.dt_transform == 'exp', "Deterministic dt transform should be 'exp' for simplicity"
inv_dt = torch.exp(torch.linspace(math.log(self.dt_min), math.log(self.dt_max), self.H)).unsqueeze(-1) # (H 1)
else:
shape = (self.H, 1) if self.dt_tie else (self.H, self.N//2)
# Initialize log dt
inv_dt = torch.rand(*shape, dtype=self.dtype) * (
math.log(self.dt_max) - math.log(self.dt_min)
) + math.log(self.dt_min)
if self.dt_transform != 'exp':
inv_dt = inv_transform(torch.exp(inv_dt), self.dt_transform)
return inv_dt
def init_ssm_real(self):
"""Returns (dense, real) (A, B, C) parameters for init options."""
# Generate A, B
A, B = transition(self.init, self.N)
A = torch.as_tensor(A, dtype=self.dtype)
B = torch.as_tensor(B, dtype=self.dtype)[:, 0]
B = repeat(B, 'n -> v n', v=self.n_ssm).clone().contiguous()
A = repeat(A, 'n m -> v n m', v=self.n_ssm).clone().contiguous()
# Generate C
if self.deterministic:
C = torch.zeros(self.channels, self.H, self.N, dtype=self.dtype)
C[..., :1] = 1.0
else:
C = torch.randn(self.channels, self.H, self.N, dtype=self.dtype)
return A, B, C
def init_ssm_dplr(self):
"""Returns DPLR (A, P, B, C) parameters for init options."""
A, P, B, V = combination(self.init, self.N, self.rank, self.n_ssm, **self.init_args)
# Broadcast C to have H channels
if self.deterministic:
C = torch.zeros(self.channels, self.n_ssm, self.N, dtype=self.cdtype)
C[:, :, :1] = 1.
C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C
C = repeat(C, 'c t n -> c (v t) n', v=self.H // C.size(-2)).clone().contiguous()
else:
C = torch.randn(self.channels, self.H, self.N//2, dtype=self.cdtype)
# Broadcast other parameters to have n_ssm copies
assert self.n_ssm % B.size(-2) == 0 \
and self.n_ssm % P.size(-2) == 0 \
and self.n_ssm % A.size(-2) == 0
# Broadcast tensors to n_ssm copies
# These will be the parameters, so make sure tensors are materialized and contiguous
B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous()
P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous()
A = repeat(A, 't n -> (v t) n', v=self.n_ssm // A.size(-2)).clone().contiguous()
# Because these complex parameterizations assume conjugate symmetry,
# halve the value of self.N for convenience
self.N //= 2
return A, P, B, C
def __init__(
self,
# General Kernel arguments for parent class
d_model: int = 0,
channels: int = 1,
l_max: Optional[int] = None,
lr: Union[float, Optional[Mapping]] = None,
wd: Union[float, Optional[Mapping]] = 0.0,
verbose: bool = True,
# SSM arguments
d_state: int = 64,
deterministic: bool = False,
# dt options
dt_min: float = 0.001,
dt_max: float = 0.1,
dt_tie: bool = True,
dt_transform: str = 'exp',
# (A, B, C) options
rank: int = 1,
n_ssm: Optional[int] = None,
measure: Optional[str] = None,
init: Optional[str] = "legs",
# Extra hyperparameters for initialization
**init_args,
):
super().__init__(d_model=d_model, channels=channels, l_max=l_max, lr=lr, wd=wd, verbose=verbose)
self.N = d_state
self.dtype, self.cdtype = torch.float, torch.cfloat
self.deterministic = deterministic
# dt options
self.dt_min = dt_min
self.dt_max = dt_max
self.dt_tie = dt_tie
self.dt_transform = dt_transform
# SSM options (A, B, C)
self.rank = rank
self.n_ssm = n_ssm if n_ssm is not None else self.H
if measure is not None:
log.warning("Warning: 'measure' option changed to 'init' and will be removed in a future version.")
assert init is None, "'measure' and 'init' cannot both be passed into SSMKernel"
init, measure = measure, init
self.init = init
self.init_args = init_args
@torch.no_grad()
def forward_state(self, u, state):
"""Forward the state through a sequence, i.e. computes the state after passing chunk through SSM
This is a generic version of this functionality that works for SSMs.
It is currently used by SSMKernelDense and SSMKernelDPLR.
This is a suboptimal implementation; it is recommended to use SSMKernelDiag
if this functionality is desired.
state: (B, H, N)
u: (B, H, L)
Returns: (B, H, N)
"""
# Construct dA, dB matrices
dA, dB = self._setup_state() # (H N N) (H N)
conj = state.size(-1) != dA.size(-1)
if conj: state = _conj(state)
v = contract('h n, b h l -> b h n l', dB, u.flip(-1))
AL, v = power(u.size(-1), dA, v)
next_state = contract("h m n, b h n -> b h m", AL, state)
next_state = next_state + v
if conj: next_state = next_state[..., : next_state.size(-1) // 2]
return next_state
def _setup_state(self):
"""Register dA and dB to module."""
raise NotImplementedError
@property
def d_state(self):
"""d_state and state_to_tensor are used by specific decoders.
These were used in earlier versions and should not be needed in general.
"""
return self.H * self.N
@property
def state_to_tensor(self):
return lambda state: rearrange('... h n -> ... (h n)', state)
class SSMKernelDiag(SSMKernel):
"""SSM kernel using diagonal state matrix (S4D model).
Options:
disc: ['zoh' | 'bilinear' | 'dss'] Discretization options.
dt_fast: (experimental) Parameterize inv_dt under sinh function.
(Ohno et al. "Fast Saturating Gate for Learning Long Time Scales with RNNs")
real_transform, imag_transform: ['none' | 'exp' | 'relu' | 'sigmoid' | 'softplus']
Parameterize the real/imag parts of the diagonal of A under this function.
bandlimit: Mask high frequencies of the kernel (indices corresponding to
diagonal elements with large imaginary part). Introduced in S4ND paper.
kernel: ['cuda' | 'keops' | 'naive'] Options for Vandermonde/Cauchy kernel (in order of efficiency).
force_real : Force A to have 0 imaginary part, to emulate EMA.
"""
def __init__(
self,
disc: str = 'zoh', # Change to 'bilinear' to match S4, but should make little difference either way
dt_fast: bool = False,
real_transform: str = 'exp',
imag_transform: str = 'none',
bandlimit: Optional[float] = None,
kernel: str = 'cuda',
force_real: bool = False,