Skip to content

Commit

Permalink
update: added more tests for the RSSETSMarker
Browse files Browse the repository at this point in the history
  • Loading branch information
LeSasse authored and synchon committed Sep 20, 2022
1 parent cee6212 commit 40dbcd7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
2 changes: 1 addition & 1 deletion junifer/markers/etsrss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_output_kind(self, input: List[str]) -> List[str]:
"""
outputs = []
for t_input in input:
if input in ["BOLD"]:
if t_input in ["BOLD"]:
outputs.append("timeseries")
else:
raise ValueError(f"Unknown input kind for {t_input}")
Expand Down
33 changes: 32 additions & 1 deletion junifer/markers/tests/test_ets_rss_marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Synchon Mandal <[email protected]>
# License: AGPL

from pathlib import Path

from nilearn import image
from nilearn.maskers import NiftiLabelsMasker

Expand All @@ -14,7 +16,7 @@
from junifer.testing.datagrabbers import SPMAuditoryTestingDatagrabber


def test_RSSETS() -> None:
def test_compute() -> None:
"""Test RSS ETS."""
atlas = "Schaefer100x17"
test_atlas, _, _ = load_atlas(atlas)
Expand All @@ -37,3 +39,32 @@ def test_RSSETS() -> None:
assert meta["atlas"] == "Schaefer100x17"
assert meta["aggregation_method"] == "mean"
assert meta["class"] == "RSSETSMarker"


def test_get_output_kind() -> None:
""" Test get_output_kind."""

atlas = "Schaefer100x17"
ets_rss_marker = RSSETSMarker(atlas=atlas)
input_list = ["BOLD"]
input_list = ets_rss_marker.get_output_kind(input_list)
assert len(input_list) == 1
assert input_list[0] in ["timeseries"]


def test_store(tmp_path: Path) -> None:
"""Test store."""

atlas = "Schaefer100x17"
with SPMAuditoryTestingDatagrabber() as dg:
out = dg["sub001"]
niimg = image.load_img(str(out["BOLD"]["path"].absolute()))
input_dict = {"data": niimg, "path": out["BOLD"]["path"]}
# Compute the RSSETSMarker
ets_rss_marker = RSSETSMarker(atlas=atlas)
new_out = ets_rss_marker.compute(input_dict)
storage = {
"kind": "SQLiteFeatureStorage",
"uri": str((tmp_path / "test.db").absolute())
}
ets_rss_marker.store("SQLiteFeatureStorage", new_out, storage)

0 comments on commit 40dbcd7

Please sign in to comment.