forked from MPoL-dev/MPoL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon_data.py
66 lines (54 loc) · 1.39 KB
/
common_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import matplotlib.pyplot as plt
import numpy as np
import torch
from astropy.utils.data import download_file
from ray import tune
from mpol import (
connectors,
coordinates,
crossval,
datasets,
gridding,
images,
losses,
precomposed,
utils,
)
# load the data
fname = "HD143006_continuum.npz"
d = np.load(fname)
uu = d["uu"]
vv = d["vv"]
weight = d["weight"]
data = d["data"]
data_re = data.real
data_im = data.imag
coords = coordinates.GridCoords(cell_size=0.01, npix=512)
gridder = gridding.Gridder(
coords=coords,
weight=weight,
data_re=data_re,
data_im=data_im,
)
dataset = gridder.to_pytorch_dataset()
# plot the grid
fig, ax = plt.subplots(nrows=1)
ax.imshow(
np.squeeze(utils.packed_cube_to_ground_cube(dataset.mask).detach().numpy()),
interpolation="none",
origin="lower",
extent=coords.vis_ext,
cmap="GnBu",
)
fig.savefig("grid.png", dpi=300)
# create the cross validator
# create a radial and azimuthal partition
dartboard = datasets.Dartboard(coords=coords)
# create cross validator using this "dartboard"
k = 5
cv = crossval.DartboardSplitGridded(dataset, k, dartboard=dartboard, seed=42)
k_fold_datasets = [(train, test) for (train, test) in cv]
# create the model
model = precomposed.SimpleNet(coords=coords, nchan=dataset.nchan)
# create the residual connector
residuals = connectors.GriddedResidualConnector(model.fcube, dataset)