forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
104 lines (80 loc) · 3.62 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import os
import sys
from pathlib import Path
from typing import List
import pytest
import torch
from lightning.fabric.utilities.testing import _runif_reasons
wd = Path(__file__).parent.parent.absolute()
@pytest.fixture(autouse=True)
def add_wd_to_path():
# this adds support for running tests without the package installed
sys.path.append(str(wd))
@pytest.fixture()
def fake_checkpoint_dir(tmp_path):
os.chdir(tmp_path)
checkpoint_dir = tmp_path / "checkpoints" / "tmp"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "lit_model.pth").touch()
(checkpoint_dir / "lit_config.json").touch()
(checkpoint_dir / "tokenizer.json").touch()
(checkpoint_dir / "tokenizer_config.json").touch()
return checkpoint_dir
class TensorLike:
def __eq__(self, other):
return isinstance(other, torch.Tensor)
@pytest.fixture()
def tensor_like():
return TensorLike()
class FloatLike:
def __eq__(self, other):
return not isinstance(other, int) and isinstance(other, float)
@pytest.fixture()
def float_like():
return FloatLike()
@pytest.fixture(autouse=True)
def restore_default_dtype():
# just in case
torch.set_default_dtype(torch.float32)
def RunIf(**kwargs):
reasons, marker_kwargs = _runif_reasons(**kwargs)
return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs)
# https://github.com/Lightning-AI/lightning/blob/6e517bd55b50166138ce6ab915abd4547702994b/tests/tests_fabric/conftest.py#L140
def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.Config) -> None:
initial_size = len(items)
conditions = []
filtered, skipped = 0, 0
options = {"standalone": "PL_RUN_STANDALONE_TESTS", "min_cuda_gpus": "PL_RUN_CUDA_TESTS"}
if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1":
# special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests.
# by deleting the key, we avoid filtering out the CPU tests
del options["min_cuda_gpus"]
for kwarg, env_var in options.items():
# this will compute the intersection of all tests selected per environment variable
if os.getenv(env_var, "0") == "1":
conditions.append(env_var)
for i, test in reversed(list(enumerate(items))): # loop in reverse, since we are going to pop items
already_skipped = any(marker.name == "skip" for marker in test.own_markers)
if already_skipped:
# the test was going to be skipped anyway, filter it out
items.pop(i)
skipped += 1
continue
has_runif_with_kwarg = any(
marker.name == "skipif" and marker.kwargs.get(kwarg) for marker in test.own_markers
)
if not has_runif_with_kwarg:
# the test has `@RunIf(kwarg=True)`, filter it out
items.pop(i)
filtered += 1
if config.option.verbose >= 0 and (filtered or skipped):
writer = config.get_terminal_writer()
writer.write(
f"\nThe number of tests has been filtered from {initial_size} to {initial_size - filtered} after the"
f" filters {conditions}.\n{skipped} tests are marked as unconditional skips.\nIn total,"
f" {len(items)} tests will run.\n",
flush=True,
bold=True,
purple=True, # oh yeah, branded pytest messages
)