forked from mosaicml/composer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model_registry.py
38 lines (28 loc) · 1.01 KB
/
test_model_registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright 2021 MosaicML. All Rights Reserved.
import pytest
from composer.models import ModelHparams
from composer.trainer.trainer_hparams import model_registry
@pytest.mark.parametrize("model_name", model_registry.keys())
def test_model_registry(model_name, request):
if model_name in ['timm']:
pytest.importorskip("timm")
if model_name in ['unet']:
pytest.importorskip("monai")
# create the model hparams object
model_hparams = model_registry[model_name]()
requires_num_classes = set([
"deeplabv3",
"resnet_cifar",
"efficientnetb0",
"resnet",
"mnist_classifier",
])
if model_name in requires_num_classes:
model_hparams.num_classes = 10
if model_name == "resnet":
model_hparams.model_name = 'resnet50'
if model_name == "deeplabv3":
model_hparams.is_backbone_pretrained = False
if model_name == "timm":
model_hparams.model_name = "resnet18"
assert isinstance(model_hparams, ModelHparams)