forked from renmengye/rec-attend-public
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
98 lines (89 loc) · 2.91 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import division
import cv2
import numpy as np
import os
import time
from utils import logger
from utils import BatchIterator, ConcurrentBatchIterator
from runner import RunnerBase
import analysis
log = logger.get()
class OneTimeEvalBase(RunnerBase):
def __init__(self,
sess,
model,
dataset,
opt,
model_opt,
outputs,
start_idx=-1,
end_idx=-1):
self.dataset = dataset
self.log = logger.get()
self.model_opt = model_opt
self.opt = opt
self.input_variables = self.get_input_variables()
if start_idx != -1 and end_idx != -1:
if start_idx < 0 or end_idx < 0:
self.log.fatal('Indices must be non-negative.')
elif start_idx >= end_idx:
self.log.fatal('End index must be greater than start index.')
num_ex = end_idx - start_idx
if end_idx > dataset.get_dataset_size():
self.log.warning('End index exceeds dataset size.')
end_idx = dataset.get_dataset_size()
num_ex = end_idx - start_idx
self.log.info('Running partial dataset: start {} end {}'.format(start_idx,
end_idx))
self.all_idx = np.arange(start_idx, end_idx)
else:
self.log.info('Running through entire dataset.')
num_ex = dataset.get_dataset_size()
self.all_idx = np.arange(num_ex)
if num_ex == -1:
num_ex = dataset.get_dataset_size()
batch_iter = BatchIterator(
num_ex,
batch_size=opt['batch_size'],
get_fn=self.get_batch,
cycle=False,
shuffle=False)
if opt['prefetch']:
batch_iter = ConcurrentBatchIterator(
batch_iter,
max_queue_size=opt['queue_size'],
num_threads=opt['num_worker'],
log_queue=-1)
super(OneTimeEvalBase, self).__init__(
sess,
model,
batch_iter,
outputs,
num_batch=1,
phase_train=False,
increment_step=False)
pass
def get_input_variables(self):
variables = ['x', 's_gt', 'idx_map']
if 'add_d_out' in self.model_opt:
if self.model_opt['add_d_out']:
variables.append('d_out')
if 'add_y_out' in self.model_opt:
if self.model_opt['add_y_out']:
variables.append('y_out')
return set(variables)
def get_batch(self, idx):
"""Transform a dataset get_batch into a dictionary to feed."""
idx_new = self.all_idx[idx]
_batch = self.dataset.get_batch(idx_new, variables=self.input_variables)
batch = {}
batch['x'] = _batch['x']
if 'add_d_out' in self.model_opt:
if self.model_opt['add_d_out']:
batch['d_in'] = _batch['d_out']
if 'add_y_out' in self.model_opt:
if self.model_opt['add_y_out']:
batch['y_in'] = _batch['y_out']
batch['idx_map'] = _batch['idx_map']
batch['_s_gt'] = _batch['s_gt']
return batch