PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!
Take a look at one of our Kaggle notebooks to get started:
To install PyTorch/XLA stable build in a new TPU VM:
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
To install PyTorch/XLA nightly build in a new TPU VM:
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev+cxx11-cp310-cp310-linux_x86_64.whl' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
# Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Starting from Pytorch/XLA 2.6, we'll provide wheels and docker images built with two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor tracing performance.
To install C++11 ABI flavored 2.6 wheels:
pip install torch==2.6.0+cpu.cxx11.abi torch_xla[tpu]==2.6.0+cxx11 \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://download.pytorch.org/whl/torch
To access C++11 ABI flavored docker image:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
If your model is tracing bound (e.g. you see that the host CPU is busy tracing the model while TPUs are idle), switching to the C++11 ABI wheels/docker images can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global batch size 1024:
- Pre-C++11 ABI MFU: 33%
- C++ ABI MFU: 39%
PyTorch/XLA now provides GPU support through a plugin package similar to libtpu
:
pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl
The newest stable version where PyTorch/XLA:GPU wheel is available is torch_xla
2.5. We do not offer a PyTorch/XLA:GPU wheel in the PyTorch/XLA 2.6 release. We
understand this is important and plan to reinstate GPU support by the 2.7 release.
PyTorch/XLA remains an open-source project and we welcome contributions from the
community to help maintain and improve the project. To contribute, please start
with the contributors guide.
To update your existing training loop, make the following changes:
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.core.xla_model as xm
def _mp_fn(index):
...
+ # Move the model paramters to your XLA device
+ model.to(xla.device())
for inputs, labels in train_loader:
+ with xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # xla.launch automatically selects the correct world size
+ xla.launch(_mp_fn, args=())
If you're using DistributedDataParallel
, make the following changes:
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.distributed.xla_backend
def _mp_fn(rank):
...
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ ddp_model = DDP(model, gradient_as_bucket_view=True)
- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])
for inputs, labels in train_loader:
+ with xla.step():
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xla.launch(_mp_fn, args=())
Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
The AI-Hypercomputer/tpu-recipes repo. contains examples for training and serving many LLM and diffusion models.
PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You
can now install the main build with pip install torch_xla
. To also install the
Cloud TPU plugin corresponding to your installed torch_xla
, install the optional tpu
dependencies after installing the main build with
pip install torch_xla[tpu] \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
GPU and nightly builds are available in our public GCS bucket.
Version | Cloud GPU VM Wheels |
---|---|
2.5 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
nightly (Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev-cp39-cp39-linux_x86_64.whl |
nightly (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev-cp310-cp310-linux_x86_64.whl |
nightly (Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev-cp311-cp311-linux_x86_64.whl |
nightly (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
You can also add yyyymmdd
like torch_xla-2.7.0.devyyyymmdd+cxx11
(or the latest dev version)
to get the nightly wheel of a specified date. Here is an example:
pip3 install torch==2.7.0.dev20250124+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124+cxx11-cp310-cp310-linux_x86_64.whl
The torch wheel version 2.7.0.dev20250124+cpu
can be found at https://download.pytorch.org/whl/nightly/torch/.
older versions
Version | Cloud TPU VMs Wheel |
---|---|
2.5 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.4 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.2 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (XRT + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl |
Version | GPU Wheel |
---|---|
2.5 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 + CUDA 11.8 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl |
nightly + CUDA 12.0 >= 2023/06/27 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Version | Cloud TPU VMs Docker |
---|---|
2.6 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm |
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm |
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm |
nightly python | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
nightly python (C++11 ABI) | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_cxx11 |
To use the above dockers, please pass --privileged --net host --shm-size=16G
along. Here is an example:
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash
Version | GPU CUDA 12.4 Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4 |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4 |
Version | GPU CUDA 12.1 Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1 |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1 |
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1 |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1 |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1 |
nightly | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 |
nightly at date | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD |
Version | GPU CUDA 11.8 + Docker |
---|---|
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8 |
2.0 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8 |
To run on compute instances with GPUs.
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).
The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!
See the contribution guide.
This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to [email protected]. For questions directed at Google, please send an email to [email protected]. For all other questions, please open up an issue in this repository here.
You can find additional useful reading materials in