forked from FoundryOfTitans/AlphaFold3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
107 lines (79 loc) · 3.77 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""This file prepares config fixtures for other tests."""
from pathlib import Path
import pytest
import rootutils
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import DictConfig, open_dict
@pytest.fixture(scope="package")
def cfg_train_global() -> DictConfig:
"""A pytest fixture for setting up a default Hydra DictConfig for training.
:return: A DictConfig object containing a default Hydra configuration for training.
"""
with initialize(version_base="1.3", config_path="../configs"):
cfg = compose(config_name="train.yaml", return_hydra_config=True, overrides=[])
# set defaults for all tests
with open_dict(cfg):
cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root"))
cfg.trainer.max_epochs = 1
cfg.trainer.limit_train_batches = 0.01
cfg.trainer.limit_val_batches = 0.1
cfg.trainer.limit_test_batches = 0.1
cfg.trainer.accelerator = "cpu"
cfg.trainer.devices = 1
cfg.data.num_workers = 0
cfg.data.pin_memory = False
cfg.extras.print_config = False
cfg.extras.enforce_tags = False
cfg.logger = None
return cfg
@pytest.fixture(scope="package")
def cfg_eval_global() -> DictConfig:
"""A pytest fixture for setting up a default Hydra DictConfig for evaluation.
:return: A DictConfig containing a default Hydra configuration for evaluation.
"""
with initialize(version_base="1.3", config_path="../configs"):
cfg = compose(config_name="eval.yaml", return_hydra_config=True, overrides=["ckpt_path=."])
# set defaults for all tests
with open_dict(cfg):
cfg.paths.root_dir = str(rootutils.find_root(indicator=".project-root"))
cfg.trainer.max_epochs = 1
cfg.trainer.limit_test_batches = 0.1
cfg.trainer.accelerator = "cpu"
cfg.trainer.devices = 1
cfg.data.num_workers = 0
cfg.data.pin_memory = False
cfg.extras.print_config = False
cfg.extras.enforce_tags = False
cfg.logger = None
return cfg
@pytest.fixture(scope="function")
def cfg_train(cfg_train_global: DictConfig, tmp_path: Path) -> DictConfig:
"""A pytest fixture built on top of the `cfg_train_global()` fixture, which accepts a temporary
logging path `tmp_path` for generating a temporary logging path.
This is called by each test which uses the `cfg_train` arg. Each test generates its own temporary logging path.
:param cfg_train_global: The x DictConfig object to be modified.
:param tmp_path: The temporary logging path.
:return: A DictConfig with updated output and log directories corresponding to `tmp_path`.
"""
cfg = cfg_train_global.copy()
with open_dict(cfg):
cfg.paths.output_dir = str(tmp_path)
cfg.paths.log_dir = str(tmp_path)
yield cfg
GlobalHydra.instance().clear()
@pytest.fixture(scope="function")
def cfg_eval(cfg_eval_global: DictConfig, tmp_path: Path) -> DictConfig:
"""A pytest fixture built on top of the `cfg_eval_global()` fixture, which accepts a temporary
logging path `tmp_path` for generating a temporary logging path.
This is called by each test which uses the `cfg_eval` arg. Each test generates its own temporary logging path.
:param cfg_train_global: The x DictConfig object to be modified.
:param tmp_path: The temporary logging path.
:return: A DictConfig with updated output and log directories corresponding to `tmp_path`.
"""
cfg = cfg_eval_global.copy()
with open_dict(cfg):
cfg.paths.output_dir = str(tmp_path)
cfg.paths.log_dir = str(tmp_path)
yield cfg
GlobalHydra.instance().clear()