Skip to content

Commit

Permalink
WIP(transforms): remove masks from tissue_objects
Browse files Browse the repository at this point in the history
  • Loading branch information
Qovaxx committed Jun 29, 2020
1 parent e33b5fb commit 07b31d0
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 30 deletions.
26 changes: 14 additions & 12 deletions src/psga/transforms/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from . import functional as F
from .entity import (
TissueObject,
TissueObjects,
Rectangle
)

Expand All @@ -21,27 +21,29 @@ def show(image):
plt.show()


def convert_to_atlas(image: np.ndarray, tissue_objects: Optional[List[TissueObject]] = None,
def convert_to_atlas(image: np.ndarray, tissue_objects: Optional[TissueObjects] = None,
min_contour_area: int = 200,
background_value: int = 255) -> Tuple[np.ndarray, List[TissueObject]]:
background_value: int = 255) -> Tuple[np.ndarray, TissueObjects]:
if tissue_objects is None:
gray = cv2.cvtColor(image, code=cv2.COLOR_RGB2GRAY)
_, not_background_mask = cv2.threshold(gray, thresh=background_value - 1, maxval=image.max(),
type=cv2.THRESH_BINARY_INV)
found_contours, _ = cv2.findContours(not_background_mask, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
tissue_objects = list()
rectangles = list()
mask = np.full(image.shape[:2], fill_value=0, dtype=np.uint8)

for contour in found_contours:
if cv2.contourArea(contour) >= min_contour_area:
contour_mask = np.full(image.shape[:2], fill_value=0, dtype=np.uint8)
cv2.drawContours(contour_mask, contours=[contour], contourIdx=0, color=1, thickness=-1)
color = len(rectangles) + 1
cv2.drawContours(mask, contours=[contour], contourIdx=0, color=color, thickness=-1)
rectangle = cv2.minAreaRect(points=contour)
rectangle = Rectangle(center_x=rectangle[0][0],
center_y=rectangle[0][1],
width=rectangle[1][0],
height=rectangle[1][1],
angle=rectangle[2])
tissue_objects.append(TissueObject(mask=contour_mask.astype(np.bool), rectangle=rectangle))
rectangles.append(Rectangle(center_x=rectangle[0][0],
center_y=rectangle[0][1],
width=rectangle[1][0],
height=rectangle[1][1],
angle=rectangle[2]))

tissue_objects = TissueObjects(mask, rectangles)

image = F.pack_atlas(image, tissue_objects)
return image, tissue_objects
16 changes: 8 additions & 8 deletions src/psga/transforms/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,24 @@ def rescale(self, scale: int) -> NoReturn:


@dataclass
class TissueObject(BaseEntity):
class TissueObjects(BaseEntity):
mask: np.ndarray
rectangle: Rectangle
rectangles: List[Rectangle]

def rescale(self, scale: int, **kwargs) -> NoReturn:
shape = tuple(np.asarray(self.mask.shape) * scale)
mask = self.mask.astype(np.uint8)
self.mask = cv2.resize(mask, dsize=shape[::-1], interpolation=cv2.INTER_NEAREST).astype(np.bool)
self.rectangle.rescale(scale)
self.mask = cv2.resize(self.mask, dsize=shape[::-1], interpolation=cv2.INTER_NEAREST)
for rectangle in self.rectangles:
rectangle.rescale(scale)


@dataclass
class Intermediates(BaseEntity):
external_bbox: Optional[BBox] = field(default=None)
inner_slice: Optional[Slice2D] = field(default=None)
tissue_objects: Optional[List[TissueObject]] = field(default=None)
tissue_objects: Optional[TissueObjects] = field(default=None)

def rescale(self, scale: int) -> NoReturn:
self.external_bbox.rescale(scale)
self.inner_slice.rescale(scale)
for obj in self.tissue_objects:
obj.rescale(scale)
self.tissue_objects.rescale(scale)
21 changes: 11 additions & 10 deletions src/psga/transforms/functional/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .misc import apply_mask
from .rectangle import fast_crop_rectangle
from ..entity import TissueObject
from ..entity import TissueObjects

__all__ = ["create_bin", "pack_atlas"]

Expand All @@ -23,9 +23,9 @@ def show(image):
plt.show()


def create_bin(tissue_objects: List[TissueObject], step_size: int = 10
def create_bin(tissue_objects: TissueObjects, step_size: int = 10
) -> Tuple[Tuple[int, int], List[RECTPACK_RECT_TYPE]]:
sides = [(int(obj.rectangle.width), int(obj.rectangle.height)) for obj in tissue_objects]
sides = [(int(rect.width), int(rect.height)) for rect in tissue_objects.rectangles]
sides = list(chain(*sides))
height = max(sides)
width = min(sides)
Expand All @@ -36,24 +36,25 @@ def create_bin(tissue_objects: List[TissueObject], step_size: int = 10
sort_algo=rectpack.SORT_LSIDE,
rotation=True)
packer.add_bin(width=width, height=height)
for index, obj in enumerate(tissue_objects):
packer.add_rect(width=int(obj.rectangle.width), height=int(obj.rectangle.height), rid=index)
for index, rect in enumerate(tissue_objects.rectangles):
packer.add_rect(width=int(rect.width), height=int(rect.height), rid=index)

packer.pack()
if len(tissue_objects) == len(packer[0].rectangles):
if len(tissue_objects.rectangles) == len(packer[0].rectangles):
break
else:
width += step_size

return (height, width), packer.rect_list()


def pack_atlas(image: np.ndarray, tissue_objects: List[TissueObject]) -> np.ndarray:
def pack_atlas(image: np.ndarray, tissue_objects: TissueObjects) -> np.ndarray:
fill_value = 255 if len(image.shape) == 3 else 0
crops = list()
for obj in tissue_objects:
contoured_image = apply_mask(image, mask=obj.mask.astype(np.uint8), add=fill_value)
crops.append(fast_crop_rectangle(contoured_image, obj.rectangle))
for index, rectangle in enumerate(tissue_objects.rectangles):
mask = (tissue_objects.mask == index + 1).astype(np.uint8)
contoured_image = apply_mask(image, mask=mask, add=fill_value)
crops.append(fast_crop_rectangle(contoured_image, rectangle))

shape, rectangles = create_bin(tissue_objects)
if len(image.shape) == 3:
Expand Down

0 comments on commit 07b31d0

Please sign in to comment.