From 1e08009384014957b436aa13042223ee5b62a726 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 13 Jan 2025 10:48:03 +0000 Subject: [PATCH] Fix RAFT input dimension check --- torchvision/models/optical_flow/raft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index c294777ee6f..3622887e3a0 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -486,7 +486,7 @@ def forward(self, image1, image2, num_flow_updates: int = 12): batch_size, _, h, w = image1.shape if (h, w) != image2.shape[-2:]: raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") - if not (h % 8 == 0) and (w % 8 == 0): + if not ((h % 8 == 0) and (w % 8 == 0)): raise ValueError(f"input image H and W should be divisible by 8, instead got {h} (h) and {w} (w)") fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))