From 64fbe9effdf49212610bcd3b8d3cf7e1516c76ae Mon Sep 17 00:00:00 2001 From: Ricky Chen Date: Thu, 6 Apr 2023 11:51:21 -0700 Subject: [PATCH] fix tensor.real for earlier versions of PyTorch --- torchdiffeq/_impl/misc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchdiffeq/_impl/misc.py b/torchdiffeq/_impl/misc.py index cf96cd35..685cc8da 100644 --- a/torchdiffeq/_impl/misc.py +++ b/torchdiffeq/_impl/misc.py @@ -182,7 +182,9 @@ def forward(self, t, y, *, perturb=Perturb.NONE): # This dtype change here might be buggy. # The exact time value should be determined inside the solver, # but this can slightly change it due to numerical differences during casting. - t = t.real.to(y.abs().dtype) + if torch.is_complex(t): + t = t.real + t = t.to(y.abs().dtype) if perturb is Perturb.NEXT: # Replace with next smallest representable value. t = _nextafter(t, t + 1)