Skip to content

Commit

Permalink
run evaluation in the end of trainer, even if no training iteration h…
Browse files Browse the repository at this point in the history
…appens

Reviewed By: rbgirshick

Differential Revision: D19319879

fbshipit-source-id: d84b9c3e957b0544979bc6706716222f31ebdd27
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Jan 8, 2020
1 parent dfc678a commit 0b62f13
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ SOLVER:
STEPS: (5500,)
MAX_ITER: 7000
TEST:
EXPECTED_RESULTS: [["bbox", "AP", 46.80, 1.1], ["segm", "AP", 38.93, 0.7], ["sem_seg", "mIoU", 63.99, 0.9], ["panoptic_seg", "PQ", 48.23, 0.8]]
EXPECTED_RESULTS: [["bbox", "AP", 46.80, 1.1], ["segm", "AP", 38.93, 0.7], ["sem_seg", "mIoU", 64.23, 1.0], ["panoptic_seg", "PQ", 48.23, 0.8]]
50 changes: 29 additions & 21 deletions detectron2/engine/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,34 +315,42 @@ def __init__(self, eval_period, eval_function):
"""
self._period = eval_period
self._func = eval_function
self._done_eval_at_last = False

def _do_eval(self):
results = self._func()

if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)

flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)

# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()

def after_step(self):
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
results = self._func()

if results:
assert isinstance(
results, dict
), "Eval function must return a dict. Got {} instead.".format(results)

flattened_results = flatten_results_dict(results)
for k, v in flattened_results.items():
try:
v = float(v)
except Exception:
raise ValueError(
"[EvalHook] eval_function should return a nested dict of float. "
"Got '{}: {}' instead.".format(k, v)
)
self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)

# Evaluation may take different time among workers.
# A barrier make them start the next iteration together.
comm.synchronize()
self._do_eval()
if is_final:
self._done_eval_at_last = True

def after_train(self):
if not self._done_eval_at_last:
self._do_eval()
# func is likely a closure that holds reference to the trainer
# therefore we clean it to avoid circular reference in the end
del self._func
Expand Down

0 comments on commit 0b62f13

Please sign in to comment.