Skip to content

Commit

Permalink
Potential fix for training stuck caused by data loader failure. (face…
Browse files Browse the repository at this point in the history
…bookresearch#638)

Summary:
Pull Request resolved: facebookresearch#638

Potential fix for training stuck caused by data loader failure.

Reviewed By: rbgirshick

Differential Revision: D9513621

fbshipit-source-id: 123974eac83f40ef2f582a90fedea790fdc442d1
  • Loading branch information
newstzpz authored and facebook-github-bot committed Aug 28, 2018
1 parent c9ed587 commit 1ecd603
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
3 changes: 3 additions & 0 deletions detectron/roi_data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def start(self, prefill=False):
self.shutdown()
break

def should_stop(self):
return self.coordinator.should_stop()

def shutdown(self):
self.coordinator.request_stop()
self.coordinator.wait_for_stop()
Expand Down
14 changes: 10 additions & 4 deletions detectron/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

def train_model():
"""Model training loop."""
logger = logging.getLogger(__name__)
model, weights_file, start_iter, checkpoints, output_dir = create_model()
if 'final' in checkpoints:
# The final model was found in the output directory, so nothing to do
Expand All @@ -61,6 +60,8 @@ def train_model():
CHECKPOINT_PERIOD = int(cfg.TRAIN.SNAPSHOT_ITERS / cfg.NUM_GPUS)

for cur_iter in range(start_iter, cfg.SOLVER.MAX_ITER):
if model.roi_data_loader.should_stop():
handle_critical_error(model, 'roi_data_loader failed')
training_stats.IterTic()
lr = model.UpdateWorkspaceLr(cur_iter, lr_policy.get_lr_at_iter(cur_iter))
workspace.RunNet(model.net.Proto().name)
Expand All @@ -82,9 +83,7 @@ def train_model():
training_stats.ResetIterTimer()

if np.isnan(training_stats.iter_total_loss):
logger.critical('Loss is NaN, exiting...')
model.roi_data_loader.shutdown()
envu.exit_on_error()
handle_critical_error(model, 'Loss is NaN')

# Save the final model
checkpoints['final'] = os.path.join(output_dir, 'model_final.pkl')
Expand All @@ -94,6 +93,13 @@ def train_model():
return checkpoints


def handle_critical_error(model, msg):
logger = logging.getLogger(__name__)
logger.critical(msg)
model.roi_data_loader.shutdown()
raise Exception(msg)


def create_model():
"""Build the model and look for saved model checkpoints in case we can
resume from one.
Expand Down

0 comments on commit 1ecd603

Please sign in to comment.