forked from slaclab/lume-model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
131 lines (101 loc) · 4.21 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import json
from typing import Any, Union
import pytest
import numpy as np
from lume_model.utils import variables_from_yaml
from lume_model.variables import ScalarVariable
try:
import torch
from botorch.models.transforms.input import AffineInputTransform
from lume_model.models import TorchModel, TorchModule
except ModuleNotFoundError:
pass
@pytest.fixture(scope="session")
def rootdir() -> str:
return os.path.dirname(os.path.abspath(__file__))
@pytest.fixture(scope="session")
def simple_variables() -> dict[str, Union[list[ScalarVariable], list[ScalarVariable]]]:
input_variables = [ScalarVariable(name="input1", default_value=1.0, value_range=(0.0, 5.0)),
ScalarVariable(name="input2", default_value=2.0, value_range=(1.0, 3.0))]
output_variables = [ScalarVariable(name="output1"),
ScalarVariable(name="output2")]
return {"input_variables": input_variables, "output_variables": output_variables}
@pytest.fixture(scope="module")
def california_model_info(rootdir) -> dict[str, str]:
try:
with open(f"{rootdir}/test_files/california_regression/model_info.json", "r") as f:
model_info = json.load(f)
return model_info
except FileNotFoundError as e:
pytest.skip(str(e))
@pytest.fixture(scope="module")
def california_variables(rootdir) -> tuple[list[ScalarVariable], list[ScalarVariable]]:
try:
file = f"{rootdir}/test_files/california_regression/variables.yml"
input_variables, output_variables = variables_from_yaml(file)
return input_variables, output_variables
except FileNotFoundError as e:
pytest.skip(str(e))
@pytest.fixture(scope="module")
def california_transformers(rootdir):
botorch = pytest.importorskip("botorch")
try:
with open(f"{rootdir}/test_files/california_regression/normalization.json", "r") as f:
normalizations = json.load(f)
except FileNotFoundError as e:
pytest.skip(str(e))
input_transformer = botorch.models.transforms.input.AffineInputTransform(
len(normalizations["x_mean"]),
coefficient=torch.tensor(normalizations["x_scale"]),
offset=torch.tensor(normalizations["x_mean"]),
)
output_transformer = botorch.models.transforms.input.AffineInputTransform(
len(normalizations["y_mean"]),
coefficient=torch.tensor(normalizations["y_scale"]),
offset=torch.tensor(normalizations["y_mean"]),
)
return input_transformer, output_transformer
@pytest.fixture(scope="module")
def california_model_kwargs(
rootdir,
california_model_info,
california_variables,
california_transformers,
) -> dict[str, Any]:
botorch = pytest.importorskip("botorch")
input_variables, output_variables = california_variables
input_transformer, output_transformer = california_transformers
model_kwargs = {
"model": torch.load(f"{rootdir}/test_files/california_regression/model.pt"),
"input_variables": input_variables,
"output_variables": output_variables,
"input_transformers": [input_transformer],
"output_transformers": [output_transformer],
"output_format": "tensor",
}
return model_kwargs
@pytest.fixture(scope="module")
def california_test_input_tensor(rootdir: str):
torch = pytest.importorskip("torch")
try:
test_input_tensor = torch.load(f"{rootdir}/test_files/california_regression/test_input_tensor.pt")
except FileNotFoundError as e:
pytest.skip(str(e))
return test_input_tensor
@pytest.fixture(scope="module")
def california_test_input_dict(california_test_input_tensor, california_model_info) -> dict:
pytest.importorskip("botorch")
test_input_dict = {
key: california_test_input_tensor[0, idx]
for idx, key in enumerate(california_model_info["model_in_list"])
}
return test_input_dict
@pytest.fixture(scope="module")
def california_model(california_model_kwargs):
botorch = pytest.importorskip("botorch")
return TorchModel(**california_model_kwargs)
@pytest.fixture(scope="module")
def california_module(california_model):
botorch = pytest.importorskip("botorch")
return TorchModule(model=california_model)