-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
42 changed files
with
2,890 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .dacapo_create_target import DaCapoTargetFilter | ||
from .gamma_noise import GammaAugment | ||
from .elastic_augment_fuse import ElasticAugment | ||
from .reject_if_empty import RejectIfEmpty | ||
from .copy import CopyMask | ||
from .dacapo_points_source import GraphSource | ||
from .product import Product |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import gunpowder as gp | ||
|
||
|
||
class CopyMask(gp.BatchFilter): | ||
""" | ||
A class to copy a mask into a new key with the option to drop channels via max collapse. | ||
Attributes: | ||
array_key (gp.ArrayKey): Original key of the array from where the mask will be copied. | ||
copy_key (gp.ArrayKey): New key where the copied mask will reside. | ||
drop_channels (bool): If True, channels will be dropped via a max collapse. | ||
Methods: | ||
setup: Sets up the filter by enabling autoskip and providing the copied key. | ||
prepare: Prepares the filter by copying the request of copy_key into a dependency. | ||
process: Processes the batch by copying the mask from the array_key to the copy_key. | ||
Note: | ||
This class is a subclass of gunpowder.BatchFilter and is used to | ||
copy a mask into a new key with the option to drop channels via max collapse. | ||
""" | ||
|
||
def __init__( | ||
self, array_key: gp.ArrayKey, copy_key: gp.ArrayKey, drop_channels: bool = False | ||
): | ||
""" | ||
Constructs the necessary attributes for the CopyMask object. | ||
Args: | ||
array_key (gp.ArrayKey): Original key of the array from where the mask will be copied. | ||
copy_key (gp.ArrayKey): New key where the copied mask will reside. | ||
drop_channels (bool): If True, channels will be dropped via a max collapse. Default is False. | ||
Raises: | ||
TypeError: If array_key is not of type gp.ArrayKey. | ||
TypeError: If copy_key is not of type gp.ArrayKey. | ||
Examples: | ||
>>> array_key = gp.ArrayKey("ARRAY") | ||
>>> copy_key = gp.ArrayKey("COPY") | ||
>>> copy_mask = CopyMask(array_key, copy_key) | ||
""" | ||
self.array_key = array_key | ||
self.copy_key = copy_key | ||
self.drop_channels = drop_channels | ||
|
||
def setup(self): | ||
""" | ||
Sets up the filter by enabling autoskip and providing the copied key. | ||
Raises: | ||
RuntimeError: If the key is already provided. | ||
Examples: | ||
>>> copy_mask.setup() | ||
""" | ||
self.enable_autoskip() | ||
self.provides(self.copy_key, self.spec[self.array_key].copy()) | ||
|
||
def prepare(self, request): | ||
""" | ||
Prepares the filter by copying the request of copy_key into a dependency. | ||
Args: | ||
request: The request to prepare. | ||
Returns: | ||
deps: The prepared dependencies. | ||
Raises: | ||
NotImplementedError: If the copy_key is not provided. | ||
Examples: | ||
>>> request = gp.BatchRequest() | ||
>>> request[self.copy_key] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) | ||
>>> copy_mask.prepare(request) | ||
""" | ||
deps = gp.BatchRequest() | ||
deps[self.array_key] = request[self.copy_key].copy() | ||
return deps | ||
|
||
def process(self, batch, request): | ||
""" | ||
Processes the batch by copying the mask from the array_key to the copy_key. | ||
If "drop_channels" attribute is True, it performs max collapse. | ||
Args: | ||
batch: The batch to process. | ||
request: The request for processing. | ||
Returns: | ||
outputs: The processed outputs. | ||
Raises: | ||
KeyError: If the requested key is not in the request. | ||
Examples: | ||
>>> request = gp.BatchRequest() | ||
>>> request[gp.ArrayKey("ARRAY")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) | ||
>>> copy_mask.process(batch, request) | ||
""" | ||
outputs = gp.Batch() | ||
|
||
outputs[self.copy_key] = batch[self.array_key] | ||
if self.drop_channels: | ||
while ( | ||
outputs[self.copy_key].data.ndim | ||
> outputs[self.copy_key].spec.voxel_size.dims | ||
): | ||
outputs[self.copy_key].data = outputs[self.copy_key].data.max(axis=0) | ||
|
||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
from dacapo_toolbox.tasks.predictors import Predictor | ||
from dacapo_toolbox.tmp import gp_to_funlib_array | ||
|
||
import gunpowder as gp | ||
|
||
from typing import Optional | ||
|
||
|
||
class DaCapoTargetFilter(gp.BatchFilter): | ||
""" | ||
A Gunpowder node for generating the target from the ground truth | ||
Attributes: | ||
Predictor (Predictor): | ||
The DaCapo Predictor to use to transform gt into target | ||
gt (``Array``): | ||
The dataset to use for generating the target. | ||
target_key (``gp.ArrayKey``): | ||
The key with which to provide the target. | ||
weights_key (``gp.ArrayKey``): | ||
The key with which to provide the weights. | ||
mask_key (``gp.ArrayKey``): | ||
The key with which to provide the mask. | ||
Methods: | ||
setup(): Set up the provider. | ||
prepare(request): Prepare the request. | ||
process(batch, request): Process the batch. | ||
Note: | ||
This class is a subclass of gunpowder.BatchFilter and is used to | ||
generate the target from the ground truth. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
predictor: Predictor, | ||
gt_key: gp.ArrayKey, | ||
target_key: Optional[gp.ArrayKey] = None, | ||
weights_key: Optional[gp.ArrayKey] = None, | ||
mask_key: Optional[gp.ArrayKey] = None, | ||
): | ||
""" | ||
Initialize the DacapoCreateTarget object. | ||
Args: | ||
predictor (Predictor): The predictor object used for prediction. | ||
gt_key (gp.ArrayKey): The ground truth key. | ||
target_key (Optional[gp.ArrayKey]): The target key. Defaults to None. | ||
weights_key (Optional[gp.ArrayKey]): The weights key. Defaults to None. | ||
mask_key (Optional[gp.ArrayKey]): The mask key. Defaults to None. | ||
Raises: | ||
AssertionError: If neither target_key nor weights_key is provided. | ||
Examples: | ||
>>> from dacapo_toolbox.tasks.predictors import Predictor | ||
>>> from gunpowder import ArrayKey | ||
>>> from gunpowder import ArrayKey | ||
>>> from gunpowder import ArrayKey | ||
>>> predictor = Predictor() | ||
>>> gt_key = ArrayKey("GT") | ||
>>> target_key = ArrayKey("TARGET") | ||
>>> weights_key = ArrayKey("WEIGHTS") | ||
>>> mask_key = ArrayKey("MASK") | ||
>>> target_filter = DaCapoTargetFilter(predictor, gt_key, target_key, weights_key, mask_key) | ||
Note: | ||
The target filter is used to generate the target from the ground truth. | ||
""" | ||
self.predictor = predictor | ||
self.gt_key = gt_key | ||
self.target_key = target_key | ||
self.weights_key = weights_key | ||
self.mask_key = mask_key | ||
|
||
self.moving_counts = None | ||
|
||
assert ( | ||
target_key is not None or weights_key is not None | ||
), "Must provide either target or weights" | ||
|
||
def setup(self): | ||
""" | ||
Set up the provider. This function sets the provider to provide the | ||
target with the given key. | ||
Raises: | ||
RuntimeError: If the key is already provided. | ||
Examples: | ||
>>> target_filter.setup() | ||
""" | ||
provided_spec = gp.ArraySpec( | ||
roi=self.spec[self.gt_key].roi, | ||
voxel_size=self.spec[self.gt_key].voxel_size, | ||
interpolatable=self.predictor.output_array_type.interpolatable, | ||
) | ||
if self.target_key is not None: | ||
self.provides(self.target_key, provided_spec) | ||
|
||
provided_spec = gp.ArraySpec( | ||
roi=self.spec[self.gt_key].roi, | ||
voxel_size=self.spec[self.gt_key].voxel_size, | ||
interpolatable=True, | ||
) | ||
if self.weights_key is not None: | ||
self.provides(self.weights_key, provided_spec) | ||
|
||
def prepare(self, request): | ||
""" | ||
Prepare the request. | ||
Args: | ||
request (gp.BatchRequest): The request to prepare. | ||
Returns: | ||
deps (gp.BatchRequest): The dependencies. | ||
Raises: | ||
NotImplementedError: If the target_key is not provided. | ||
Examples: | ||
>>> request = gp.BatchRequest() | ||
>>> request[gp.ArrayKey("GT")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) | ||
>>> target_filter.prepare(request) | ||
""" | ||
deps = gp.BatchRequest() | ||
# TODO: Does the gt depend on weights too? | ||
request_spec = None | ||
if self.target_key is not None: | ||
request_spec = request[self.target_key] | ||
request_spec.voxel_size = self.spec[self.gt_key].voxel_size | ||
request_spec = self.predictor.gt_region_for_roi(request_spec) | ||
elif self.weights_key is not None: | ||
request_spec = request[self.weights_key].copy() | ||
else: | ||
raise NotImplementedError("Should not be reached!") | ||
assert request_spec is not None | ||
deps[self.gt_key] = request_spec | ||
if self.mask_key is not None: | ||
deps[self.mask_key] = request_spec | ||
return deps | ||
|
||
def process(self, batch, request): | ||
""" | ||
Process the batch. | ||
Args: | ||
batch (gp.Batch): The batch to process. | ||
request (gp.BatchRequest): The request to process. | ||
Returns: | ||
output (gp.Batch): The output batch. | ||
Examples: | ||
>>> request = gp.BatchRequest() | ||
>>> request[gp.ArrayKey("GT")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) | ||
>>> target_filter.process(batch, request) | ||
""" | ||
output = gp.Batch() | ||
|
||
gt_array = gp_to_funlib_array(batch[self.gt_key]) | ||
target_array = self.predictor.create_target(gt_array) | ||
mask_array = gp_to_funlib_array(batch[self.mask_key]) | ||
|
||
if self.target_key is not None: | ||
request_spec = request[self.target_key] | ||
request_spec.voxel_size = gt_array.voxel_size | ||
output[self.target_key] = gp.Array( | ||
target_array[request_spec.roi], request_spec | ||
) | ||
if self.weights_key is not None: | ||
weight_array, self.moving_counts = self.predictor.create_weight( | ||
gt_array, | ||
target_array, | ||
mask=mask_array, | ||
moving_class_counts=self.moving_counts, | ||
) | ||
request_spec = request[self.weights_key] | ||
request_spec.voxel_size = gt_array.voxel_size | ||
output[self.weights_key] = gp.Array( | ||
weight_array[request_spec.roi], request_spec | ||
) | ||
return output |
Oops, something went wrong.