Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch_xla2] torch_xla2.default_env() guard doesn't enforce XLATensor2 #8546

Open
zpcore opened this issue Jan 9, 2025 · 1 comment
Open

Comments

@zpcore
Copy link
Collaborator

zpcore commented Jan 9, 2025

🐛 Bug

torch_xla2.default_env() doesn't guarantee XLATensor2 device. Need explicitly move the tensor to 'jax' to convert into XLATensor2.

To Reproduce

import torch
import torch_xla2

env = torch_xla2.default_env()

with env:
  inputs = torch.randn(1)
  print(type(inputs))   ---> output torch.Tensor instead of XLATensor2

Environment

PyTorch/XLA git commit f52e202e825651374e44dbb1c79fab3724be7e0e

@qihqi
Copy link
Collaborator

qihqi commented Jan 9, 2025

Please try either:

import torch
import torch_xla2

env = torch_xla2.default_env()

with env:
  inputs = torch.randn(1, device='jax')
  print(type(inputs))   ---> output torch.Tensor instead of XLATensor2

OR,

import torch
import torch_xla2

env = torch_xla2.default_env()
env.config.use_torch_native_for_cpu_tensor = False # this will use XLATensor even for CPU

with env:
  inputs = torch.randn(1)
  print(type(inputs))   ---> output torch.Tensor instead of XLATensor2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants