forked from MPoL-dev/MPoL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon_functions.py
144 lines (113 loc) · 3.83 KB
/
common_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import torch
from mpol import (
losses,
coordinates,
images,
precomposed,
gridding,
datasets,
connectors,
utils,
)
from astropy.utils.data import download_file
from ray import tune
import matplotlib.pyplot as plt
# We want to split these
# because otherwise the data loading routines will be rehashed several times.
def train(
model, dataset, optimizer, config, device, writer=None, report=False, logevery=50
):
"""
Args:
model: neural net model
dataset: to use to train against
optimizer: tied to model parameters and used to take a step
config: dictionary including epochs and hyperparameters.
device: "cpu" or "cuda"
writer: tensorboard writer object
"""
model = model.to(device)
model.train()
dataset = dataset.to(device)
residuals = connectors.GriddedResidualConnector(model.fcube, dataset)
residuals.to(device)
for iteration in range(config["epochs"]):
optimizer.zero_grad()
vis = model()
sky_cube = model.icube.sky_cube
loss = (
losses.nll_gridded(vis, dataset)
+ config["lambda_sparsity"] * losses.sparsity(sky_cube)
+ config["lambda_TV"] * losses.TV_image(sky_cube)
+ config["entropy"] * losses.entropy(sky_cube, config["prior_intensity"])
)
if (iteration % logevery == 0) and writer is not None:
writer.add_scalar("loss", loss.item(), iteration)
writer.add_figure("image", log_figure(model, residuals), iteration)
loss.backward()
optimizer.step()
if report:
tune.report(loss=loss.item())
return loss.item()
def test(model, dataset, device):
model = model.to(device)
model.eval()
dataset = dataset.to(device)
vis = model()
loss = losses.nll_gridded(vis, dataset)
return loss.item()
def cross_validate(model, config, device, k_fold_datasets, MODEL_PATH, writer=None):
test_scores = []
for k_fold, (train_dset, test_dset) in enumerate(k_fold_datasets):
# reset model
model.load_state_dict(torch.load(MODEL_PATH))
# create a new optimizer for this k_fold
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
# train for a while
train(model, train_dset, optimizer, config, device, writer=writer)
# evaluate the test metric
test_scores.append(test(model, test_dset, device))
# aggregate all test scores and sum to evaluate cross val metric
test_score = np.sum(np.array(test_scores))
# log to ray tune
tune.report(cv_score=test_score)
return test_score
def log_figure(model, residuals):
"""
Create a matplotlib figure showing the current image state.
Args:
model: neural net model
"""
# populate residual connector
residuals()
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(10, 10))
im = ax[0, 0].imshow(
np.squeeze(model.icube.sky_cube.detach().cpu().numpy()),
origin="lower",
interpolation="none",
extent=model.icube.coords.img_ext,
)
plt.colorbar(im, ax=ax[0, 0])
im = ax[0, 1].imshow(
np.squeeze(residuals.sky_cube.detach().cpu().numpy()),
origin="lower",
interpolation="none",
extent=residuals.coords.img_ext,
)
plt.colorbar(im, ax=ax[0, 1])
im = ax[1, 0].imshow(
np.squeeze(torch.log(model.fcube.ground_amp.detach()).cpu().numpy()),
origin="lower",
interpolation="none",
extent=residuals.coords.vis_ext,
)
plt.colorbar(im, ax=ax[1, 0])
im = ax[1, 1].imshow(
np.squeeze(torch.log(residuals.ground_amp.detach()).cpu().numpy()),
origin="lower",
interpolation="none",
extent=residuals.coords.vis_ext,
)
plt.colorbar(im, ax=ax[1, 1])
return fig