forked from shekshaa/ResidualMFN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ctf.py
91 lines (79 loc) · 3.32 KB
/
ctf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import torch
def compute_safe_freqs(n_pixels, psize):
freq_pix_1d = np.arange(-0.5, 0.5, 1 / n_pixels)
freq_pix_1d_safe = freq_pix_1d[:n_pixels]
x, y = np.meshgrid(freq_pix_1d_safe, freq_pix_1d_safe)
rho = np.sqrt(x**2 + y**2)
angles_rad = np.arctan2(y, x)
freq_mag_2d = rho / psize
return freq_mag_2d, angles_rad
def torch_compute_safe_freqs(n_pixels, psize):
freq_pix_1d = torch.arange(-0.5, 0.5, 1 / n_pixels)
freq_pix_1d_safe = freq_pix_1d[:n_pixels]
x, y = torch.meshgrid(freq_pix_1d_safe, freq_pix_1d_safe)
rho = torch.sqrt(x**2 + y**2)
angles_rad = torch.atan2(y, x)
freq_mag_2d = rho / psize
return freq_mag_2d, angles_rad
def generate_random_ctf_params(batch_size,
df_min=15000,
df_max=20000,
df_diff_min=100,
df_diff_max=500,
df_ang_min=0,
df_ang_max=360,
volt=300,
cs=2.7,
w=0.1,
phase_shift=0,):
dfs = np.random.uniform(low=df_min, high=df_max, size=(batch_size, 1, 1))
df_diff = np.random.uniform(low=df_diff_min, high=df_diff_max, size=(batch_size, 1, 1))
df_ang_deg = np.random.uniform(low=df_ang_min, high=df_ang_max, size=(batch_size, 1, 1))
dfu = dfs - df_diff / 2
dfv = dfs + df_diff / 2
return dfu, dfv, df_ang_deg, np.ones((batch_size, 1, 1)) * volt, np.ones((batch_size, 1, 1)) * cs, np.ones((batch_size, 1, 1)) * w, np.ones((batch_size, 1, 1)) * phase_shift
def compute_ctf(s, a, dfu, dfv, dfang_deg, kv, cs, w, phase=0, bf=0):
s = s[None, ...] # add batch dimension
a = a[None, ...] # add batch dimension
kv = kv * 1e3
cs = cs * 1e7
lamb = 12.2643247 / np.sqrt(kv * (1.0 + kv * 0.978466e-6))
dfang_deg = np.deg2rad(dfang_deg)
def_avg = -(dfu + dfv) * 0.5
def_dev = -(dfu - dfv) * 0.5
k1 = np.pi / 2.0 * 2 * lamb
k2 = np.pi / 2.0 * cs * lamb**3
k3 = np.sqrt(1 - w**2)
k4 = bf / 4.0 # B-factor, follows RELION convention.
k5 = np.deg2rad(phase) # Phase shift.
s_2 = s**2
s_4 = s_2**2
dZ = def_avg + def_dev * (np.cos(2 * (a - dfang_deg)))
gamma = (k1 * dZ * s_2) + (k2 * s_4) - k5
ctf = -(k3 * np.sin(gamma) - w * np.cos(gamma))
if bf != 0: # Enforce envelope.
ctf *= np.exp(-k4 * s_2)
return ctf
def torch_compute_ctf(s, a, dfu, dfv, dfang_deg, kv, cs, w, phase=0, bf=0):
s = s[None, ...] # add batch dimension
a = a[None, ...] # add batch dimension
kv = kv * 1e3
cs = cs * 1e7
lamb = 12.2643247 / torch.sqrt(kv * (1.0 + kv * 0.978466e-6))
dfang_deg = torch.deg2rad(dfang_deg)
def_avg = -(dfu + dfv) * 0.5
def_dev = -(dfu - dfv) * 0.5
k1 = np.pi / 2.0 * 2 * lamb
k2 = np.pi / 2.0 * cs * lamb**3
k3 = torch.sqrt(1 - w**2)
k4 = bf / 4.0 # B-factor, follows RELION convention.
k5 = torch.deg2rad(phase) # Phase shift.
s_2 = s**2
s_4 = s_2**2
dZ = def_avg + def_dev * (torch.cos(2 * (a - dfang_deg)))
gamma = (k1 * dZ * s_2) + (k2 * s_4) - k5
ctf = -(k3 * torch.sin(gamma) - w * torch.cos(gamma))
if bf != 0: # Enforce envelope.
ctf *= torch.exp(-k4 * s_2)
return ctf