82
82
is_msamp_available ,
83
83
is_npu_available ,
84
84
is_torch_version ,
85
- is_tpu_available ,
85
+ is_torch_xla_available ,
86
86
is_xpu_available ,
87
87
load_fsdp_model ,
88
88
load_fsdp_optimizer ,
133
133
from torch .distributed .algorithms .join import Join
134
134
135
135
136
- if is_tpu_available (check_device = False ):
136
+ if is_torch_xla_available ():
137
+ import torch_xla .amp as xamp
137
138
import torch_xla .core .xla_model as xm
138
139
import torch_xla .distributed .xla_multiprocessing as xmp
139
140
@@ -397,7 +398,7 @@ def __init__(
397
398
if (
398
399
(mixed_precision != "bf16" )
399
400
and getattr (self .state , "downcast_bfloat" , False )
400
- and (self .state .distributedType != DistributedType .TPU )
401
+ and (self .state .distributedType != DistributedType .XLA )
401
402
):
402
403
raise ValueError ("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU" )
403
404
@@ -414,7 +415,7 @@ def __init__(
414
415
self .gradient_state = GradientState (
415
416
gradient_accumulation_plugin = gradient_accumulation_plugin ,
416
417
)
417
- if self .state .distributed_type == DistributedType .TPU :
418
+ if self .state .distributed_type == DistributedType .XLA :
418
419
if self .gradient_state .num_steps != 1 :
419
420
raise ValueError (
420
421
"Gradient accumulation is not supported on TPU. Please set `gradient_accumulation_steps` to 1 and don't pass in a `GradientAccumulationPlugin` object."
@@ -436,13 +437,17 @@ def __init__(
436
437
and self .distributed_type not in (DistributedType .DEEPSPEED , DistributedType .MEGATRON_LM )
437
438
):
438
439
self .native_amp = True
439
- if self .device .type not in ("xpu" , "cuda" , "mps" , "npu" ):
440
+ if self .device .type not in ("xpu" , "cuda" , "mps" , "npu" , "xla" ) or is_torch_xla_available (
441
+ check_is_tpu = True
442
+ ):
440
443
raise ValueError (f"fp16 mixed precision requires a GPU (not { self .device .type !r} )." )
441
444
kwargs = self .scaler_handler .to_kwargs () if self .scaler_handler is not None else {}
442
445
if self .distributed_type == DistributedType .FSDP :
443
446
from torch .distributed .fsdp .sharded_grad_scaler import ShardedGradScaler
444
447
445
448
self .scaler = ShardedGradScaler (** kwargs )
449
+ elif is_torch_xla_available (check_is_gpu = True ):
450
+ self .scaler = xamp .GradScaler (** kwargs )
446
451
elif is_npu_available ():
447
452
self .scaler = torch .npu .amp .GradScaler (** kwargs )
448
453
else :
@@ -456,7 +461,7 @@ def __init__(
456
461
self .native_amp = True
457
462
else :
458
463
self .native_amp = is_bf16_available (True )
459
- if mixed_precision == "bf16" and not self .native_amp and not is_tpu_available ( ):
464
+ if mixed_precision == "bf16" and not self .native_amp and not is_torch_xla_available ( check_is_gpu = True ):
460
465
raise ValueError ("bf16 mixed precision requires PyTorch >= 1.10 and a supported device." )
461
466
462
467
# Start of internal step tracking
@@ -1193,7 +1198,7 @@ def prepare(self, *args, device_placement=None):
1193
1198
# On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will
1194
1199
# have parameters disconnected from the model (so no training :-( ).
1195
1200
# If the model and optimizer have parameters on different devices we raise an error.
1196
- if self .distributed_type == DistributedType .TPU :
1201
+ if self .distributed_type == DistributedType .XLA :
1197
1202
model_device , optimizer_device = self ._get_devices ()
1198
1203
if model_device is not None and optimizer_device is not None and model_device != optimizer_device :
1199
1204
raise ValueError (
@@ -1205,7 +1210,7 @@ def prepare(self, *args, device_placement=None):
1205
1210
)
1206
1211
1207
1212
# If we're dealing with device placement, this deals with that by...
1208
- tpu_should_fix_optimizer = self .device_placement and self .distributed_type == DistributedType .TPU
1213
+ tpu_should_fix_optimizer = self .device_placement and self .distributed_type == DistributedType .XLA
1209
1214
if tpu_should_fix_optimizer or (self .mixed_precision == "fp8" and self .fp8_recipe_handler .backend == "TE" ):
1210
1215
# 1. grabbing old model parameters
1211
1216
old_named_params = self ._get_named_parameters (* args )
@@ -1406,7 +1411,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
1406
1411
elif self .distributed_type == DistributedType .MULTI_CPU :
1407
1412
kwargs = self .ddp_handler .to_kwargs () if self .ddp_handler is not None else {}
1408
1413
model = torch .nn .parallel .DistributedDataParallel (model , ** kwargs )
1409
- elif self .distributed_type == DistributedType .TPU and self .state .fork_launched :
1414
+ elif self .distributed_type == DistributedType .XLA and self .state .fork_launched :
1410
1415
model = xmp .MpModelWrapper (model ).to (self .device )
1411
1416
# torch.compile should be called last and only if the model isn't already compiled.
1412
1417
if self .state .dynamo_plugin .backend != DynamoBackend .NO and not is_compiled_module (model ):
@@ -1843,7 +1848,7 @@ def prepare_data_loader(
1843
1848
self ._dataloaders .append (data_loader )
1844
1849
return data_loader
1845
1850
if device_placement is None :
1846
- device_placement = self .device_placement if self .distributed_type != DistributedType .TPU else False
1851
+ device_placement = self .device_placement if self .distributed_type != DistributedType .XLA else False
1847
1852
prepared_data_loader = prepare_data_loader (
1848
1853
data_loader ,
1849
1854
self .device ,
@@ -2056,10 +2061,6 @@ def unscale_gradients(self, optimizer=None):
2056
2061
for opt in optimizer :
2057
2062
while isinstance (opt , AcceleratedOptimizer ):
2058
2063
opt = opt .optimizer
2059
- # Reduce gradients first for XLA
2060
- if self .distributed_type == DistributedType .TPU :
2061
- gradients = xm ._fetch_gradients (opt )
2062
- self .reduce (gradients , scale = 1.0 / self .num_processes )
2063
2064
self .scaler .unscale_ (opt )
2064
2065
2065
2066
def clip_grad_norm_ (self , parameters , max_norm , norm_type = 2 ):
@@ -2097,6 +2098,19 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
2097
2098
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
2098
2099
# We cannot return the gradient norm because DeepSpeed does it.
2099
2100
return None
2101
+ elif self .distributed_type == DistributedType .XLA :
2102
+ # Reduce gradients first for XLA
2103
+ for acc_opt in self ._optimizers :
2104
+ if not acc_opt .gradient_state .is_xla_gradients_synced :
2105
+ opt = acc_opt
2106
+ while isinstance (opt , AcceleratedOptimizer ):
2107
+ opt = opt .optimizer
2108
+ gradients = xm ._fetch_gradients (opt )
2109
+ # Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor
2110
+ # one by one in self.reduce is non-inplace.
2111
+ xm .all_reduce ("sum" , gradients , scale = 1.0 / self .num_processes )
2112
+ # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
2113
+ acc_opt .gradient_state .is_xla_gradients_synced = True
2100
2114
self .unscale_gradients ()
2101
2115
return torch .nn .utils .clip_grad_norm_ (parameters , max_norm , norm_type = norm_type )
2102
2116
@@ -2714,7 +2728,7 @@ def _inner(folder):
2714
2728
os .makedirs (output_dir , exist_ok = True )
2715
2729
logger .info (f"Saving current state to { output_dir } " )
2716
2730
2717
- if self .distributed_type == DistributedType .TPU :
2731
+ if self .distributed_type == DistributedType .XLA :
2718
2732
# Finish running the previous step before checkpointing
2719
2733
xm .mark_step ()
2720
2734
0 commit comments