Skip to content

Commit

Permalink
[tests] diffuser models in model zoo (hpcaitech#3136)
Browse files Browse the repository at this point in the history
* [tests] diffuser models in model zoo

* remove useless code

* [tests] add diffusers to requirement-test
  • Loading branch information
1SAA authored Mar 14, 2023
1 parent 1a46e71 commit 1216d1e
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 107 deletions.
1 change: 1 addition & 0 deletions requirements/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
diffusers
fbgemm-gpu==0.2.0
pytest
pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion tests/kit/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import timm
from . import diffusers, timm
from .registry import model_zoo

__all__ = ['model_zoo']
1 change: 1 addition & 0 deletions tests/kit/model_zoo/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .diffusers import *
73 changes: 73 additions & 0 deletions tests/kit/model_zoo/diffusers/diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from functools import partial

import diffusers
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
TIME_STEP = 3

data_vae_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32))
data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3)

identity_output = lambda x: x


def data_clip_model():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
return dict(input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids)


def data_clip_text():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)


def data_clip_vision():
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
return dict(pixel_values=pixel_values)


model_zoo.register(name='diffusers_auto_encoder_kl',
model_fn=diffusers.AutoencoderKL,
data_gen_fn=data_vae_fn,
output_transform_fn=identity_output)

model_zoo.register(name='diffusers_vq_model',
model_fn=diffusers.VQModel,
data_gen_fn=data_vae_fn,
output_transform_fn=identity_output)

model_zoo.register(name='diffusers_clip_model',
model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()),
data_gen_fn=data_clip_model,
output_transform_fn=identity_output)

model_zoo.register(name='diffusers_clip_text_model',
model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()),
data_gen_fn=data_clip_text,
output_transform_fn=identity_output)

model_zoo.register(name='diffusers_clip_vision_model',
model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()),
data_gen_fn=data_clip_vision,
output_transform_fn=identity_output)

model_zoo.register(name='diffusers_unet2d_model',
model_fn=diffusers.UNet2DModel,
data_gen_fn=data_unet_fn,
output_transform_fn=identity_output)
167 changes: 61 additions & 106 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
Original file line number Diff line number Diff line change
@@ -1,114 +1,69 @@
import pytest
import torch
import transformers
from hf_tracer_utils import trace_model_and_compare_output

from colossalai.fx import symbolic_trace
from colossalai.testing.random import seed_all
from tests.kit.model_zoo import model_zoo

try:
import diffusers
HAS_DIFFUSERS = True
except ImportError:
HAS_DIFFUSERS = False

BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
TIME_STEP = 2


@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed")
def test_vae():
MODEL_LIST = [
diffusers.AutoencoderKL,
diffusers.VQModel,
]

for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)

gm = symbolic_trace(model)

model.eval()
gm.eval()

with torch.no_grad():
fx_out = gm(sample)
non_fx_out = model(sample)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'


def test_clip():
MODEL_LIST = [
transformers.CLIPModel,
transformers.CLIPTextModel,
transformers.CLIPVisionModel,
]

CONFIG_LIST = [
transformers.CLIPConfig,
transformers.CLIPTextConfig,
transformers.CLIPVisionConfig,
]

def data_gen():
if isinstance(model, transformers.CLIPModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values)
elif isinstance(model, transformers.CLIPTextModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
elif isinstance(model, transformers.CLIPVisionModel):
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(pixel_values=pixel_values)
return kwargs

for model_cls, config in zip(MODEL_LIST, CONFIG_LIST):
model = model_cls(config=config())
trace_model_and_compare_output(model, data_gen)


@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed")
@pytest.mark.skip(reason='cannot pass the test yet')
def test_unet():
MODEL_LIST = [
diffusers.UNet2DModel,
diffusers.UNet2DConditionModel,
]

for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)

gm = symbolic_trace(model)

model.eval()
gm.eval()

with torch.no_grad():
fx_out = gm(sample, TIME_STEP)
non_fx_out = model(sample, TIME_STEP)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'

def assert_dict(da, db, assert_fn):
assert len(da) == len(db)
for k, v in da.items():
assert k in db
if not torch.is_tensor(v):
continue
u = db.get(k)
assert_fn(u, v)

if __name__ == "__main__":
test_vae()
test_clip()

# skip because of failure
# test_unet()
def trace_and_compare(model_cls, data, output_fn):
model = model_cls()
model.eval()

concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)}
meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)}
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)

# run forward
with torch.no_grad():
fx_out = gm(**data)
non_fx_out = model(**data)

# compare output
transformed_fx_out = output_fn(fx_out)
transformed_non_fx_out = output_fn(non_fx_out)

def assert_fn(ta, tb):
assert torch.equal(ta, tb)

assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn)


@pytest.mark.skip(reason='cannot pass this test yet')
def test_diffusers():
seed_all(9091, cuda_deterministic=True)

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()
print(f"{name:40s} √")


def test_torch_diffusers():
seed_all(65535, cuda_deterministic=True)

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
data = data_gen_fn()
model = model_fn()
output = model(**data)
torch.cuda.synchronize()
print(f"{name:40s} √")


if __name__ == "__main__":
test_torch_diffusers()

0 comments on commit 1216d1e

Please sign in to comment.