diff --git a/src/models/mnist_module.py b/src/models/mnist_module.py index dd00e41b5..8827f05ce 100644 --- a/src/models/mnist_module.py +++ b/src/models/mnist_module.py @@ -56,7 +56,9 @@ def forward(self, x: torch.Tensor): def on_train_start(self): # by default lightning executes validation step sanity checks before training starts, - # so we need to make sure val_acc_best doesn't store accuracy from these checks + # so it's worth to make sure validation metrics don't store results from these checks + self.val_loss.reset() + self.val_acc.reset() self.val_acc_best.reset() def model_step(self, batch: Any):