Skip to content

Commit

Permalink
Replaces xxx_required with requires_backends (huggingface#20715)
Browse files Browse the repository at this point in the history
* Replaces xxx_required with requires_backends

* Fixup
  • Loading branch information
amyeroberts authored Dec 14, 2022
1 parent 7c9e2f2 commit 7b23a58
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 58 deletions.
10 changes: 5 additions & 5 deletions src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments


Expand Down Expand Up @@ -76,8 +76,8 @@ def __init__(self, **kwargs):
)

@cached_property
@torch_required
def _setup_devices(self) -> Tuple["torch.device", int]:
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if not self.cuda:
device = torch.device("cpu")
Expand All @@ -95,19 +95,19 @@ def is_tpu(self):
return is_torch_tpu_available() and self.tpu

@property
@torch_required
def device_idx(self) -> int:
requires_backends(self, ["torch"])
# TODO(PVP): currently only single GPU is supported
return torch.cuda.current_device()

@property
@torch_required
def device(self) -> "torch.device":
requires_backends(self, ["torch"])
return self._setup_devices[0]

@property
@torch_required
def n_gpu(self):
requires_backends(self, ["torch"])
return self._setup_devices[1]

@property
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/benchmark/benchmark_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..utils import cached_property, is_tf_available, logging, tf_required
from ..utils import cached_property, is_tf_available, logging, requires_backends
from .benchmark_args_utils import BenchmarkArguments


Expand Down Expand Up @@ -77,8 +77,8 @@ def __init__(self, **kwargs):
)

@cached_property
@tf_required
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
tpu = None
if self.tpu:
try:
Expand All @@ -91,8 +91,8 @@ def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver
return tpu

@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
requires_backends(self, ["tf"])
if self.is_tpu:
tf.config.experimental_connect_to_cluster(self._setup_tpu)
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
Expand All @@ -111,23 +111,23 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.clus
return strategy

@property
@tf_required
def is_tpu(self) -> bool:
requires_backends(self, ["tf"])
return self._setup_tpu is not None

@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
requires_backends(self, ["tf"])
return self._setup_strategy

@property
@tf_required
def gpu_list(self):
requires_backends(self, ["tf"])
return tf.config.list_physical_devices("GPU")

@property
@tf_required
def n_gpu(self) -> int:
requires_backends(self, ["tf"])
if self.cuda:
return len(self.gpu_list)
return 0
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
is_torch_device,
is_torch_dtype,
logging,
torch_required,
requires_backends,
)


Expand Down Expand Up @@ -175,7 +175,6 @@ def as_tensor(value):

return self

@torch_required
def to(self, *args, **kwargs) -> "BatchFeature":
"""
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
Expand All @@ -190,6 +189,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
Returns:
[`BatchFeature`]: The same instance after modification.
"""
requires_backends(self, ["torch"])
import torch # noqa

new_data = {}
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,8 @@
is_vision_available,
replace_return_docstrings,
requires_backends,
tf_required,
to_numpy,
to_py_obj,
torch_only_method,
torch_required,
torch_version,
)
4 changes: 2 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
is_torch_device,
is_torch_tensor,
logging,
requires_backends,
to_py_obj,
torch_required,
)


Expand Down Expand Up @@ -739,7 +739,6 @@ def convert_to_tensors(

return self

@torch_required
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
Expand All @@ -750,6 +749,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
Returns:
[`BatchEncoding`]: The same instance after modification.
"""
requires_backends(self, ["torch"])

# This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
is_torch_tpu_available,
logging,
requires_backends,
torch_required,
)


Expand Down Expand Up @@ -1386,8 +1385,8 @@ def ddp_timeout_delta(self) -> timedelta:
return timedelta(seconds=self.ddp_timeout)

@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
requires_backends(self, ["torch"])
logger.info("PyTorch: setting up devices")
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
logger.warning(
Expand Down Expand Up @@ -1537,15 +1536,14 @@ def _setup_devices(self) -> "torch.device":
return device

@property
@torch_required
def device(self) -> "torch.device":
"""
The device used by this process.
"""
requires_backends(self, ["torch"])
return self._setup_devices

@property
@torch_required
def n_gpu(self):
"""
The number of GPUs used by this process.
Expand All @@ -1554,12 +1552,12 @@ def n_gpu(self):
This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1.
"""
requires_backends(self, ["torch"])
# Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices
return self._n_gpu

@property
@torch_required
def parallel_mode(self):
"""
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
Expand All @@ -1570,6 +1568,7 @@ def parallel_mode(self):
`torch.nn.DistributedDataParallel`).
- `ParallelMode.TPU`: several TPU cores.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return ParallelMode.TPU
elif is_sagemaker_mp_enabled():
Expand All @@ -1584,11 +1583,12 @@ def parallel_mode(self):
return ParallelMode.NOT_PARALLEL

@property
@torch_required
def world_size(self):
"""
The number of processes used in parallel.
"""
requires_backends(self, ["torch"])

if is_torch_tpu_available():
return xm.xrt_world_size()
elif is_sagemaker_mp_enabled():
Expand All @@ -1600,11 +1600,11 @@ def world_size(self):
return 1

@property
@torch_required
def process_index(self):
"""
The index of the current process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_ordinal()
elif is_sagemaker_mp_enabled():
Expand All @@ -1616,11 +1616,11 @@ def process_index(self):
return 0

@property
@torch_required
def local_process_index(self):
"""
The index of the local process used.
"""
requires_backends(self, ["torch"])
if is_torch_tpu_available():
return xm.get_local_ordinal()
elif is_sagemaker_mp_enabled():
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/training_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional, Tuple

from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
from .utils import cached_property, is_tf_available, logging, requires_backends


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})

@cached_property
@tf_required
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
requires_backends(self, ["tf"])
logger.info("Tensorflow: setting up strategy")

gpus = tf.config.list_physical_devices("GPU")
Expand Down Expand Up @@ -234,19 +234,19 @@ def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
return strategy

@property
@tf_required
def strategy(self) -> "tf.distribute.Strategy":
"""
The strategy used for distributed training.
"""
requires_backends(self, ["tf"])
return self._setup_strategy

@property
@tf_required
def n_replicas(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends(self, ["tf"])
return self._setup_strategy.num_replicas_in_sync

@property
Expand Down Expand Up @@ -276,11 +276,11 @@ def eval_batch_size(self) -> int:
return per_device_batch_size * self.n_replicas

@property
@tf_required
def n_gpu(self) -> int:
"""
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
"""
requires_backends(self, ["tf"])
warnings.warn(
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
FutureWarning,
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
tf_required,
torch_only_method,
torch_required,
torch_version,
)

Expand Down
26 changes: 1 addition & 25 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
import warnings
from collections import OrderedDict
from functools import lru_cache, wraps
from functools import lru_cache
from itertools import chain
from types import ModuleType
from typing import Any
Expand Down Expand Up @@ -1039,30 +1039,6 @@ def __getattribute__(cls, key):
requires_backends(cls, cls._backends)


def torch_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_torch_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")

return wrapper


def tf_required(func):
# Chose a different decorator name than in tests so it's clear they are not the same.
@wraps(func)
def wrapper(*args, **kwargs):
if is_tf_available():
return func(*args, **kwargs)
else:
raise ImportError(f"Method `{func.__name__}` requires TF.")

return wrapper


def is_torch_fx_proxy(x):
if is_torch_fx_available():
import torch.fx
Expand Down

0 comments on commit 7b23a58

Please sign in to comment.