Skip to content

Commit b443be7

Browse files
anw90muellerzr
andauthored
Make torch xla available on GPU (huggingface#2176)
* Make torch xla available on GPU * format code * fix documentation build error * update according to the comments * Replace DistributedType.TPU with DistributedType.XLA * make all ut pass * format code * update comments * skip test * format code * skip FSDPPluginIntegration for torchxla * bring back custom_sampler_check * fix ut * format code * format code --------- Co-authored-by: Zach Mueller <[email protected]>
1 parent 613ad70 commit b443be7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+356
-150
lines changed

docs/source/concept_guides/performance.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Why is this important? Under the hood this will set **5** different seed setting
4545
torch.manual_seed(seed)
4646
torch.cuda.manual_seed_all(seed)
4747
# ^^ safe to call this function even if cuda is not available
48-
if is_tpu_available():
48+
if is_torch_xla_available():
4949
xm.set_rng_state(seed)
5050
```
5151

docs/source/package_reference/utilities.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ These functionalities check the state of the current working environment includi
150150

151151
[[autodoc]] utils.is_torch_version
152152

153-
[[autodoc]] utils.is_tpu_available
153+
[[autodoc]] utils.is_torch_xla_available
154154

155155
[[autodoc]] utils.is_xpu_available
156156

docs/source/quicktour.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ To introduce special behavior in your script for TPUs you can check the `distrib
258258
```python docstyle-ignore
259259
from accelerate import DistributedType
260260

261-
if accelerator.distributed_type == DistributedType.TPU:
261+
if accelerator.distributed_type == DistributedType.XLA:
262262
# do something of static shape
263263
else:
264264
# go crazy and be dynamic

examples/by_feature/checkpointing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def tokenize_function(examples):
8585

8686
def collate_fn(examples):
8787
# On TPU it's best to pad everything to the same length or training will be very slow.
88-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
88+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
8989
# When using mixed precision we want round multiples of 8/16
9090
if accelerator.mixed_precision == "fp8":
9191
pad_to_multiple_of = 16
@@ -153,7 +153,7 @@ def training_function(config, args):
153153

154154
# If the batch size is too big we use gradient accumulation
155155
gradient_accumulation_steps = 1
156-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
156+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
157157
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
158158
batch_size = MAX_GPU_BATCH_SIZE
159159

examples/by_feature/cross_validation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def tokenize_function(examples):
105105

106106
def collate_fn(examples):
107107
# On TPU it's best to pad everything to the same length or training will be very slow.
108-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
108+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
109109
# When using mixed precision we want round multiples of 8/16
110110
if accelerator.mixed_precision == "fp8":
111111
pad_to_multiple_of = 16
@@ -156,7 +156,7 @@ def training_function(config, args):
156156

157157
# If the batch size is too big we use gradient accumulation
158158
gradient_accumulation_steps = 1
159-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
159+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
160160
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
161161
batch_size = MAX_GPU_BATCH_SIZE
162162

examples/by_feature/deepspeed_with_config_support.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def group_texts(examples):
511511
optimizer = optimizer_cls(optimizer_grouped_parameters, lr=args.learning_rate)
512512

513513
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
514-
if accelerator.distributed_type == DistributedType.TPU:
514+
if accelerator.distributed_type == DistributedType.XLA:
515515
model.tie_weights()
516516

517517
# Scheduler and math around the number of training steps.

examples/by_feature/early_stopping.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def tokenize_function(examples):
8080

8181
def collate_fn(examples):
8282
# On TPU it's best to pad everything to the same length or training will be very slow.
83-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
83+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
8484
# When using mixed precision we want round multiples of 8/16
8585
if accelerator.mixed_precision == "fp8":
8686
pad_to_multiple_of = 16
@@ -150,7 +150,7 @@ def training_function(config, args):
150150

151151
# If the batch size is too big we use gradient accumulation
152152
gradient_accumulation_steps = 1
153-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
153+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
154154
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
155155
batch_size = MAX_GPU_BATCH_SIZE
156156

examples/by_feature/fsdp_with_peak_mem_tracking.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,13 @@ def tokenize_function(examples):
208208

209209
# If the batch size is too big we use gradient accumulation
210210
gradient_accumulation_steps = 1
211-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
211+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
212212
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
213213
batch_size = MAX_GPU_BATCH_SIZE
214214

215215
def collate_fn(examples):
216216
# On TPU it's best to pad everything to the same length or training will be very slow.
217-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
217+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
218218
# When using mixed precision we want round multiples of 8/16
219219
if accelerator.mixed_precision == "fp8":
220220
pad_to_multiple_of = 16

examples/by_feature/gradient_accumulation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def tokenize_function(examples):
8080

8181
def collate_fn(examples):
8282
# On TPU it's best to pad everything to the same length or training will be very slow.
83-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
83+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
8484
# When using mixed precision we want round multiples of 8/16
8585
if accelerator.mixed_precision == "fp8":
8686
pad_to_multiple_of = 16
@@ -125,7 +125,7 @@ def training_function(config, args):
125125
accelerator = Accelerator(
126126
cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps
127127
)
128-
if accelerator.distributed_type == DistributedType.TPU and gradient_accumulation_steps > 1:
128+
if accelerator.distributed_type == DistributedType.XLA and gradient_accumulation_steps > 1:
129129
raise NotImplementedError(
130130
"Gradient accumulation on TPUs is currently not supported. Pass `gradient_accumulation_steps=1`"
131131
)

examples/by_feature/local_sgd.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def tokenize_function(examples):
8383

8484
def collate_fn(examples):
8585
# On TPU it's best to pad everything to the same length or training will be very slow.
86-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
86+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
8787
# When using mixed precision we want round multiples of 8/16
8888
if accelerator.mixed_precision == "fp8":
8989
pad_to_multiple_of = 16

examples/by_feature/megatron_lm_gpt_pretraining.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def group_texts(examples):
505505
)
506506

507507
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
508-
if accelerator.distributed_type == DistributedType.TPU:
508+
if accelerator.distributed_type == DistributedType.XLA:
509509
model.tie_weights()
510510

511511
# We need to recalculate our total training steps as the size of the training dataloader may have changed.

examples/by_feature/memory.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def tokenize_function(examples):
8686

8787
def collate_fn(examples):
8888
# On TPU it's best to pad everything to the same length or training will be very slow.
89-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
89+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
9090
# When using mixed precision we want round multiples of 8/16
9191
if accelerator.mixed_precision == "fp8":
9292
pad_to_multiple_of = 16

examples/by_feature/multi_process_metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def tokenize_function(examples):
8787

8888
def collate_fn(examples):
8989
# On TPU it's best to pad everything to the same length or training will be very slow.
90-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
90+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
9191
# When using mixed precision we want round multiples of 8/16
9292
if accelerator.mixed_precision == "fp8":
9393
pad_to_multiple_of = 16
@@ -138,7 +138,7 @@ def training_function(config, args):
138138

139139
# If the batch size is too big we use gradient accumulation
140140
gradient_accumulation_steps = 1
141-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
141+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
142142
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
143143
batch_size = MAX_GPU_BATCH_SIZE
144144

examples/by_feature/tracking.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def tokenize_function(examples):
8585

8686
def collate_fn(examples):
8787
# On TPU it's best to pad everything to the same length or training will be very slow.
88-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
88+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
8989
# When using mixed precision we want round multiples of 8/16
9090
if accelerator.mixed_precision == "fp8":
9191
pad_to_multiple_of = 16
@@ -148,7 +148,7 @@ def training_function(config, args):
148148

149149
# If the batch size is too big we use gradient accumulation
150150
gradient_accumulation_steps = 1
151-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
151+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
152152
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
153153
batch_size = MAX_GPU_BATCH_SIZE
154154

examples/complete_nlp_example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,13 @@ def tokenize_function(examples):
102102

103103
# If the batch size is too big we use gradient accumulation
104104
gradient_accumulation_steps = 1
105-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
105+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
106106
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
107107
batch_size = MAX_GPU_BATCH_SIZE
108108

109109
def collate_fn(examples):
110110
# On TPU it's best to pad everything to the same length or training will be very slow.
111-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
111+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
112112
# When using mixed precision we want round multiples of 8/16
113113
if accelerator.mixed_precision == "fp8":
114114
pad_to_multiple_of = 16

examples/nlp_example.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def tokenize_function(examples):
7777
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
7878

7979
def collate_fn(examples):
80-
# On TPU it's best to pad everything to the same length or training will be very slow.
81-
max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
80+
# For Torchxla, it's best to pad everything to the same length or training will be very slow.
81+
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
8282
# When using mixed precision we want round multiples of 8/16
8383
if accelerator.mixed_precision == "fp8":
8484
pad_to_multiple_of = 16
@@ -123,7 +123,7 @@ def training_function(config, args):
123123

124124
# If the batch size is too big we use gradient accumulation
125125
gradient_accumulation_steps = 1
126-
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
126+
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
127127
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
128128
batch_size = MAX_GPU_BATCH_SIZE
129129

src/accelerate/accelerator.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
is_msamp_available,
8383
is_npu_available,
8484
is_torch_version,
85-
is_tpu_available,
85+
is_torch_xla_available,
8686
is_xpu_available,
8787
load_fsdp_model,
8888
load_fsdp_optimizer,
@@ -133,7 +133,8 @@
133133
from torch.distributed.algorithms.join import Join
134134

135135

136-
if is_tpu_available(check_device=False):
136+
if is_torch_xla_available():
137+
import torch_xla.amp as xamp
137138
import torch_xla.core.xla_model as xm
138139
import torch_xla.distributed.xla_multiprocessing as xmp
139140

@@ -397,7 +398,7 @@ def __init__(
397398
if (
398399
(mixed_precision != "bf16")
399400
and getattr(self.state, "downcast_bfloat", False)
400-
and (self.state.distributedType != DistributedType.TPU)
401+
and (self.state.distributedType != DistributedType.XLA)
401402
):
402403
raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")
403404

@@ -414,7 +415,7 @@ def __init__(
414415
self.gradient_state = GradientState(
415416
gradient_accumulation_plugin=gradient_accumulation_plugin,
416417
)
417-
if self.state.distributed_type == DistributedType.TPU:
418+
if self.state.distributed_type == DistributedType.XLA:
418419
if self.gradient_state.num_steps != 1:
419420
raise ValueError(
420421
"Gradient accumulation is not supported on TPU. Please set `gradient_accumulation_steps` to 1 and don't pass in a `GradientAccumulationPlugin` object."
@@ -436,13 +437,17 @@ def __init__(
436437
and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
437438
):
438439
self.native_amp = True
439-
if self.device.type not in ("xpu", "cuda", "mps", "npu"):
440+
if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla") or is_torch_xla_available(
441+
check_is_tpu=True
442+
):
440443
raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).")
441444
kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
442445
if self.distributed_type == DistributedType.FSDP:
443446
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
444447

445448
self.scaler = ShardedGradScaler(**kwargs)
449+
elif is_torch_xla_available(check_is_gpu=True):
450+
self.scaler = xamp.GradScaler(**kwargs)
446451
elif is_npu_available():
447452
self.scaler = torch.npu.amp.GradScaler(**kwargs)
448453
else:
@@ -456,7 +461,7 @@ def __init__(
456461
self.native_amp = True
457462
else:
458463
self.native_amp = is_bf16_available(True)
459-
if mixed_precision == "bf16" and not self.native_amp and not is_tpu_available():
464+
if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(check_is_gpu=True):
460465
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
461466

462467
# Start of internal step tracking
@@ -1193,7 +1198,7 @@ def prepare(self, *args, device_placement=None):
11931198
# On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will
11941199
# have parameters disconnected from the model (so no training :-( ).
11951200
# If the model and optimizer have parameters on different devices we raise an error.
1196-
if self.distributed_type == DistributedType.TPU:
1201+
if self.distributed_type == DistributedType.XLA:
11971202
model_device, optimizer_device = self._get_devices()
11981203
if model_device is not None and optimizer_device is not None and model_device != optimizer_device:
11991204
raise ValueError(
@@ -1205,7 +1210,7 @@ def prepare(self, *args, device_placement=None):
12051210
)
12061211

12071212
# If we're dealing with device placement, this deals with that by...
1208-
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.TPU
1213+
tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
12091214
if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"):
12101215
# 1. grabbing old model parameters
12111216
old_named_params = self._get_named_parameters(*args)
@@ -1406,7 +1411,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
14061411
elif self.distributed_type == DistributedType.MULTI_CPU:
14071412
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
14081413
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
1409-
elif self.distributed_type == DistributedType.TPU and self.state.fork_launched:
1414+
elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
14101415
model = xmp.MpModelWrapper(model).to(self.device)
14111416
# torch.compile should be called last and only if the model isn't already compiled.
14121417
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
@@ -1843,7 +1848,7 @@ def prepare_data_loader(
18431848
self._dataloaders.append(data_loader)
18441849
return data_loader
18451850
if device_placement is None:
1846-
device_placement = self.device_placement if self.distributed_type != DistributedType.TPU else False
1851+
device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False
18471852
prepared_data_loader = prepare_data_loader(
18481853
data_loader,
18491854
self.device,
@@ -2056,10 +2061,6 @@ def unscale_gradients(self, optimizer=None):
20562061
for opt in optimizer:
20572062
while isinstance(opt, AcceleratedOptimizer):
20582063
opt = opt.optimizer
2059-
# Reduce gradients first for XLA
2060-
if self.distributed_type == DistributedType.TPU:
2061-
gradients = xm._fetch_gradients(opt)
2062-
self.reduce(gradients, scale=1.0 / self.num_processes)
20632064
self.scaler.unscale_(opt)
20642065

20652066
def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
@@ -2097,6 +2098,19 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
20972098
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
20982099
# We cannot return the gradient norm because DeepSpeed does it.
20992100
return None
2101+
elif self.distributed_type == DistributedType.XLA:
2102+
# Reduce gradients first for XLA
2103+
for acc_opt in self._optimizers:
2104+
if not acc_opt.gradient_state.is_xla_gradients_synced:
2105+
opt = acc_opt
2106+
while isinstance(opt, AcceleratedOptimizer):
2107+
opt = opt.optimizer
2108+
gradients = xm._fetch_gradients(opt)
2109+
# Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
2110+
# one by one in self.reduce is non-inplace.
2111+
xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
2112+
# Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
2113+
acc_opt.gradient_state.is_xla_gradients_synced = True
21002114
self.unscale_gradients()
21012115
return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
21022116

@@ -2714,7 +2728,7 @@ def _inner(folder):
27142728
os.makedirs(output_dir, exist_ok=True)
27152729
logger.info(f"Saving current state to {output_dir}")
27162730

2717-
if self.distributed_type == DistributedType.TPU:
2731+
if self.distributed_type == DistributedType.XLA:
27182732
# Finish running the previous step before checkpointing
27192733
xm.mark_step()
27202734

0 commit comments

Comments
 (0)