forked from renmengye/rec-attend-public
-
Notifications
You must be signed in to change notification settings - Fork 0
/
full_model_pack.py
executable file
·74 lines (61 loc) · 2.29 KB
/
full_model_pack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
"""Run model inference, output to H5."""
from __future__ import division
import cv2
import os
import h5py
import numpy as np
from cmd_args_parser import EvalArgsParser, DataArgsParser
from evaluation import OneTimeEvalBase
from experiment import EvalExperimentBase
from analysis import RenderInstanceAnalyzer
from full_model import get_model
class PackRunner(OneTimeEvalBase):
def __init__(self, sess, model, dataset, train_opt, model_opt):
outputs = ['y_out', 's_out']
self.input_variables = set(['x', 'y_out', 'd_out', 'idx_map'])
super(PackRunner, self).__init__(sess, model, dataset, train_opt, model_opt,
outputs)
def get_batch(self, idx):
"""Transform a dataset get_batch into a dictionary to feed."""
_batch = self.dataset.get_batch(idx, variables=self.input_variables)
return {
'x': _batch['x'],
'y_in': _batch['y_out'],
'd_in': _batch['d_out'],
'idx_map': _batch['idx_map']
}
def write_log(self, results):
"""Process results
Args:
results: y_out, s_out
"""
inp = results['_batches'][0]
y_out = results['y_out']
s_out = results['s_out']
with h5py.File(self.dataset.h5_fname, 'r+') as h5f:
print inp['idx_map']
for ii in xrange(y_out.shape[0]):
idx = inp['idx_map'][ii]
group = h5f[self.dataset.get_str_id(idx)]
if 'instance_pred' in group:
del group['instance_pred']
for ins in xrange(y_out.shape[1]):
y_out_arr = y_out[ii, ins]
y_out_arr = (y_out_arr * 255).astype('uint8')
y_out_str = cv2.imencode('.png', y_out_arr)[1]
group['instance_pred/{:02d}'.format(ins)] = y_out_str
if 'score_pred' in group:
del group['score_pred']
group['score_pred'] = s_out[ii]
class PackExperiment(EvalExperimentBase):
def get_runner(self, split):
return PackRunner(self.sess, self.model, self.dataset[split], self.opt,
self.model_opt)
def get_model(self):
self.model_opt['use_knob'] = False
return get_model(self.model_opt)
if __name__ == '__main__':
parsers = {'default': EvalArgsParser(), 'data': DataArgsParser()}
PackExperiment.create_from_main(
'ris_pack', parsers=parsers, description='Pack ris output').run()