-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_Harvard.py
120 lines (94 loc) · 3.79 KB
/
test_Harvard.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import math
import numpy as np
import torch
import scipy.io as sio
import torch.nn.functional as F
import torch.nn as nn
from torch import optim
from torch.nn.parallel import DataParallel
from torch.autograd import Variable
import torchvision
from utils import *
from utils_reid import *
import os
from torch.utils.data import DataLoader
from datasets import HavardDataset_Test as HavardDataset
from model import DNU
def get_dataloader(batch_size=64):
base_path = '/Harvard/test'
data_names = os.listdir(base_path)
data_paths = [ os.path.join(base_path, i) for i in mat_names if '.mat' in i]
mask_path = '/Harvard/Cu_48.mat'
datasets = [HavardDataset(data_paths[i], mask_path) for i in range(len(data_paths))]
#print(len(hd))
train_dataloaders = [DataLoader(datasets[i], batch_size=batch_size, shuffle=False, drop_last=False) for i in range(len(data_paths))]
return train_dataloaders, datasets
def convert_2_tensor(patch, mask, TVT):
[bs, r, c, nC] = patch.shape
patch_Phi = mask
patch_Phi = patch_Phi #/ 31.0
patch_x_tensor = patch.view(bs, r, c, nC).permute(0, 3, 1, 2)
patch_Phi_tensor = patch_Phi.view(bs, r, c, nC).permute(0, 3, 1, 2)
patch_x_tensor = TVT(patch_x_tensor)
patch_Phi_tensor = TVT(patch_Phi_tensor)
patch_x_tensor = patch_x_tensor.float()
patch_Phi_tensor = patch_Phi_tensor.float()
patch_g_tensor = A_torch(patch_x_tensor, patch_Phi_tensor)
patch_f0_tensor = At_torch(patch_g_tensor, patch_Phi_tensor)
patch_PhiPhiT_tensor = torch.sum(patch_Phi_tensor*patch_Phi_tensor, axis=1)
patch_f0_tensor = patch_f0_tensor.float()
patch_g_tensor = patch_g_tensor.float()
patch_PhiPhiT_tensor = patch_PhiPhiT_tensor.float()
return patch_x_tensor, patch_f0_tensor, patch_g_tensor, patch_Phi_tensor, patch_PhiPhiT_tensor
def test():
batch_size = 11
TOTAL_ITERS = 400000
ITERS_PER_SAVE = 1
CRITIC_ITERS = [150000, 300000]
#devices = [1
devices = [0,]
SAVE_PATH = './results'
mse_loss = nn.MSELoss()
model = DNU(31, K=11)
#model.load_state_dict(torch.load('./ckpts_omega0.95/ep_150.pth'))
#model.load_state_dict(torch.load('./ckpts_tf_init_adam_tf/ep_160.pth'))
model.load_state_dict(torch.load('./ckpts_Harvard/ep_160.pth'))
model.eval()
TVT, TMO = set_devices(devices)
model_w = DataParallel(model)
#dnu = dnu.cuda()
lr = 0.001
optimizer = optim.Adam(model.parameters(), lr=lr)
dataloaders, datasets = get_dataloader(batch_size=batch_size)
modules_optims = [model, optimizer]
TMO(modules_optims)
base_path = '/Harvard/test'
file_names = os.listdir(base_path)
data_names = [ i for i in file_names if '.mat' in i]
i = 0
with torch.no_grad():
for dataloader in dataloaders:
count = 0
print('-------------------', mat_id[i], '-------------------')
rec_patch = None
for patch, mask in dataloader:
x, f0, g, Phi, PhiPhiT = convert_2_tensor(patch, mask, TVT)
x = TVT(x)
f0 = TVT(f0)
g = TVT(g)
Phi = TVT(Phi)
PhiPhiT = TVT(PhiPhiT)
dnu_out = model_w(f0, g, Phi, PhiPhiT)
if rec_patch is None:
rec_patch = dnu_out.cpu()
else:
rec_patch = torch.cat((rec_patch, dnu_out.cpu()), 0)
loss = mse_loss(dnu_out, x)
count += 1
if count % 1 == 0:
print('loss=%.6f' % (loss))
rec_patch = rec_patch.permute(2, 3, 1, 0).cpu().numpy() # N, h, w, c
sio.savemat(os.path.join(SAVE_PATH, '%d.mat'%data_names[i]), {'output':rec_patch, 'label': datasets[i].hyper})
i += 1
if __name__ == '__main__':
test()