Skip to content

Commit

Permalink
Feature/plugin updates (AllenCellModeling#345)
Browse files Browse the repository at this point in the history
* add timeout

* add autothreshold as default postprocessing;

* update with Dannys comments

* precommit

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Mar 6, 2024
1 parent 1c11e5b commit 568fdc3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
6 changes: 2 additions & 4 deletions configs/model/im2im/segmentation_plugin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ task_heads:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
prediction:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
activation:
_target_: torch.nn.Sigmoid
rescale_dtype: numpy.uint8
_target_: cyto_dl.models.im2im.utils.postprocessing.AutoThreshold
method: "threshold_otsu"
save_input: True

optimizer:
Expand Down
2 changes: 2 additions & 0 deletions configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ check_val_every_n_epoch: 1
# makes training slower but gives more reproducibility than just setting seeds
deterministic: False
detect_anomaly: False

max_time: null
8 changes: 6 additions & 2 deletions cyto_dl/models/im2im/utils/postprocessing/auto_thresh.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import importlib
from typing import Union
from typing import Optional, Union


class AutoThreshold:
def __init__(self, method: Union[float, str]):
def __init__(self, method: Optional[Union[float, str]] = None):
if isinstance(method, float):

def thresh_func(image):
Expand All @@ -14,9 +14,13 @@ def thresh_func(image):
thresh_func = getattr(importlib.import_module("skimage.filters"), method)
except AttributeError:
raise AttributeError(f"method {method} not found in skimage.filters")
elif method is None:
thresh_func = None
else:
raise TypeError("method must be a float or a string")
self.thresh_func = thresh_func

def __call__(self, image):
if self.thresh_func is None:
return image
return image > self.thresh_func(image)

0 comments on commit 568fdc3

Please sign in to comment.