forked from apple/ml-cvnets
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model.py
127 lines (99 loc) · 3.93 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import sys
from pathlib import Path
import pytest
sys.path.append("..")
from typing import Dict
from torch import Tensor
from cvnets import get_model
from loss_fn import build_loss_fn
from tests.configs import get_config
from tests.test_utils import unset_pretrained_models_from_opts
# We use a batch size of 1 to catch error that may arise due to reshaping operations inside the model
@pytest.mark.parametrize("batch_size", [1, 2])
def test_model(config_file: str, batch_size: int):
opts = get_config(config_file=config_file)
setattr(opts, "common.debug_mode", True)
# removing pretrained models (if any) for now to reduce test time as well as access issues
unset_pretrained_models_from_opts(opts)
model = get_model(opts)
criteria = build_loss_fn(opts)
inputs = None
targets = None
if hasattr(model, "dummy_input_and_label"):
inputs_and_targets = model.dummy_input_and_label(batch_size)
inputs = inputs_and_targets["samples"]
targets = inputs_and_targets["targets"]
assert inputs is not None, (
"Input tensor can't be None. This is likely because "
"{} does not implement dummy_input_and_label function".format(
model.__class__.__name__
)
)
assert targets is not None, (
"Label tensor can't be None. This is likely because "
"{} does not implement dummy_input_and_label function".format(
model.__class__.__name__
)
)
# if getattr(opts, "common.channels_last", False):
# inputs = inputs.to(memory_format=torch.channels_last)
# model = model.to(memory_format=torch.channels_last)
try:
outputs = model(inputs)
loss = criteria(
input_sample=inputs,
prediction=outputs,
target=targets,
epoch=0,
iterations=0,
)
print(f"Loss: {loss}")
if isinstance(loss, Tensor):
loss.backward()
elif isinstance(loss, Dict):
loss["total_loss"].backward()
else:
raise RuntimeError("The output of criteria should be either Dict or Tensor")
# If there are unused parameters in gradient computation, print them
# This may be useful for debugging purposes
unused_params = []
for name, param in model.named_parameters():
if param.grad is None:
unused_params.append(name)
if len(unused_params) > 0:
print("Unused parameters: {}".format(unused_params))
except Exception as e:
if (
isinstance(e, ValueError)
and str(e).find("Expected more than 1 value per channel when training") > -1
and batch_size == 1
):
# For segmentation models (e.g., PSPNet), we pool the tensor so that they have a spatial size of 1.
# In such a case, batch norm needs a batch size > 1. Otherwise, we can't compute the statistics, raising
# ValueError("Expected more than 1 value per channel when training"). If we encounter this error
# for a batch size of 1, we skip it.
pytest.skip(str(e))
else:
raise e
def exclude_yaml_from_test(yaml_file_path: str) -> bool:
"""Check if a yaml file should be excluded from test based on first line marker.
Args:
yaml_file_path: path to the yaml file to check
Returns:
True if yaml should be excluded, and False otherwise.
"""
with open(yaml_file_path, "r") as f:
first_line = f.readline().rstrip()
return (
first_line.startswith("#")
and first_line.lower().replace(" ", "") == "#pytest:disable"
)
def pytest_generate_tests(metafunc):
configs = [
str(x) for x in Path(".").rglob("**/*.yaml") if not exclude_yaml_from_test(x)
]
metafunc.parametrize("config_file", configs)