Skip to content

Commit

Permalink
🚨 Add flake8-annotations Checks for examples (TissueImageAnalytic…
Browse files Browse the repository at this point in the history
…s#699)

- Add `flake8-annotations` checks for `examples`
  • Loading branch information
shaneahmed authored Aug 21, 2023
1 parent 7d9c944 commit e633e56
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 208 deletions.
206 changes: 74 additions & 132 deletions examples/07-advanced-modeling.ipynb

Large diffs are not rendered by default.

114 changes: 65 additions & 49 deletions examples/full-pipelines/slide-graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"import warnings\n",
"from collections import OrderedDict\n",
"from pathlib import Path\n",
"from typing import Callable\n",
"from typing import Callable, Iterator\n",
"\n",
"# Third party imports\n",
"import joblib\n",
Expand Down Expand Up @@ -237,26 +237,26 @@
},
"outputs": [],
"source": [
"def load_json(path: Path):\n",
"def load_json(path: Path) -> dict | list | int | float | str:\n",
" \"\"\"Load JSON from a file path.\"\"\"\n",
" with path.open() as fptr:\n",
" return json.load(fptr)\n",
"\n",
"\n",
"def rmdir(dir_path: Path):\n",
"def rmdir(dir_path: Path) -> None:\n",
" \"\"\"Remove a directory.\"\"\"\n",
" if dir_path.is_dir():\n",
" shutil.rmtree(dir_path)\n",
"\n",
"\n",
"def rm_n_mkdir(dir_path: Path):\n",
"def rm_n_mkdir(dir_path: Path) -> None:\n",
" \"\"\"Remove then re-create a directory.\"\"\"\n",
" if dir_path.is_dir():\n",
" shutil.rmtree(dir_path)\n",
" dir_path.mkdir(parents=True)\n",
"\n",
"\n",
"def mkdir(dir_path: Path):\n",
"def mkdir(dir_path: Path) -> None:\n",
" \"\"\"Create a directory if it does not exist.\"\"\"\n",
" if not dir_path.is_dir():\n",
" dir_path.mkdir(parents=True)\n",
Expand Down Expand Up @@ -450,7 +450,7 @@
" test: float,\n",
" num_folds: int,\n",
" seed: int = 5,\n",
"):\n",
") -> list:\n",
" \"\"\"Helper to generate stratified splits.\n",
"\n",
" Split `x` and `y` in to N number of `num_folds` sets\n",
Expand Down Expand Up @@ -662,7 +662,7 @@
" msk_paths: list[str],\n",
" save_dir: str,\n",
" preproc_func: Callable | None = None,\n",
"):\n",
") -> list:\n",
" \"\"\"Helper function to extract deep features.\"\"\"\n",
" ioconfig = IOSegmentorConfig(\n",
" input_resolutions=[\n",
Expand Down Expand Up @@ -754,7 +754,7 @@
" stride_shape: tuple[int] = (512, 512),\n",
" resolution: Resolution = 0.25,\n",
" units: Units = \"mpp\",\n",
"):\n",
") -> None:\n",
" \"\"\"Estimates cellular composition.\"\"\"\n",
" reader = WSIReader.open(wsi_path)\n",
" inst_pred = joblib.load(inst_pred_path)\n",
Expand Down Expand Up @@ -815,7 +815,7 @@
" msk_paths: list[str],\n",
" save_dir: str,\n",
" preproc_func: Callable,\n",
"):\n",
") -> list:\n",
" \"\"\"Extract cellular composition features.\"\"\"\n",
" inst_segmentor = NucleusInstanceSegmentor(\n",
" pretrained_model=\"hovernet_fast-pannuke\",\n",
Expand Down Expand Up @@ -891,7 +891,7 @@
"stain_normalizer.fit(target_image)\n",
"\n",
"\n",
"def stain_norm_func(img):\n",
"def stain_norm_func(img: np.ndarray) -> np.ndarray:\n",
" \"\"\"Helper function to perform stain normalization.\"\"\"\n",
" return stain_normalizer.transform(img)"
]
Expand Down Expand Up @@ -1035,7 +1035,7 @@
},
"outputs": [],
"source": [
"def construct_graph(wsi_name, save_path):\n",
"def construct_graph(wsi_name: str, save_path: Path) -> None:\n",
" \"\"\"Construct graph for one WSI and save to file.\"\"\"\n",
" positions = np.load(f\"{WSI_FEATURE_DIR}/{wsi_name}.position.npy\")\n",
" features = np.load(f\"{WSI_FEATURE_DIR}/{wsi_name}.features.npy\")\n",
Expand Down Expand Up @@ -1271,13 +1271,18 @@
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, info_list, mode=\"train\", preproc=None):\n",
" def __init__(\n",
" self: Dataset,\n",
" info_list: list,\n",
" mode: str = \"train\",\n",
" preproc: Callable | None = None,\n",
" ) -> None:\n",
" \"\"\"Initialize SlideGraphDataset.\"\"\"\n",
" self.info_list = info_list\n",
" self.mode = mode\n",
" self.preproc = preproc\n",
"\n",
" def __getitem__(self, idx):\n",
" def __getitem__(self: Dataset, idx: int) -> Dataset:\n",
" \"\"\"Get an element from SlideGraphDataset.\"\"\"\n",
" info = self.info_list[idx]\n",
" if any(v in self.mode for v in [\"train\", \"valid\"]):\n",
Expand All @@ -1301,7 +1306,7 @@
" return {\"graph\": graph, \"label\": label}\n",
" return {\"graph\": graph}\n",
"\n",
" def __len__(self):\n",
" def __len__(self: Dataset) -> int:\n",
" \"\"\"Length of SlideGraphDataset.\"\"\"\n",
" return len(self.info_list)"
]
Expand Down Expand Up @@ -1387,7 +1392,7 @@
"\n",
"\n",
"# we must define the function after training/loading\n",
"def nodes_preproc_func(node_features):\n",
"def nodes_preproc_func(node_features: np.ndarray) -> np.ndarray:\n",
" \"\"\"Pre-processing function for nodes.\"\"\"\n",
" return node_scaler.transform(node_features)"
]
Expand All @@ -1414,17 +1419,17 @@
" \"\"\"Define SlideGraph architecture.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" dim_features,\n",
" dim_target,\n",
" layers=None,\n",
" pooling=\"max\",\n",
" dropout=0.0,\n",
" conv=\"GINConv\",\n",
" self: nn.Module,\n",
" dim_features: int,\n",
" dim_target: int,\n",
" layers: list[int, int] | None = None,\n",
" pooling: str = \"max\",\n",
" dropout: float = 0.0,\n",
" conv: str = \"GINConv\",\n",
" *,\n",
" gembed=False,\n",
" **kwargs,\n",
" ):\n",
" gembed: bool = False,\n",
" **kwargs: dict,\n",
" ) -> None:\n",
" \"\"\"Initialize SlideGraphArch.\"\"\"\n",
" super().__init__()\n",
" if layers is None:\n",
Expand All @@ -1449,7 +1454,7 @@
" msg = f'Not support `conv=\"{conv}\".'\n",
" raise ValueError(msg)\n",
"\n",
" def create_linear(in_dims, out_dims):\n",
" def create_linear(in_dims: int, out_dims: int) -> Linear:\n",
" return nn.Sequential(\n",
" Linear(in_dims, out_dims),\n",
" BatchNorm1d(out_dims),\n",
Expand Down Expand Up @@ -1480,19 +1485,19 @@
" # as they can be sklearn model etc.\n",
" self.aux_model = {}\n",
"\n",
" def save(self, path, aux_path):\n",
" def save(self: nn.Module, path: str | Path, aux_path: str | Path) -> None:\n",
" \"\"\"Save torch model.\"\"\"\n",
" state_dict = self.state_dict()\n",
" torch.save(state_dict, path)\n",
" joblib.dump(self.aux_model, aux_path)\n",
"\n",
" def load(self, path, aux_path):\n",
" def load(self: nn.Module, path: str | Path, aux_path: str | Path) -> None:\n",
" \"\"\"Load torch model.\"\"\"\n",
" state_dict = torch.load(path)\n",
" self.load_state_dict(state_dict)\n",
" self.aux_model = joblib.load(aux_path)\n",
"\n",
" def forward(self, data):\n",
" def forward(self: nn.Module, data: np.ndarray | torch.Tensor) -> tuple:\n",
" \"\"\"Torch model forward function.\"\"\"\n",
" feature, edge_index, batch = data.x, data.edge_index, data.batch\n",
"\n",
Expand Down Expand Up @@ -1536,7 +1541,13 @@
"\n",
" # Run one single step\n",
" @staticmethod\n",
" def train_batch(model, batch_data, on_gpu, optimizer: torch.optim.Optimizer):\n",
" def train_batch(\n",
" model: nn.Module,\n",
" batch_data: np.ndarray | torch.Tensor,\n",
" optimizer: torch.optim.Optimizer,\n",
" *,\n",
" on_gpu: bool,\n",
" ) -> list:\n",
" \"\"\"Helper function for model training.\"\"\"\n",
" device = select_device(on_gpu=on_gpu)\n",
" wsi_graphs = batch_data[\"graph\"].to(device)\n",
Expand Down Expand Up @@ -1570,7 +1581,12 @@
"\n",
" # Run one inference step\n",
" @staticmethod\n",
" def infer_batch(model, batch_data, on_gpu):\n",
" def infer_batch(\n",
" model: nn.Module,\n",
" batch_data: torch.Tensor,\n",
" *,\n",
" on_gpu: bool,\n",
" ) -> list:\n",
" \"\"\"Model inference.\"\"\"\n",
" device = select_device(on_gpu=on_gpu)\n",
" wsi_graphs = batch_data[\"graph\"].to(device)\n",
Expand Down Expand Up @@ -1759,25 +1775,25 @@
"\n",
" \"\"\"\n",
"\n",
" def __init__(self, labels, batch_size=10):\n",
" def __init__(self: Sampler, labels: list, batch_size: int = 10) -> None:\n",
" \"\"\"Initialize StratifiedSampler.\"\"\"\n",
" self.batch_size = batch_size\n",
" self.num_splits = int(len(labels) / self.batch_size)\n",
" self.labels = labels\n",
" self.num_steps = self.num_splits\n",
"\n",
" def _sampling(self):\n",
" def _sampling(self: Sampler) -> list:\n",
" \"\"\"Do we want to control randomness here.\"\"\"\n",
" skf = StratifiedKFold(n_splits=self.num_splits, shuffle=True)\n",
" indices = np.arange(len(self.labels)) # idx holder\n",
" # return array of arrays of indices in each batch\n",
" return [tidx for _, tidx in skf.split(indices, self.labels)]\n",
"\n",
" def __iter__(self):\n",
" def __iter__(self: Sampler) -> Iterator:\n",
" \"\"\"Define Iterator.\"\"\"\n",
" return iter(self._sampling())\n",
"\n",
" def __len__(self):\n",
" def __len__(self: Sampler) -> int:\n",
" \"\"\"The length of the sampler.\n",
"\n",
" This value actually corresponds to the number of steps to query\n",
Expand Down Expand Up @@ -1837,7 +1853,7 @@
},
"outputs": [],
"source": [
"def create_pbar(subset_name: str, num_steps: int):\n",
"def create_pbar(subset_name: str, num_steps: int) -> tqdm:\n",
" \"\"\"Create a nice progress bar.\"\"\"\n",
" pbar_format = (\n",
" \"Processing: |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]\"\n",
Expand All @@ -1860,13 +1876,13 @@
"class ScalarMovingAverage:\n",
" \"\"\"Class to calculate running average.\"\"\"\n",
"\n",
" def __init__(self, alpha=0.95):\n",
" def __init__(self: ScalarMovingAverage, alpha: float = 0.95) -> None:\n",
" \"\"\"Initialize ScalarMovingAverage.\"\"\"\n",
" super().__init__()\n",
" self.alpha = alpha\n",
" self.tracking_dict = {}\n",
"\n",
" def __call__(self, step_output):\n",
" def __call__(self: ScalarMovingAverage, step_output: dict) -> None:\n",
" \"\"\"ScalarMovingAverage instances behave and can be called like a function.\"\"\"\n",
" for key, current_value in step_output.items():\n",
" if key in self.tracking_dict:\n",
Expand Down Expand Up @@ -1929,16 +1945,16 @@
"outputs": [],
"source": [
"def run_once( # noqa: C901, PLR0912, PLR0915\n",
" dataset_dict,\n",
" num_epochs,\n",
" save_dir,\n",
" pretrained=None,\n",
" loader_kwargs=None,\n",
" arch_kwargs=None,\n",
" optim_kwargs=None,\n",
" dataset_dict: dict,\n",
" num_epochs: int,\n",
" save_dir: str | Path,\n",
" pretrained: str | None = None,\n",
" loader_kwargs: dict | None = None,\n",
" arch_kwargs: dict | None = None,\n",
" optim_kwargs: dict | None = None,\n",
" *,\n",
" on_gpu=True,\n",
"):\n",
" on_gpu: bool = True,\n",
") -> list:\n",
" \"\"\"Running the inference or training loop once.\n",
"\n",
" The actual running mode is defined via the code name of the dataset\n",
Expand Down Expand Up @@ -2092,7 +2108,7 @@
")\n",
"\n",
"\n",
"def reset_logging(save_path):\n",
"def reset_logging(save_path: str | Path) -> None:\n",
" \"\"\"Reset logger handler.\"\"\"\n",
" log_formatter = logging.Formatter(\n",
" \"|%(asctime)s.%(msecs)03d| [%(levelname)s] %(message)s\",\n",
Expand Down Expand Up @@ -2262,7 +2278,7 @@
" top_k: int = 2,\n",
" metric: str = \"infer-valid-auprc\",\n",
" epoch_range: tuple[int] | None = None,\n",
"):\n",
") -> tuple[list, list]:\n",
" \"\"\"Select checkpoints basing on training statistics.\n",
"\n",
" Args:\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/inference-pipelines/idars.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
"save_dir = Path(\"./tmp/\")\n",
"\n",
"\n",
"def rmdir(dir_path):\n",
"def rmdir(dir_path: Path) -> None:\n",
" \"\"\"Helper function to delete directory.\"\"\"\n",
" if dir_path.is_dir():\n",
" shutil.rmtree(dir_path)\n",
Expand Down
Loading

0 comments on commit e633e56

Please sign in to comment.