forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tests] diffuser models in model zoo (hpcaitech#3136)
* [tests] diffuser models in model zoo * remove useless code * [tests] add diffusers to requirement-test
- Loading branch information
Showing
5 changed files
with
137 additions
and
107 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
diffusers | ||
fbgemm-gpu==0.2.0 | ||
pytest | ||
pytest-cov | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from . import timm | ||
from . import diffusers, timm | ||
from .registry import model_zoo | ||
|
||
__all__ = ['model_zoo'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .diffusers import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from functools import partial | ||
|
||
import diffusers | ||
import torch | ||
import transformers | ||
|
||
from ..registry import ModelAttribute, model_zoo | ||
|
||
BATCH_SIZE = 2 | ||
SEQ_LENGTH = 5 | ||
HEIGHT = 224 | ||
WIDTH = 224 | ||
IN_CHANNELS = 3 | ||
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7) | ||
TIME_STEP = 3 | ||
|
||
data_vae_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32)) | ||
data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3) | ||
|
||
identity_output = lambda x: x | ||
|
||
|
||
def data_clip_model(): | ||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) | ||
return dict(input_ids=input_ids, | ||
pixel_values=pixel_values, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids) | ||
|
||
|
||
def data_clip_text(): | ||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
return dict(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
|
||
def data_clip_vision(): | ||
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) | ||
return dict(pixel_values=pixel_values) | ||
|
||
|
||
model_zoo.register(name='diffusers_auto_encoder_kl', | ||
model_fn=diffusers.AutoencoderKL, | ||
data_gen_fn=data_vae_fn, | ||
output_transform_fn=identity_output) | ||
|
||
model_zoo.register(name='diffusers_vq_model', | ||
model_fn=diffusers.VQModel, | ||
data_gen_fn=data_vae_fn, | ||
output_transform_fn=identity_output) | ||
|
||
model_zoo.register(name='diffusers_clip_model', | ||
model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()), | ||
data_gen_fn=data_clip_model, | ||
output_transform_fn=identity_output) | ||
|
||
model_zoo.register(name='diffusers_clip_text_model', | ||
model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()), | ||
data_gen_fn=data_clip_text, | ||
output_transform_fn=identity_output) | ||
|
||
model_zoo.register(name='diffusers_clip_vision_model', | ||
model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()), | ||
data_gen_fn=data_clip_vision, | ||
output_transform_fn=identity_output) | ||
|
||
model_zoo.register(name='diffusers_unet2d_model', | ||
model_fn=diffusers.UNet2DModel, | ||
data_gen_fn=data_unet_fn, | ||
output_transform_fn=identity_output) |
167 changes: 61 additions & 106 deletions
167
tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,114 +1,69 @@ | ||
import pytest | ||
import torch | ||
import transformers | ||
from hf_tracer_utils import trace_model_and_compare_output | ||
|
||
from colossalai.fx import symbolic_trace | ||
from colossalai.testing.random import seed_all | ||
from tests.kit.model_zoo import model_zoo | ||
|
||
try: | ||
import diffusers | ||
HAS_DIFFUSERS = True | ||
except ImportError: | ||
HAS_DIFFUSERS = False | ||
|
||
BATCH_SIZE = 2 | ||
SEQ_LENGTH = 5 | ||
HEIGHT = 224 | ||
WIDTH = 224 | ||
IN_CHANNELS = 3 | ||
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8) | ||
TIME_STEP = 2 | ||
|
||
|
||
@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") | ||
def test_vae(): | ||
MODEL_LIST = [ | ||
diffusers.AutoencoderKL, | ||
diffusers.VQModel, | ||
] | ||
|
||
for model_cls in MODEL_LIST: | ||
model = model_cls() | ||
sample = torch.zeros(LATENTS_SHAPE) | ||
|
||
gm = symbolic_trace(model) | ||
|
||
model.eval() | ||
gm.eval() | ||
|
||
with torch.no_grad(): | ||
fx_out = gm(sample) | ||
non_fx_out = model(sample) | ||
assert torch.allclose( | ||
fx_out['sample'], | ||
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' | ||
|
||
|
||
def test_clip(): | ||
MODEL_LIST = [ | ||
transformers.CLIPModel, | ||
transformers.CLIPTextModel, | ||
transformers.CLIPVisionModel, | ||
] | ||
|
||
CONFIG_LIST = [ | ||
transformers.CLIPConfig, | ||
transformers.CLIPTextConfig, | ||
transformers.CLIPVisionConfig, | ||
] | ||
|
||
def data_gen(): | ||
if isinstance(model, transformers.CLIPModel): | ||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) | ||
kwargs = dict(input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
pixel_values=pixel_values) | ||
elif isinstance(model, transformers.CLIPTextModel): | ||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) | ||
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) | ||
elif isinstance(model, transformers.CLIPVisionModel): | ||
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32) | ||
kwargs = dict(pixel_values=pixel_values) | ||
return kwargs | ||
|
||
for model_cls, config in zip(MODEL_LIST, CONFIG_LIST): | ||
model = model_cls(config=config()) | ||
trace_model_and_compare_output(model, data_gen) | ||
|
||
|
||
@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed") | ||
@pytest.mark.skip(reason='cannot pass the test yet') | ||
def test_unet(): | ||
MODEL_LIST = [ | ||
diffusers.UNet2DModel, | ||
diffusers.UNet2DConditionModel, | ||
] | ||
|
||
for model_cls in MODEL_LIST: | ||
model = model_cls() | ||
sample = torch.zeros(LATENTS_SHAPE) | ||
|
||
gm = symbolic_trace(model) | ||
|
||
model.eval() | ||
gm.eval() | ||
|
||
with torch.no_grad(): | ||
fx_out = gm(sample, TIME_STEP) | ||
non_fx_out = model(sample, TIME_STEP) | ||
assert torch.allclose( | ||
fx_out['sample'], | ||
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}' | ||
|
||
def assert_dict(da, db, assert_fn): | ||
assert len(da) == len(db) | ||
for k, v in da.items(): | ||
assert k in db | ||
if not torch.is_tensor(v): | ||
continue | ||
u = db.get(k) | ||
assert_fn(u, v) | ||
|
||
if __name__ == "__main__": | ||
test_vae() | ||
test_clip() | ||
|
||
# skip because of failure | ||
# test_unet() | ||
def trace_and_compare(model_cls, data, output_fn): | ||
model = model_cls() | ||
model.eval() | ||
|
||
concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)} | ||
meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)} | ||
gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args) | ||
|
||
# run forward | ||
with torch.no_grad(): | ||
fx_out = gm(**data) | ||
non_fx_out = model(**data) | ||
|
||
# compare output | ||
transformed_fx_out = output_fn(fx_out) | ||
transformed_non_fx_out = output_fn(non_fx_out) | ||
|
||
def assert_fn(ta, tb): | ||
assert torch.equal(ta, tb) | ||
|
||
assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn) | ||
|
||
|
||
@pytest.mark.skip(reason='cannot pass this test yet') | ||
def test_diffusers(): | ||
seed_all(9091, cuda_deterministic=True) | ||
|
||
sub_model_zoo = model_zoo.get_sub_registry('diffusers') | ||
|
||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): | ||
data = data_gen_fn() | ||
trace_and_compare(model_fn, data, output_transform_fn) | ||
torch.cuda.synchronize() | ||
print(f"{name:40s} √") | ||
|
||
|
||
def test_torch_diffusers(): | ||
seed_all(65535, cuda_deterministic=True) | ||
|
||
sub_model_zoo = model_zoo.get_sub_registry('diffusers') | ||
|
||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): | ||
data = data_gen_fn() | ||
model = model_fn() | ||
output = model(**data) | ||
torch.cuda.synchronize() | ||
print(f"{name:40s} √") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_torch_diffusers() |