forked from Punyaslok/snippets
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfft.cpp
92 lines (75 loc) · 2.09 KB
/
fft.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Fast Fourier transform
//
// Caling mult(a, b, c, len) is identical to:
// REP(i, 2*len) tmp[i] = 0
// REP(i, len) REP(j, len) tmp[i+j] += a[i] * b[j];
// REP(i, 2*len) c[i] = tmp[i];
//
// There is also a variant with modular arithmetic: mult_mod.
//
// Common use cases:
// - big integer multiplication
// - convolutions in dynamic programming
//
// Time complexity: O(N log N), where N is the length of arrays
//
// Constants to configure:
// - MAX must be at least 2^ceil(log2(2 * len))
namespace FFT {
const int MAX = 1 << 20;
typedef llint value;
typedef complex<double> comp;
int N;
comp omega[MAX];
comp a1[MAX], a2[MAX];
comp z1[MAX], z2[MAX];
void fft(comp *a, comp *z, int m = N) {
if (m == 1) {
z[0] = a[0];
} else {
int s = N/m;
m /= 2;
fft(a, z, m);
fft(a+s, z+m, m);
REP(i, m) {
comp c = omega[s*i] * z[m+i];
z[m+i] = z[i] - c;
z[i] += c;
}
}
}
void mult(value *a, value *b, value *c, int len) {
N = 2*len;
while (N & (N-1)) ++N;
assert(N <= MAX);
REP(i, N) a1[i] = 0;
REP(i, N) a2[i] = 0;
REP(i, len) a1[i] = a[i];
REP(i, len) a2[i] = b[i];
REP(i, N) omega[i] = polar(1.0, 2*M_PI/N*i);
fft(a1, z1, N);
fft(a2, z2, N);
REP(i, N) omega[i] = comp(1, 0) / omega[i];
REP(i, N) a1[i] = z1[i] * z2[i] / comp(N, 0);
fft(a1, z1, N);
REP(i, 2*len) c[i] = round(z1[i].real());
}
void mult_mod(llint *a, llint *b, llint *c, int len, int mod) {
static llint a0[MAX], a1[MAX];
static llint b0[MAX], b1[MAX];
static llint c0[MAX], c1[MAX], c2[MAX];
REP(i, len) a0[i] = a[i] & 0xFFFF;
REP(i, len) a1[i] = a[i] >> 16;
REP(i, len) b0[i] = b[i] & 0xFFFF;
REP(i, len) b1[i] = b[i] >> 16;
FFT::mult(a0, b0, c0, len);
FFT::mult(a1, b1, c2, len);
REP(i, len) a0[i] += a1[i];
REP(i, len) b0[i] += b1[i];
FFT::mult(a0, b0, c1, len);
REP(i, 2*len) c1[i] -= c0[i] + c2[i];
REP(i, 2*len) c1[i] %= mod;
REP(i, 2*len) c2[i] %= mod;
REP(i, 2*len) c[i] = (c0[i] + (c1[i] << 16) + (c2[i] << 32)) % mod;
}
}