Skip to content

Commit

Permalink
adding max_size_fraction option (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Sep 12, 2024
1 parent be48d27 commit 3b63be8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
19 changes: 13 additions & 6 deletions cellpose/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def mem_info():
return masks


def get_masks(p, iscell=None, rpad=20):
def get_masks(p, iscell=None, rpad=20, max_size_fraction=0.4):
"""Create masks using pixel convergence after running dynamics.
Makes a histogram of final pixel locations p, initializes masks
Expand All @@ -677,6 +677,8 @@ def get_masks(p, iscell=None, rpad=20):
iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
iscell False to stay in their original location.
rpad (int, optional): Histogram edge padding. Default is 20.
max_size_fraction (float, optional): Masks larger than max_size_fraction of
total image size are removed. Default is 0.4.
Returns:
M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
Expand Down Expand Up @@ -750,7 +752,7 @@ def get_masks(p, iscell=None, rpad=20):

# remove big masks
uniq, counts = fastremap.unique(M0, return_counts=True)
big = np.prod(shape0) * 0.4
big = np.prod(shape0) * max_size_fraction
bigc = uniq[counts > big]
if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
M0 = fastremap.mask(M0, bigc)
Expand All @@ -761,7 +763,7 @@ def get_masks(p, iscell=None, rpad=20):

def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False, min_size=15,
resize=None, device=None):
max_size_fraction=0.4, resize=None, device=None):
"""Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
Args:
Expand All @@ -774,6 +776,8 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
min_size (int, optional): The minimum size of the masks. Defaults to 15.
max_size_fraction (float, optional): Masks larger than max_size_fraction of
total image size are removed. Default is 0.4.
resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
device (str, optional): The torch device to use for computation. Defaults to None.
Expand All @@ -783,7 +787,8 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold
mask, p = compute_masks(dP, cellprob, p=p, niter=niter,
cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
min_size=min_size, device=device)
min_size=min_size, max_size_fraction=max_size_fraction,
device=device)

if resize is not None:
mask = transforms.resize_image(mask, resize[0], resize[1],
Expand All @@ -798,7 +803,7 @@ def resize_and_compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold

def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
flow_threshold=0.4, interp=True, do_3D=False, min_size=15,
device=None):
max_size_fraction=0.4, device=None):
"""Compute masks using dynamics from dP and cellprob.
Args:
Expand All @@ -811,6 +816,8 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
min_size (int, optional): The minimum size of the masks. Defaults to 15.
max_size_fraction (float, optional): Masks larger than max_size_fraction of
total image size are removed. Default is 0.4.
device (str, optional): The torch device to use for computation. Defaults to None.
Returns:
Expand All @@ -831,7 +838,7 @@ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
return mask, p

#calculate masks
mask = get_masks(p, iscell=cp_mask)
mask = get_masks(p, iscell=cp_mask, max_size_fraction=max_size_fraction)

# flow thresholding factored out of get_masks
if not do_3D:
Expand Down
21 changes: 13 additions & 8 deletions cellpose/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(self, gpu=False, model_type="cyto3", nchan=2, device=None,
self.sz.model_type = model_type

def eval(self, x, batch_size=8, channels=[0, 0], channel_axis=None, invert=False,
normalize=True, diameter=30., do_3D=False, find_masks=True, **kwargs):
normalize=True, diameter=30., do_3D=False, **kwargs):
"""Run cellpose size model and mask model and get masks.
Args:
Expand Down Expand Up @@ -353,9 +353,9 @@ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile=True,
tile_overlap=0.1, bsize=224, interp=True, compute_masks=True,
progress=None):
stitch_threshold=0.0, min_size=15, max_size_fraction=0.4, niter=None,
augment=False, tile=True, tile_overlap=0.1, bsize=224,
interp=True, compute_masks=True, progress=None):
""" segment list of images x, or 4D array - Z x nchan x Y x X
Args:
Expand Down Expand Up @@ -394,6 +394,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
max_size_fraction (float, optional): max_size_fraction (float, optional): Masks larger than max_size_fraction of
total image size are removed. Default is 0.4.
niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
augment (bool, optional): tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
tile (bool, optional): tiles image to ensure GPU/CPU memory usage limited (recommended). Defaults to True.
Expand Down Expand Up @@ -435,7 +437,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
tile_overlap=tile_overlap, bsize=bsize, resample=resample,
interp=interp, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, compute_masks=compute_masks,
min_size=min_size, stitch_threshold=stitch_threshold,
min_size=min_size, max_size_fraction=max_size_fraction,
stitch_threshold=stitch_threshold,
progress=progress, niter=niter)
masks.append(maski)
flows.append(flowi)
Expand Down Expand Up @@ -464,7 +467,7 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
rescale=rescale, resample=resample, augment=augment, tile=tile,
tile_overlap=tile_overlap, bsize=bsize, flow_threshold=flow_threshold,
cellprob_threshold=cellprob_threshold, interp=interp, min_size=min_size,
do_3D=do_3D, anisotropy=anisotropy, niter=niter,
max_size_fraction=max_size_fraction, do_3D=do_3D, anisotropy=anisotropy, niter=niter,
stitch_threshold=stitch_threshold)

flows = [plot.dx_to_circ(dP), dP, cellprob, p]
Expand All @@ -473,7 +476,8 @@ def eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None,
def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=None,
rescale=1.0, resample=True, augment=False, tile=True, tile_overlap=0.1,
cellprob_threshold=0.0, bsize=224, flow_threshold=0.4, min_size=15,
interp=True, anisotropy=1.0, do_3D=False, stitch_threshold=0.0):
max_size_fraction=0.4, interp=True, anisotropy=1.0, do_3D=False,
stitch_threshold=0.0):

if isinstance(normalize, dict):
normalize_params = {**normalize_default, **normalize}
Expand Down Expand Up @@ -538,7 +542,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
masks, p = dynamics.resize_and_compute_masks(
dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
flow_threshold=flow_threshold, interp=interp, do_3D=do_3D,
min_size=min_size, resize=None,
min_size=min_size, max_size_fraction=max_size_fraction, resize=None,
device=self.device if self.gpu else None)
else:
masks, p = [], []
Expand All @@ -557,6 +561,7 @@ def _run_cp(self, x, compute_masks=True, normalize=True, invert=False, niter=Non
resize=resize,
min_size=min_size if stitch_threshold == 0 or nimg == 1 else
-1, # turn off for 3D stitching
max_size_fraction=max_size_fraction,
device=self.device if self.gpu else None)
masks.append(outputs[0])
p.append(outputs[1])
Expand Down

0 comments on commit 3b63be8

Please sign in to comment.