diff --git a/nerf/utils.py b/nerf/utils.py index 17b38c76..3d99ffaa 100644 --- a/nerf/utils.py +++ b/nerf/utils.py @@ -241,10 +241,12 @@ def report(self): class SSIMMeter: - def __init__(self): + def __init__(self, device=None): self.V = 0 self.N = 0 + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + def clear(self): self.V = 0 self.N = 0