diff --git a/train.py b/train.py index afe9c10232..ce05403767 100644 --- a/train.py +++ b/train.py @@ -117,6 +117,10 @@ def train(args, trainer, task, epoch_itr): valid_subsets = args.valid_subset.split(',') max_update = args.max_update or math.inf for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): + samples = [s for s in samples if len(s) > 0] + if len(samples) == 0: + continue + log_output = trainer.train_step(samples) if log_output is None: continue