Skip to content

Commit

Permalink
update paths to use dacapo_toolbox
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Mar 4, 2025
1 parent 73fdb9c commit 57c0a8a
Show file tree
Hide file tree
Showing 42 changed files with 2,890 additions and 198 deletions.
2 changes: 1 addition & 1 deletion src/dacapo_toolbox/architectures/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def save_bioimage_io_model(
checkpoint: int | str | None = None,
in_voxel_size: Coordinate | None = None,
):
from dacapo_toolbox.run_config import RunConfig
from dacapo.experiments.run_config import RunConfig

run = RunConfig(name=f"{self.name}-bioimage-io", architecture_config=self)
run.save_bioimage_io_model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def __init__(
)

else:
layers.append(torch.nn.Upsample(scale_factor=scale_factor, mode=mode))
layers.append(torch.nn.Upsample(scale_factor=tuple(scale_factor), mode=mode))
conv = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}[self.dims]
layers.append(
conv(
Expand Down
2 changes: 1 addition & 1 deletion src/dacapo_toolbox/architectures/wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ def num_out_channels(self):

def scale(self, input_voxel_size: Coordinate) -> Coordinate:
if self._scale is not None:
return input_voxel_size // self._scale
return input_voxel_size // Coordinate(self._scale)
else:
return input_voxel_size
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from funlib.persistence import Array

from typing import List, Tuple
from dacapo.tmp import num_channels_from_array
from dacapo_toolbox.tmp import num_channels_from_array

import dask.array as da

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .array_config import ArrayConfig
from funlib.persistence import Array
import dask.array as da
from dacapo.tmp import num_channels_from_array
from dacapo_toolbox.tmp import num_channels_from_array


@attr.s
Expand Down
7 changes: 7 additions & 0 deletions src/dacapo_toolbox/gp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .dacapo_create_target import DaCapoTargetFilter
from .gamma_noise import GammaAugment
from .elastic_augment_fuse import ElasticAugment
from .reject_if_empty import RejectIfEmpty
from .copy import CopyMask
from .dacapo_points_source import GraphSource
from .product import Product
103 changes: 103 additions & 0 deletions src/dacapo_toolbox/gp/copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import gunpowder as gp


class CopyMask(gp.BatchFilter):
"""
A class to copy a mask into a new key with the option to drop channels via max collapse.
Attributes:
array_key (gp.ArrayKey): Original key of the array from where the mask will be copied.
copy_key (gp.ArrayKey): New key where the copied mask will reside.
drop_channels (bool): If True, channels will be dropped via a max collapse.
Methods:
setup: Sets up the filter by enabling autoskip and providing the copied key.
prepare: Prepares the filter by copying the request of copy_key into a dependency.
process: Processes the batch by copying the mask from the array_key to the copy_key.
Note:
This class is a subclass of gunpowder.BatchFilter and is used to
copy a mask into a new key with the option to drop channels via max collapse.
"""

def __init__(
self, array_key: gp.ArrayKey, copy_key: gp.ArrayKey, drop_channels: bool = False
):
"""
Constructs the necessary attributes for the CopyMask object.
Args:
array_key (gp.ArrayKey): Original key of the array from where the mask will be copied.
copy_key (gp.ArrayKey): New key where the copied mask will reside.
drop_channels (bool): If True, channels will be dropped via a max collapse. Default is False.
Raises:
TypeError: If array_key is not of type gp.ArrayKey.
TypeError: If copy_key is not of type gp.ArrayKey.
Examples:
>>> array_key = gp.ArrayKey("ARRAY")
>>> copy_key = gp.ArrayKey("COPY")
>>> copy_mask = CopyMask(array_key, copy_key)
"""
self.array_key = array_key
self.copy_key = copy_key
self.drop_channels = drop_channels

def setup(self):
"""
Sets up the filter by enabling autoskip and providing the copied key.
Raises:
RuntimeError: If the key is already provided.
Examples:
>>> copy_mask.setup()
"""
self.enable_autoskip()
self.provides(self.copy_key, self.spec[self.array_key].copy())

def prepare(self, request):
"""
Prepares the filter by copying the request of copy_key into a dependency.
Args:
request: The request to prepare.
Returns:
deps: The prepared dependencies.
Raises:
NotImplementedError: If the copy_key is not provided.
Examples:
>>> request = gp.BatchRequest()
>>> request[self.copy_key] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1)))
>>> copy_mask.prepare(request)
"""
deps = gp.BatchRequest()
deps[self.array_key] = request[self.copy_key].copy()
return deps

def process(self, batch, request):
"""
Processes the batch by copying the mask from the array_key to the copy_key.
If "drop_channels" attribute is True, it performs max collapse.
Args:
batch: The batch to process.
request: The request for processing.
Returns:
outputs: The processed outputs.
Raises:
KeyError: If the requested key is not in the request.
Examples:
>>> request = gp.BatchRequest()
>>> request[gp.ArrayKey("ARRAY")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1)))
>>> copy_mask.process(batch, request)
"""
outputs = gp.Batch()

outputs[self.copy_key] = batch[self.array_key]
if self.drop_channels:
while (
outputs[self.copy_key].data.ndim
> outputs[self.copy_key].spec.voxel_size.dims
):
outputs[self.copy_key].data = outputs[self.copy_key].data.max(axis=0)

return outputs
177 changes: 177 additions & 0 deletions src/dacapo_toolbox/gp/dacapo_create_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from dacapo_toolbox.tasks.predictors import Predictor
from dacapo_toolbox.tmp import gp_to_funlib_array

import gunpowder as gp

from typing import Optional


class DaCapoTargetFilter(gp.BatchFilter):
"""
A Gunpowder node for generating the target from the ground truth
Attributes:
Predictor (Predictor):
The DaCapo Predictor to use to transform gt into target
gt (``Array``):
The dataset to use for generating the target.
target_key (``gp.ArrayKey``):
The key with which to provide the target.
weights_key (``gp.ArrayKey``):
The key with which to provide the weights.
mask_key (``gp.ArrayKey``):
The key with which to provide the mask.
Methods:
setup(): Set up the provider.
prepare(request): Prepare the request.
process(batch, request): Process the batch.
Note:
This class is a subclass of gunpowder.BatchFilter and is used to
generate the target from the ground truth.
"""

def __init__(
self,
predictor: Predictor,
gt_key: gp.ArrayKey,
target_key: Optional[gp.ArrayKey] = None,
weights_key: Optional[gp.ArrayKey] = None,
mask_key: Optional[gp.ArrayKey] = None,
):
"""
Initialize the DacapoCreateTarget object.
Args:
predictor (Predictor): The predictor object used for prediction.
gt_key (gp.ArrayKey): The ground truth key.
target_key (Optional[gp.ArrayKey]): The target key. Defaults to None.
weights_key (Optional[gp.ArrayKey]): The weights key. Defaults to None.
mask_key (Optional[gp.ArrayKey]): The mask key. Defaults to None.
Raises:
AssertionError: If neither target_key nor weights_key is provided.
Examples:
>>> from dacapo_toolbox.tasks.predictors import Predictor
>>> from gunpowder import ArrayKey
>>> from gunpowder import ArrayKey
>>> from gunpowder import ArrayKey
>>> predictor = Predictor()
>>> gt_key = ArrayKey("GT")
>>> target_key = ArrayKey("TARGET")
>>> weights_key = ArrayKey("WEIGHTS")
>>> mask_key = ArrayKey("MASK")
>>> target_filter = DaCapoTargetFilter(predictor, gt_key, target_key, weights_key, mask_key)
Note:
The target filter is used to generate the target from the ground truth.
"""
self.predictor = predictor
self.gt_key = gt_key
self.target_key = target_key
self.weights_key = weights_key
self.mask_key = mask_key

self.moving_counts = None

assert (
target_key is not None or weights_key is not None
), "Must provide either target or weights"

def setup(self):
"""
Set up the provider. This function sets the provider to provide the
target with the given key.
Raises:
RuntimeError: If the key is already provided.
Examples:
>>> target_filter.setup()
"""
provided_spec = gp.ArraySpec(
roi=self.spec[self.gt_key].roi,
voxel_size=self.spec[self.gt_key].voxel_size,
interpolatable=self.predictor.output_array_type.interpolatable,
)
if self.target_key is not None:
self.provides(self.target_key, provided_spec)

provided_spec = gp.ArraySpec(
roi=self.spec[self.gt_key].roi,
voxel_size=self.spec[self.gt_key].voxel_size,
interpolatable=True,
)
if self.weights_key is not None:
self.provides(self.weights_key, provided_spec)

def prepare(self, request):
"""
Prepare the request.
Args:
request (gp.BatchRequest): The request to prepare.
Returns:
deps (gp.BatchRequest): The dependencies.
Raises:
NotImplementedError: If the target_key is not provided.
Examples:
>>> request = gp.BatchRequest()
>>> request[gp.ArrayKey("GT")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1)))
>>> target_filter.prepare(request)
"""
deps = gp.BatchRequest()
# TODO: Does the gt depend on weights too?
request_spec = None
if self.target_key is not None:
request_spec = request[self.target_key]
request_spec.voxel_size = self.spec[self.gt_key].voxel_size
request_spec = self.predictor.gt_region_for_roi(request_spec)
elif self.weights_key is not None:
request_spec = request[self.weights_key].copy()
else:
raise NotImplementedError("Should not be reached!")
assert request_spec is not None
deps[self.gt_key] = request_spec
if self.mask_key is not None:
deps[self.mask_key] = request_spec
return deps

def process(self, batch, request):
"""
Process the batch.
Args:
batch (gp.Batch): The batch to process.
request (gp.BatchRequest): The request to process.
Returns:
output (gp.Batch): The output batch.
Examples:
>>> request = gp.BatchRequest()
>>> request[gp.ArrayKey("GT")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1)))
>>> target_filter.process(batch, request)
"""
output = gp.Batch()

gt_array = gp_to_funlib_array(batch[self.gt_key])
target_array = self.predictor.create_target(gt_array)
mask_array = gp_to_funlib_array(batch[self.mask_key])

if self.target_key is not None:
request_spec = request[self.target_key]
request_spec.voxel_size = gt_array.voxel_size
output[self.target_key] = gp.Array(
target_array[request_spec.roi], request_spec
)
if self.weights_key is not None:
weight_array, self.moving_counts = self.predictor.create_weight(
gt_array,
target_array,
mask=mask_array,
moving_class_counts=self.moving_counts,
)
request_spec = request[self.weights_key]
request_spec.voxel_size = gt_array.voxel_size
output[self.weights_key] = gp.Array(
weight_array[request_spec.roi], request_spec
)
return output
Loading

0 comments on commit 57c0a8a

Please sign in to comment.