Skip to content

Commit

Permalink
test: add test for extract_tile_shape_features cli
Browse files Browse the repository at this point in the history
  • Loading branch information
raylim committed Sep 11, 2023
1 parent 696bcbb commit 739654e
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/luna/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def wrapper(*args, **kwargs):

tmp_dir_dest = []
for key, write_mode in dir_key_write_mode.items():
if not args_dict[key]:
if key not in args_dict or not args_dict[key]:
continue
storage_options_key = "storage_options"
if "w" in write_mode:
Expand All @@ -107,7 +107,7 @@ def wrapper(*args, **kwargs):
result = None
with ExitStack() as stack:
for key, write_mode in file_key_write_mode.items():
if not args_dict[key]:
if key not in args_dict or not args_dict[key]:
continue
storage_options_key = "storage_options"
if "w" in write_mode:
Expand Down
17 changes: 13 additions & 4 deletions src/luna/pathology/cli/extract_tile_shape_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def cli(
statistical_descriptors: str = StatisticalDescriptors.ALL,
cellular_features: str = CellularFeatures.ALL,
property_type: str = PropertyType.ALL,
include_smaller_regions: bool = False,
label_cols: List[str] = None,
storage_options: dict = {},
output_storage_options: dict = {},
local_config: str = "",
Expand All @@ -92,6 +94,8 @@ def cli(
statistical_descriptors (str): statistical descriptors to calculate. One of All, Quantiles, Stats, or Density
cellular_features (str): cellular features to include. One of All, Nucleus, Cell, Cytoplasm, and Membrane
property_type (str): properties to include. One of All, Geometric, or Stain
include_smaller_regions (bool): include smaller regions in output
label_cols (List[str]): list of score columns to use for the classification. Tile is classified as the column with the max score
storage_options (dict): storage options to pass to reading functions
output_storage_options (dict): storage options to pass to writing functions
local_config (str): local config yaml file
Expand Down Expand Up @@ -141,6 +145,8 @@ def cli(
statistical_descriptors,
cellular_features,
property_type,
config["include_smaller_regions"],
config["label_cols"],
)

with fs.open(output_fpath, "wb") as of:
Expand All @@ -167,6 +173,7 @@ def extract_tile_shape_features(
cellular_features: CellularFeatures = CellularFeatures.ALL,
property_type: PropertyType = PropertyType.ALL,
include_smaller_regions: bool = False,
label_cols: List[str] = None,
properties: List[str] = [
"area",
"convex_area",
Expand Down Expand Up @@ -195,15 +202,14 @@ def extract_tile_shape_features(
statistical_descriptors (StatisticalDescriptors): statistical descriptors to calculate
cellular_features (CellularFeatures): cellular features to include
property_type (PropertyType): properties to include
label_cols (List[str]): list of score columns to use for the classification. Tile is classified as the column with the max score
properties (List[str]): list of whole slide image properties to
extract. Needs to be parquet compatible (numeric).
Returns:
dict: output paths and the number of features generated
"""
import pdb

pdb.set_trace()
if label_cols:
tiles_df["Classification"] = tiles_df[label_cols].idxmax(axis=1)
LabeledTileSchema.validate(tiles_df.reset_index())

tile_area = tiles_df.iloc[0].tile_size ** 2
Expand Down Expand Up @@ -260,6 +266,9 @@ def extract_tile_shape_features(

logger.info("Spatially joining tiles and objects")
gdf = object_gdf.sjoin(tiles_gdf, how="inner", predicate="within")
if len(gdf) == 0:
logger.info("No objects found within tiles")
return None
try:
measurement_keys = list(gdf.measurements.iloc[0].keys())
gdf = gdf.join(gdf.measurements.apply(lambda x: pd.Series(x)))
Expand Down
17 changes: 9 additions & 8 deletions src/luna/pathology/cli/generate_tile_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def convert_tiles_to_mask(
tiles_df: pd.DataFrame,
slide: tiffslide.TiffSlide,
label_cols: Union[str, List[str]],
output_urlpath: str,
output_storage_options: dict,
output_urlpath: str = "",
output_storage_options: dict = {},
):
"""Converts categorical tile labels to a slide image mask. This mask can be used for feature extraction and spatial analysis.
Expand Down Expand Up @@ -113,8 +113,8 @@ def convert_tiles_to_mask(
slide_width: int,
slide_height: int,
label_cols: Union[str, List[str]],
output_urlpath: str,
output_storage_options: dict,
output_urlpath: str = "",
output_storage_options: dict = {},
):
"""Converts categorical tile labels to a slide image mask. This mask can be used for feature extraction and spatial analysis.
Expand Down Expand Up @@ -154,10 +154,11 @@ def convert_tiles_to_mask(

logger.info(f"{address}, {row['mask']}, {value}")

slide_mask = Path(output_urlpath) / "tile_mask.tif"
logger.info(f"Saving output mask to {slide_mask}")
with open(slide_mask, "wb") as of:
tifffile.imwrite(of, mask_arr)
if output_urlpath:
slide_mask = Path(output_urlpath) / "tile_mask.tif"
logger.info(f"Saving output mask to {slide_mask}")
with open(slide_mask, "wb") as of:
tifffile.imwrite(of, mask_arr)

return mask_arr, mask_values

Expand Down
64 changes: 64 additions & 0 deletions tests/luna/pathology/cli/test_extract_tile_shape_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

import fire
import pandas as pd

from luna.pathology.cli.extract_tile_shape_features import cli


def test_cli_extract_tile_shape_features(tmp_path):
fire.Fire(
cli,
[
"--slide_urlpath",
"tests/testdata/pathology/123.svs",
"--object_urlpath",
"tests/testdata/pathology/test_cell_detections.geojson",
"--tiles_urlpath",
"tests/testdata/pathology/infer_tumor_background/123/tile_scores_and_labels_pytorch_inference.parquet",
"--output_urlpath",
str(tmp_path),
"--label_cols",
"Background,Tumor",
],
)

assert os.path.exists(f"{tmp_path}/shape_features.parquet")
assert os.path.exists(f"{tmp_path}/metadata.yml")
df = pd.read_parquet(f"{tmp_path}/shape_features.parquet")

assert len(df) == 866


def test_cli_extract_tile_shape_features_s3(s3fs_client):
s3fs_client.mkdirs("testtile", exist_ok=True)
s3fs_client.put("tests/testdata/pathology/123.svs", "testtile/test/")
s3fs_client.put(
"tests/testdata/pathology/test_cell_detections.geojson", "testtile/test/"
)
s3fs_client.put(
"tests/testdata/pathology/infer_tumor_background/123/tile_scores_and_labels_pytorch_inference.parquet",
"testtile/test/",
)
fire.Fire(
cli,
[
"--slide_urlpath",
"s3://testtile/test/123.svs",
"--object_urlpath",
"s3://testtile/test/test_cell_detections.geojson",
"--tiles_urlpath",
"s3://testtile/test/tile_scores_and_labels_pytorch_inference.parquet",
"--output_urlpath",
"s3://testtile/out/",
"--label_cols",
"Background,Tumor",
"--storage_options",
"{'key': '', 'secret': '', 'client_kwargs': {'endpoint_url': '"
+ s3fs_client.client_kwargs["endpoint_url"]
+ "'}}",
],
)

assert s3fs_client.exists("s3://testtile/out/shape_features.parquet")
assert s3fs_client.exists("s3://testtile/out/metadata.yml")
Loading

0 comments on commit 739654e

Please sign in to comment.