Skip to content

Commit

Permalink
Device agnostic testing (huggingface#5612)
Browse files Browse the repository at this point in the history
* utils and test modifications to enable device agnostic testing

* device for manual seed in unet1d

* fix generator condition in vae test

* consistency changes to testing

* make style

* add device agnostic testing changes to source and one model test

* make dtype check fns private, log cuda fp16 case

* remove dtype checks from import utils, move to testing_utils

* adding tests for most model classes and one pipeline

* fix vae import
  • Loading branch information
arsalanu authored Dec 5, 2023
1 parent 6e22133 commit f427345
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 67 deletions.
179 changes: 178 additions & 1 deletion src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from distutils.util import strtobool
from io import BytesIO, StringIO
from pathlib import Path
from typing import List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -58,6 +58,17 @@
if is_torch_available():
import torch

# Set a backend environment variable for any extra module import required for a custom accelerator
if "DIFFUSERS_TEST_BACKEND" in os.environ:
backend = os.environ["DIFFUSERS_TEST_BACKEND"]
try:
_ = importlib.import_module(backend)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
to enable a specified backend.):\n{e}"
) from e

if "DIFFUSERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
try:
Expand Down Expand Up @@ -210,6 +221,36 @@ def require_torch_gpu(test_case):
)


# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
test_case
)


def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
test_case
)


def require_torch_accelerator_with_fp64(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
test_case
)


def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
is_torch_available() and backend_supports_training(torch_device),
"test requires accelerator with training support",
)(test_case)


def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
Expand Down Expand Up @@ -766,3 +807,139 @@ def disable_full_determinism():
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
torch.use_deterministic_algorithms(False)


# Utils for custom and alternative accelerator devices
def _is_torch_fp16_available(device):
if not is_torch_available():
return False

import torch

device = torch.device(device)

try:
x = torch.zeros((2, 2), dtype=torch.float16).to(device)
_ = x @ x
except Exception as e:
if device.type == "cuda":
raise ValueError(
f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
)

return False


def _is_torch_fp64_available(device):
if not is_torch_available():
return False

import torch

try:
x = torch.zeros((2, 2), dtype=torch.float64).to(device)
_ = x @ x
except Exception as e:
if device.type == "cuda":
raise ValueError(
f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
)

return False


# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available():
# Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True}

# Function definitions
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None}
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0}
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed}


# This dispatches a defined function according to the accelerator from the function definitions.
def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
if device not in dispatch_table:
return dispatch_table["default"](*args, **kwargs)

fn = dispatch_table[device]

# Some device agnostic functions return values. Need to guard against 'None' instead at
# user level
if fn is None:
return None

return fn(*args, **kwargs)


# These are callables which automatically dispatch the function specific to the accelerator
def backend_manual_seed(device: str, seed: int):
return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)


def backend_empty_cache(device: str):
return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)


def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)


# These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str):
if not is_torch_available():
return False

if device not in BACKEND_SUPPORTS_TRAINING:
device = "default"

return BACKEND_SUPPORTS_TRAINING[device]


# Guard for when Torch is not available
if is_torch_available():
# Update device function dict mapping
def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
try:
# Try to import the function directly
spec_fn = getattr(device_spec_module, attribute_name)
device_fn_dict[torch_device] = spec_fn
except AttributeError as e:
# If the function doesn't exist, and there is no default, throw an error
if "default" not in device_fn_dict:
raise AttributeError(
f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
) from e

if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
if not Path(device_spec_path).is_file():
raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")

try:
import_name = device_spec_path[: device_spec_path.index(".py")]
except ValueError as e:
raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e

device_spec_module = importlib.import_module(import_name)

try:
device_name = device_spec_module.DEVICE_NAME
except AttributeError:
raise AttributeError("Device spec file did not contain `DEVICE_NAME`")

if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
raise ValueError(msg)

torch_device = device_name

# Add one entry here for each `BACKEND_*` dictionary.
update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
23 changes: 11 additions & 12 deletions tests/models/test_layers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.utils.testing_utils import torch_device
from diffusers.utils.testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,
torch_device,
)


class EmbeddingsTests(unittest.TestCase):
Expand Down Expand Up @@ -315,8 +319,7 @@ def test_restnet_with_kernel_sde_vp(self):
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
Expand All @@ -339,8 +342,7 @@ def test_spatial_transformer_default(self):

def test_spatial_transformer_cross_attention_dim(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
Expand All @@ -363,8 +365,7 @@ def test_spatial_transformer_cross_attention_dim(self):

def test_spatial_transformer_timestep(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

num_embeds_ada_norm = 5

Expand Down Expand Up @@ -401,8 +402,7 @@ def test_spatial_transformer_timestep(self):

def test_spatial_transformer_dropout(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = (
Expand All @@ -427,11 +427,10 @@ def test_spatial_transformer_dropout(self):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
@require_torch_accelerator_with_fp64
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

num_embed = 5

Expand Down
7 changes: 4 additions & 3 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
CaptureLogger,
require_python39_or_higher,
require_torch_2,
require_torch_accelerator_with_training,
require_torch_gpu,
run_test_in_subprocess,
torch_device,
Expand Down Expand Up @@ -536,7 +537,7 @@ def test_model_from_pretrained(self):

self.assertEqual(output_1.shape, output_2.shape)

@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
@require_torch_accelerator_with_training
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

Expand All @@ -553,7 +554,7 @@ def test_training(self):
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
@require_torch_accelerator_with_training
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

Expand Down Expand Up @@ -624,7 +625,7 @@ def recursive_check(tuple_object, dict_object):

recursive_check(outputs_tuple, outputs_dict)

@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
@require_torch_accelerator_with_training
def test_enable_disable_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
Expand Down
11 changes: 9 additions & 2 deletions tests/models/test_models_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from parameterized import parameterized

from diffusers import PriorTransformer
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, slow, torch_all_close, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
slow,
torch_all_close,
torch_device,
)

from .test_modeling_common import ModelTesterMixin

Expand Down Expand Up @@ -157,7 +164,7 @@ def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache()

@parameterized.expand(
[
Expand Down
13 changes: 8 additions & 5 deletions tests/models/test_models_unet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import torch

from diffusers import UNet1DModel
from diffusers.utils.testing_utils import floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import (
backend_manual_seed,
floats_tensor,
slow,
torch_device,
)

from .test_modeling_common import ModelTesterMixin, UNetTesterMixin

Expand Down Expand Up @@ -103,8 +108,7 @@ def test_from_pretrained_hub(self):
def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

num_features = model.config.in_channels
seq_len = 16
Expand Down Expand Up @@ -244,8 +248,7 @@ def test_output_pretrained(self):
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
backend_manual_seed(torch_device, 0)

num_features = value_function.config.in_channels
seq_len = 14
Expand Down
5 changes: 3 additions & 2 deletions tests/models/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
slow,
torch_all_close,
torch_device,
Expand Down Expand Up @@ -153,15 +154,15 @@ def test_from_pretrained_hub(self):

assert image is not None, "Make sure output is not None"

@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
@require_torch_accelerator
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
image = model(**self.dummy_input).sample

assert image is not None, "Make sure output is not None"

@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
@require_torch_accelerator
def test_from_pretrained_accelerate_wont_change_results(self):
# by defautl model loading will use accelerate as `low_cpu_mem_usage=True`
model_accelerate, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
Expand Down
Loading

0 comments on commit f427345

Please sign in to comment.