diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 720e63052db..f10dcab1e0e 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -324,7 +324,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index b0696675902..d7dc5976308 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -307,7 +307,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 770ae5d05ce..acb7156ad6e 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -104,7 +104,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 1e736e878dc..57cea971bc0 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -203,7 +203,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index bf7831d518c..b8425e085b1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -393,7 +393,7 @@ def __init__( try: device = next(self.parameters()).device except (AttributeError, StopIteration): - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device)) if critic_coef is not None: diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 6e280e1f0fa..bec40681f92 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -319,7 +319,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) self.register_buffer( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 64b09ea0433..9d60d51334a 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -394,7 +394,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha): min_alpha = min_alpha if min_alpha else 0.0 @@ -1121,7 +1121,7 @@ def __init__( try: device = next(self.parameters()).device except AttributeError: - device = torch.device("cpu") + device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device)) if bool(min_alpha) ^ bool(max_alpha):