Skip to content

Commit

Permalink
Merge pull request #83 from JoHof/fix/no_download_when_path
Browse files Browse the repository at this point in the history
Fix/no download when path
  • Loading branch information
JoHof authored Jun 15, 2023
2 parents 8699eb9 + b285ecc commit dfce7e4
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 31 deletions.
23 changes: 14 additions & 9 deletions lungmask/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pkg_resources # type: ignore
import SimpleITK as sitk

from lungmask import mask, utils
from lungmask import LMInferer, utils


def path(string):
Expand Down Expand Up @@ -48,7 +48,6 @@ def main():
parser.add_argument(
"--classes",
help="spcifies the number of output classes of the model",
default=3,
)
parser.add_argument(
"--cpu",
Expand Down Expand Up @@ -86,6 +85,11 @@ def main():
argsin = sys.argv[1:]
args = parser.parse_args(argsin)

if args.classes is not None:
logging.warn(
"!!! Warning: The `classes` parameter is deprecated and will be removed in the next version !!!"
)

batchsize = args.batchsize
if args.cpu:
batchsize = 1
Expand All @@ -98,29 +102,30 @@ def main():
assert (
args.modelpath is None
), "Modelpath can not be specified for LTRCLobes_R231 mode"
result = mask.apply_fused(
input_image,
inferer = LMInferer(
modelname="LTRCLobes",
force_cpu=args.cpu,
fillmodel="R231",
batch_size=batchsize,
volume_postprocessing=not (args.nopostprocess),
noHU=args.noHU,
tqdm_disable=args.noprogress,
)
result = inferer.apply(input_image)
else:
model = mask.get_model(args.modelname, args.modelpath, args.classes)
result = mask.apply(
input_image,
model,
inferer = LMInferer(
modelname=args.modelname,
modelpath=args.modelpath,
force_cpu=args.cpu,
batch_size=batchsize,
volume_postprocessing=not (args.nopostprocess),
noHU=args.noHU,
tqdm_disable=args.noprogress,
)
result = inferer.apply(input_image)

if args.noHU:
file_ending = args.output.split(".")[-1]
print(file_ending)
if file_ending in ["jpg", "jpeg", "png"]:
result = (result / (result.max()) * 255).astype(np.uint8)
result = result[0]
Expand Down
43 changes: 31 additions & 12 deletions lungmask/mask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import sys
import warnings
from typing import Optional, Union
Expand Down Expand Up @@ -41,15 +42,12 @@
}


def get_model(
modelname: str, modelpath: Optional[str] = None, n_classes: int = 3
) -> torch.nn.Module:
def get_model(modelname: str, modelpath: Optional[str] = None) -> torch.nn.Module:
"""Loads specific model and state
Args:
modelname (str): Modelname (e.g. R231, LTRCLobes or R231CovidWeb)
modelpath (Optional[str], optional): Path to statedict, if not provided will be downloaded automatically. Modelname will be ignored if provided. Defaults to None.
n_classes (int, optional): Number of classes. Will be automatically set if modelname is provided. Defaults to 3.
Returns:
torch.nn.Module: Loaded model in eval state
Expand All @@ -62,6 +60,8 @@ def get_model(
else:
state_dict = torch.load(modelpath, map_location=torch.device("cpu"))

n_classes = len(list(state_dict.values())[-1])

model = UNet(
n_classes=n_classes,
padding=True,
Expand All @@ -78,19 +78,23 @@ def get_model(
class LMInferer:
def __init__(
self,
modelname="R231",
modelname: str = "R231",
modelpath: Optional[str] = None,
fillmodel: Optional[str] = None,
force_cpu=False,
batch_size=20,
volume_postprocessing=True,
noHU=False,
tqdm_disable=False,
fillmodel_path: Optional[str] = None,
force_cpu: bool = False,
batch_size: int = 20,
volume_postprocessing: bool = True,
noHU: bool = False,
tqdm_disable: bool = False,
):
"""LungMaskInference
Args:
modelname (str, optional): Model to be applied. Defaults to 'R231'.
modelpath (str, optional): Path to modeleights. `modelname` parameter will be ignored if provided. Defaults to None.
fillmodel (Optional[str], optional): Fillmodel to be applied. Defaults to None.
fillmodel_path (Optional[str], optional): Path to weights for fillmodel. `fillmodel` parameter will be ignored if provided. Defaults to None.
force_cpu (bool, optional): Will not use GPU is `True`. Defaults to False.
batch_size (int, optional): Batch size. Defaults to 20.
volume_postprocessing (bool, optional): If `Fales` will not perform postprocessing (connected component analysis). Defaults to True.
Expand All @@ -104,6 +108,13 @@ def __init__(
assert (
fillmodel in MODEL_URLS
), "Modelname not found. Please choose from: {}".format(MODEL_URLS.keys())

# if paths provided, overwrite name
if modelpath is not None:
modelname = os.path.basename(modelpath)
if fillmodel_path is not None:
fillmodel = os.path.basename(fillmodel_path)

self.fillmodel = fillmodel
self.modelname = modelname
self.force_cpu = force_cpu
Expand All @@ -112,7 +123,7 @@ def __init__(
self.noHU = noHU
self.tqdm_disable = tqdm_disable

self.model = get_model(self.modelname)
self.model = get_model(self.modelname, modelpath)

self.device = torch.device("cpu")
if not self.force_cpu:
Expand All @@ -124,7 +135,7 @@ def __init__(

self.fillmodelm = None
if self.fillmodel is not None:
self.fillmodelm = get_model(self.fillmodel)
self.fillmodelm = get_model(self.fillmodel, fillmodel_path)
self.fillmodelm.to(self.device)

def _inference(
Expand Down Expand Up @@ -250,6 +261,10 @@ def apply(
noHU=False,
tqdm_disable=False,
):
warnings.warn(
"The function `apply` will be removed in a future version. Please use the LMInferer class!",
DeprecationWarning,
)
inferer = LMInferer(
force_cpu=force_cpu,
batch_size=batch_size,
Expand All @@ -272,6 +287,10 @@ def apply_fused(
noHU=False,
tqdm_disable=False,
):
warnings.warn(
"The function `apply_fused` will be removed in a future version. Please use the LMInferer class!",
DeprecationWarning,
)
inferer = LMInferer(
modelname=basemodel,
force_cpu=force_cpu,
Expand Down
1 change: 1 addition & 0 deletions lungmask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def postprocessing(
Returns:
np.ndarray: Postprocessed volume
"""
logging.info("Postprocessing")

# CC analysis
regionmask = skimage.measure.label(label_image)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="lungmask",
version="0.2.14",
version="0.2.15",
author="Johannes Hofmanninger",
author_email="[email protected]",
description="Package for automated lung segmentation in CT",
Expand Down
47 changes: 41 additions & 6 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import shutil

import numpy as np
import pydicom as pyd
import pytest
import torch

from lungmask.mask import LMInferer, apply, apply_fused
from lungmask.mask import MODEL_URLS, LMInferer
from lungmask.utils import read_dicoms


Expand All @@ -13,13 +14,47 @@ def fixture_testvol():
return read_dicoms(os.path.join(os.path.dirname(__file__), "testdata"))[0]


def test_apply(fixture_testvol):
res = apply(fixture_testvol)
@pytest.fixture(scope="session")
def fixture_weights_path_R231(tmpdir_factory):
# we make sure the model is there
torch.hub.load_state_dict_from_url(
MODEL_URLS["R231"][0], progress=True, map_location=torch.device("cpu")
)
modelbasename = os.path.basename(MODEL_URLS["R231"][0])
modelpath = os.path.join(torch.hub.get_dir(), "checkpoints", modelbasename)
tmppath = str(tmpdir_factory.mktemp("weights").join(modelbasename))
shutil.copy(modelpath, tmppath)
return tmppath


def test_LMInferer(fixture_testvol, fixture_weights_path_R231):
inferer = LMInferer(
force_cpu=True,
tqdm_disable=True,
)
res = inferer.apply(fixture_testvol)
assert np.all(np.unique(res, return_counts=True)[1] == [423000, 64752, 36536])

# here, we provide a path to the R231 weights but specify LTRCLobes (6 channel) as modelname
# The modelname should be ignored and a 3 channel output should be generated
inferer = LMInferer(
modelname="LTRCLobes",
modelpath=fixture_weights_path_R231,
force_cpu=True,
tqdm_disable=True,
)
res = inferer.apply(fixture_testvol)
assert np.all(np.unique(res, return_counts=True)[1] == [423000, 64752, 36536])

def test_apply_fused(fixture_testvol):
res = apply_fused(fixture_testvol)

def test_LMInferer_fused(fixture_testvol):
inferer = LMInferer(
modelname="LTRCLobes",
force_cpu=True,
fillmodel="R231",
tqdm_disable=True,
)
res = inferer.apply(fixture_testvol)
assert np.all(
np.unique(res, return_counts=True)[1] == [423000, 13334, 23202, 23834, 40918]
)
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os

import numpy as np
import pydicom as pd
import pydicom as pyd
import SimpleITK as sitk
from pydicom.dataset import FileMetaDataset

from lungmask.utils import (
bbox_3D,
Expand All @@ -18,6 +15,9 @@
)

# creating test dicom data for reference
# import pydicom as pd
# import pydicom as pyd
# from pydicom.dataset import FileMetaDataset
#
# studyuid = pyd.uid.generate_uid()
# seriesuid = pyd.uid.generate_uid()
Expand Down

0 comments on commit dfce7e4

Please sign in to comment.