Skip to content

Commit

Permalink
[TPU] Support multi-host inference (vllm-project#7457)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 13, 2024
1 parent 16422ea commit a08df83
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA.
Requirements
------------

* Google Cloud TPU VM (single host)
* Google Cloud TPU VM (single & multi host)
* TPU versions: v5e, v5p, v4
* Python: 3.10

Expand Down
13 changes: 10 additions & 3 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ray
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
Expand All @@ -18,9 +19,15 @@ def __init__(self, group: ProcessGroup):
return
self.disabled = False

local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit a08df83

Please sign in to comment.