-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
156 lines (122 loc) · 5.47 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import argparse
import torch as th
import torchbox as tb
import torchsar as ts
from ecelms import BaggingECELMs
from dataset import readsamples
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
parser = argparse.ArgumentParser()
parser.add_argument('--datacfg', type=str, default='./data.yaml')
parser.add_argument('--modelcfg', type=str, default='./ecelms.yaml')
parser.add_argument('--weightfile', type=str, default=None)
parser.add_argument('--cstrategy', type=str, default='Entropy', help='Entropy, AveragePhase')
# params in Adam
parser.add_argument('--seed', type=int, default=2020)
parser.add_argument('--size_batch', type=int, default=10)
parser.add_argument('--snapshot_name', type=str, default='2020')
# misc
parser.add_argument('--device', type=str, default='cuda:0', help='device')
parser.add_argument('--mkpetype', type=str, default='RealPE', help='make phase error(RealPE, SimPoly, SimSin...)')
cfg = parser.parse_args()
isplot = True
isplot = False
ftshift = True
issaveimg = True
issaveimg = False
seed = cfg.seed
device = cfg.device
size_batch = cfg.size_batch
cudaTF32, cudnnTF32 = False, False
benchmark, deterministic = True, True
outfolder = './snapshot/tests/'
if cfg.weightfile is None:
cfg.weightfile = './record/RealPE/DiffKernelSize/Entropy/weights/64CELMs.pth.tar'
datacfg = tb.loadyaml(cfg.datacfg)
modelcfg = tb.loadyaml(cfg.modelcfg)
if 'SAR_AF_DATA_PATH' in os.environ.keys():
datafolder = os.environ['SAR_AF_DATA_PATH']
else:
datafolder = datacfg['SAR_AF_DATA_PATH']
print(datacfg)
print(modelcfg)
nCELMs = len(modelcfg['Convs'])
# for i in range(nCEMLs):
# modelcfg['Convs'][i][0][1] = 3
# modelcfg['Convs'][i][0][1] = 17
fileTrain = [datafolder + datacfg['filenames'][i] for i in datacfg['trainid']]
fileValid = [datafolder + datacfg['filenames'][i] for i in datacfg['validid']]
fileTest = [datafolder + datacfg['filenames'][i] for i in datacfg['testid']]
modeTest = 'sequentially'
print("--->Test files sampling mode:", modeTest)
print(fileTest)
keys, ppeaxismode = [['SI', 'ca', 'cr']], 'fftfreq'
parts = [1. / nCELMs] * nCELMs
X, ca, cr = readsamples(fileTest, keys=keys, nsamples=[4000], groups=[25], mode=modeTest, parts=None, seed=seed)
N, Na, Nr, _ = X.size()
numSamples = N
os.makedirs(outfolder + '/images/', exist_ok=True)
xa = ts.ppeaxis(Na, norm=True, shift=ftshift, mode=ppeaxismode)
xr = ts.ppeaxis(Nr, norm=True, shift=ftshift, mode=ppeaxismode)
F = X
if cfg.mkpetype in ['simpoly', 'SimPoly']:
print("---Focusing...")
pa, pr = ts.polype(ca, xa), ts.polype(cr, xr)
F = ts.focus(X, pa, None, isfft=True, ftshift=ftshift)
print("---Done.")
print("---Making polynominal phase error...")
carange = [[-32, -32, -32, -32, -32, -32], [32, 32, 32, 32, 32, 32]]
# carange = [[-64, -64, -64, -64, -64, -64], [64, 64, 64, 64, 64, 64]]
# carange = [[-128, -128, -128, -128, -128, -128], [128, 128, 128, 128, 128, 128]]
crrange = [[-32, -32, -32, -32, -32, -32], [32, 32, 32, 32, 32, 32]]
# crrange = [[-64, -64, -64, -64, -64, -64], [64, 64, 64, 64, 64, 64]]
# crrange = [[-128, -128, -128, -128, -128, -128], [128, 128, 128, 128, 128, 128]]
print('~~~carange', carange)
print('~~~crrange', crrange)
ppeg = ts.PolyPhaseErrorGenerator(carange, crrange, seed)
ppeg.mkpec(n=6000, seed=None) # train
ppeg.mkpec(n=8000, seed=None) # valid
ca, cr = ppeg.mkpec(n=N, seed=None)
pa, pr = ts.polype(ca, xa), ts.polype(cr, xr)
X = ts.defocus(F, pa, None, isfft=True, ftshift=ftshift)
print("---Making polynominal phase error done.")
# index = list(range(0, 200))
# # index = list(range(7800, 8000))
# # index = list(range(1980, 2031))
index = [89, 1994, 7884]
# index = [0, 1, 4, 19, 21, 51, 93, 140, 156, 162, 250, 2000, 1999, 7835, 7881, 7887]
# X, F, ca, cr = X[index], F[index], ca[index], cr[index]
numSamples = N = X.shape[0]
N = X.shape[0]
size_batch = min(cfg.size_batch, N)
# device = th.device(cfg.device if th.cuda.is_available() else 'cpu')
devicename = 'E5 2696v3' if device == 'cpu' else th.cuda.get_device_name(int(str(device)[-1]))
print(device)
print(devicename)
print("Torch Version: ", th.__version__)
print("Torch CUDA Version: ", th.version.cuda)
print("CUDNN Version: ", th.backends.cudnn.version())
print("CUDA TF32: ", cudaTF32)
print("CUDNN TF32: ", cudnnTF32)
print("CUDNN Benchmark: ", benchmark)
print("CUDNN Deterministic: ", deterministic)
th.backends.cuda.matmul.allow_tf32 = cudaTF32
th.backends.cudnn.allow_tf32 = cudnnTF32
th.backends.cudnn.benchmark = benchmark
th.backends.cudnn.deterministic = deterministic
print("--->Orders for azimuth: ", modelcfg['Qas'])
print("--->Orders for range: ", modelcfg['Qrs'])
print("--->Convolution params: ", modelcfg['Convs'])
net = BaggingECELMs(Na, 1, Qas=modelcfg['Qas'], Convs=modelcfg['Convs'], xa=xa, ftshift=ftshift, seed=seed)
modelparamsaf = th.load(cfg.weightfile, map_location=device)
# print(modelparamsaf['network'].keys())
net.load_state_dict(modelparamsaf['network'])
net.to(device=device)
net.eval()
xa = ts.ppeaxis(Na, norm=True, shift=True, mode='fftfreq')
# xa = ts.fftfreq(Na, Na, norm=True, shift=True).reshape(1, Na)
loss_ent_func = tb.EntropyLoss('natural', cdim=-1, dim=(-3, -2), reduction='mean') # OK
loss_cts_func = tb.ContrastLoss('way1', cdim=-1, dim=(-3, -2), reduction='mean') # OK
loss_fro_func = tb.Pnorm(p=1, cdim=-1, dim=(-3, -2), reduction='mean')
lossvtest = net.ensemble_test(X, ca, cr, size_batch, loss_ent_func, loss_cts_func, loss_fro_func, device, name='Test')
net.plot(X, ca, cr, xa, index, 'test', outfolder, device)