From bbeff8d021daaacc96eab4f3489bf12516cd0a2d Mon Sep 17 00:00:00 2001 From: QIN2DIM <62018067+QIN2DIM@users.noreply.github.com> Date: Mon, 12 Sep 2022 16:07:47 +0800 Subject: [PATCH] Update resnet.py --- src/factories/resnet.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/factories/resnet.py b/src/factories/resnet.py index 23d66060..b11263d7 100644 --- a/src/factories/resnet.py +++ b/src/factories/resnet.py @@ -240,6 +240,11 @@ def _train( def _val(self, model: nn.modules, data_loader: DataLoader): total_acc = 0 + total_tp = 0 + total_tn = 0 + total_fp = 0 + total_fn = 0 + for _, (img, label) in enumerate(data_loader): img = img.to(self.DEVICE) label = label.to(self.DEVICE)