Skip to content

Commit

Permalink
💚 Parallel tests (TissueImageAnalytics#671)
Browse files Browse the repository at this point in the history
Adds the ability to run tests using several workers using [pytest-xdist](https://github.com/pytest-dev/pytest-xdist), significantly improving processing time. For example, on M1 Max (no CUDA), processing time dropped **from 14 minutes to 4 minutes 💨💨💨.**

<img width="1614" alt="image" src="https://github.com/TissueImageAnalytics/tiatoolbox/assets/19199204/fbb607b0-3bf1-48c3-b14a-be4acf2b1ec3">


However, this optimization comes at a cost. Previously, tests depended on serial execution. For example, segmentation and prediction methods used to rely on "output" as a folder to store intermediate results. If many functions modified this folder at the same time, the result would be unpredictable. To address this, I made some tweaks alongside TissueImageAnalytics#641 and TissueImageAnalytics#673 so that functions will not depend on each other.

If we merge this pull request, we will need to start checking that new tests are ready for parallel execution.

**Depends on TissueImageAnalytics#641 and TissueImageAnalytics#673
  • Loading branch information
blaginin authored Aug 23, 2023
1 parent c4ca84e commit 98465d6
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ lint: ## check style with flake8
flake8 tiatoolbox tests

test: ## run tests quickly with the default Python
pytest
pytest -n auto

coverage: ## check code coverage quickly with the default Python
pytest --cov=tiatoolbox --cov-report=term --cov-report=html --cov-report=xml
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pre-commit>=2.20.0
pytest>=7.2.0
pytest-cov>=4.0.0
pytest-runner>=6.0
pytest-xdist[psutil]
ruff==0.0.285 # This will be updated by pre-commit bot to latest version
toml>=0.10.2
twine>=4.0.1
Expand Down
35 changes: 34 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""pytest fixtures."""

import os
import shutil
from pathlib import Path
from typing import Callable
Expand Down Expand Up @@ -525,3 +525,36 @@ def moving_mask(remote_sample: Callable) -> Path:
Download moving mask for pytest.
"""
return remote_sample("moving_mask")


@pytest.fixture(scope="session")
def chdir() -> Callable:
"""Return a context manager to change the current working directory.
Todo: switch to chdir from contextlib when Python 3.11 is required
"""
try:
from contextlib import chdir
except ImportError:
from contextlib import AbstractContextManager

class chdir(AbstractContextManager): # noqa: N801
"""Non thread-safe context manager to change the current working directory.
See Also: https://github.com/python/cpython/blob/main/Lib/contextlib.py.
"""

def __init__(self, path):
self.path = path
self._old_cwd = []

def __enter__(self):
self._old_cwd.append(os.getcwd()) # noqa: PTH109
os.chdir(self.path)

def __exit__(self, *excinfo):
os.chdir(self._old_cwd.pop())

return chdir
4 changes: 1 addition & 3 deletions tests/models/test_hovernetplus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Unit test package for HoVerNet+."""

from pathlib import Path
from typing import Callable

import torch
Expand All @@ -11,9 +10,8 @@
from tiatoolbox.utils.transforms import imresize


def test_functionality(remote_sample: Callable, tmp_path: Path) -> None:
def test_functionality(remote_sample: Callable) -> None:
"""Functionality test."""
tmp_path = str(tmp_path)
sample_patch = str(remote_sample("stainnorm-source"))
patch_pre = imread(sample_patch)
patch_pre = imresize(patch_pre, scale_factor=0.5)
Expand Down
69 changes: 40 additions & 29 deletions tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def test_io_patch_predictor_config() -> None:
# -------------------------------------------------------------------------------------


def test_predictor_crash() -> None:
def test_predictor_crash(tmp_path: Path) -> None:
"""Test for crash when making predictor."""
# without providing any model
with pytest.raises(ValueError, match=r"Must provide.*"):
Expand All @@ -489,20 +489,19 @@ def test_predictor_crash() -> None:
predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32)

with pytest.raises(ValueError, match=r".*not a valid mode.*"):
predictor.predict("aaa", mode="random")
predictor.predict("aaa", mode="random", save_dir=tmp_path)
# remove previously generated data
if Path.exists(Path("output")):
shutil.rmtree("output", ignore_errors=True)
shutil.rmtree(tmp_path / "output", ignore_errors=True)
with pytest.raises(TypeError, match=r".*must be a list of file paths.*"):
predictor.predict("aaa", mode="wsi")
predictor.predict("aaa", mode="wsi", save_dir=tmp_path)
# remove previously generated data
shutil.rmtree("output", ignore_errors=True)
shutil.rmtree(tmp_path / "output", ignore_errors=True)
with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"):
predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi")
predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path)
with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"):
predictor.predict([1, 2, 3], labels=[1, 2], mode="patch")
predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path)
# remove previously generated data
shutil.rmtree("output", ignore_errors=True)
shutil.rmtree(tmp_path / "output", ignore_errors=True)


def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None:
Expand Down Expand Up @@ -622,34 +621,41 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path: Path) -> No
output = predictor.predict(
inputs,
on_gpu=ON_GPU,
save_dir=save_dir_path,
)
assert sorted(output.keys()) == ["predictions"]
assert len(output["predictions"]) == 2
shutil.rmtree(save_dir_path, ignore_errors=True)

output = predictor.predict(
inputs,
labels=[1, "a"],
return_labels=True,
on_gpu=ON_GPU,
save_dir=save_dir_path,
)
assert sorted(output.keys()) == sorted(["labels", "predictions"])
assert len(output["predictions"]) == len(output["labels"])
assert output["labels"] == [1, "a"]
shutil.rmtree(save_dir_path, ignore_errors=True)

output = predictor.predict(
inputs,
return_probabilities=True,
on_gpu=ON_GPU,
save_dir=save_dir_path,
)
assert sorted(output.keys()) == sorted(["predictions", "probabilities"])
assert len(output["predictions"]) == len(output["probabilities"])
shutil.rmtree(save_dir_path, ignore_errors=True)

output = predictor.predict(
inputs,
return_probabilities=True,
labels=[1, "a"],
return_labels=True,
on_gpu=ON_GPU,
save_dir=save_dir_path,
)
assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"])
assert len(output["predictions"]) == len(output["labels"])
Expand Down Expand Up @@ -693,13 +699,14 @@ def test_patch_predictor_api(sample_patch1, sample_patch2, tmp_path: Path) -> No
labels=[1, "a"],
return_labels=True,
on_gpu=ON_GPU,
save_dir=save_dir_path,
)
assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"])
assert len(output["predictions"]) == len(output["labels"])
assert len(output["predictions"]) == len(output["probabilities"])


def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path) -> None:
def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path, chdir: Callable) -> None:
"""Test normal run of wsi predictor."""
save_dir_path = tmp_path

Expand All @@ -711,6 +718,8 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path) -> None:
patch_size = np.array([224, 224])
predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32)

save_dir = f"{save_dir_path}/model_wsi_output"

# wrapper to make this more clean
kwargs = {
"return_probabilities": True,
Expand All @@ -720,6 +729,7 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path) -> None:
"stride_shape": patch_size,
"resolution": 1.0,
"units": "baseline",
"save_dir": save_dir,
}
# ! add this test back once the read at `baseline` is fixed
# sanity check, both output should be the same with same resolution read args
Expand All @@ -730,6 +740,8 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path) -> None:
**kwargs,
)

shutil.rmtree(save_dir, ignore_errors=True)

tile_output = predictor.predict(
[mini_wsi_jpg],
masks=[mini_wsi_msk],
Expand All @@ -744,7 +756,6 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path) -> None:
assert accuracy > 0.9, np.nonzero(~diff)

# remove previously generated data
save_dir = save_dir_path / "model_wsi_output"
shutil.rmtree(save_dir, ignore_errors=True)

kwargs = {
Expand Down Expand Up @@ -793,26 +804,26 @@ def test_wsi_predictor_api(sample_wsi_dict, tmp_path: Path) -> None:
)
# remove previously generated data
shutil.rmtree(_kwargs["save_dir"], ignore_errors=True)
shutil.rmtree("output", ignore_errors=True)

# test reading of multiple whole-slide images
_kwargs = copy.deepcopy(kwargs)
_kwargs["save_dir"] = None # default coverage
_kwargs["return_probabilities"] = False
output = predictor.predict(
[mini_wsi_svs, mini_wsi_svs],
masks=[mini_wsi_msk, mini_wsi_msk],
mode="wsi",
**_kwargs,
)
assert Path.exists(Path("output"))
for output_info in output.values():
assert Path(output_info["raw"]).exists()
assert "merged" in output_info
assert Path(output_info["merged"]).exists()
with chdir(save_dir_path):
# test reading of multiple whole-slide images
_kwargs = copy.deepcopy(kwargs)
_kwargs["save_dir"] = None # default coverage
_kwargs["return_probabilities"] = False
output = predictor.predict(
[mini_wsi_svs, mini_wsi_svs],
masks=[mini_wsi_msk, mini_wsi_msk],
mode="wsi",
**_kwargs,
)
assert Path.exists(Path("output"))
for output_info in output.values():
assert Path(output_info["raw"]).exists()
assert "merged" in output_info
assert Path(output_info["merged"]).exists()

# remove previously generated data
shutil.rmtree("output", ignore_errors=True)
# remove previously generated data
shutil.rmtree("output", ignore_errors=True)


def test_wsi_predictor_merge_predictions(sample_wsi_dict) -> None:
Expand Down
Loading

0 comments on commit 98465d6

Please sign in to comment.