Skip to content

Commit cb602f5

Browse files
authored
Add PIL Image Compatibility with Record Components (airctic#849)
* sketch pil img with record components * explicit reference to PIL.Image * add test * better repr * finalize PR * fix np.array <-> PIl.Image related tests * better repr
1 parent 9bd5b4d commit cb602f5

File tree

9 files changed

+48
-24
lines changed

9 files changed

+48
-24
lines changed

icevision/core/record_components.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -99,39 +99,47 @@ def set_record_id(self, record_id: int):
9999
self.record_id = record_id
100100

101101
def _repr(self) -> List[str]:
102-
return [f"Image ID: {self.record_id}"]
102+
return [f"Record ID: {self.record_id}"]
103103

104104
def as_dict(self) -> dict:
105105
return {"record_id": self.record_id}
106106

107107

108108
# TODO: we need a way to combine filepath and image mixin
109-
# TODO: rename to ImageArrayRecordComponent
110109
class ImageRecordComponent(RecordComponent):
111110
def __init__(self, task=tasks.common):
112111
super().__init__(task=task)
113112
self.img = None
114113

115-
def set_img(self, img: np.ndarray):
114+
def set_img(self, img: Union[PIL.Image.Image, np.ndarray]):
115+
assert isinstance(img, (PIL.Image.Image, np.ndarray))
116116
self.img = img
117-
height, width, _ = self.img.shape
117+
if isinstance(img, PIL.Image.Image):
118+
height, width = img.shape
119+
elif isinstance(img, np.ndarray):
120+
# else:
121+
height, width, _ = self.img.shape
118122
# this should set on SizeRecordComponent
119123
self.composite.set_img_size(ImgSize(width=width, height=height), original=True)
120124

121125
def _repr(self) -> List[str]:
122126
if self.img is not None:
123-
ndims = len(self.img.shape)
124-
if ndims == 3: # RGB, RGBA
125-
height, width, channels = self.img.shape
126-
elif ndims == 2: # Grayscale
127-
height, width, channels = [*self.img.shape, 1]
128-
else:
129-
raise ValueError(
130-
f"Expected image to have 2 or 3 dimensions, got {ndims} instead"
131-
)
132-
return [f"Image: {width}x{height}x{channels} <np.ndarray> Image"]
127+
if isinstance(self.img, np.ndarray):
128+
ndims = len(self.img.shape)
129+
if ndims == 3: # RGB, RGBA
130+
height, width, channels = self.img.shape
131+
elif ndims == 2: # Grayscale
132+
height, width, channels = [*self.img.shape, 1]
133+
else:
134+
raise ValueError(
135+
f"Expected image to have 2 or 3 dimensions, got {ndims} instead"
136+
)
137+
return [f"Img: {width}x{height}x{channels} <np.ndarray> Image"]
138+
elif isinstance(self.img, PIL.Image.Image):
139+
height, width = self.img.shape
140+
return [f"Img: {width}x{height} <PIL.Image; mode='{self.img.mode}'>"]
133141
else:
134-
return [f"Image: {self.img}"]
142+
return [f"Img: {self.img}"]
135143

136144
def _unload(self):
137145
self.img = None

icevision/data/dataset.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def __getitem__(self, i):
3535
record = self.records[i].load()
3636
if self.tfm is not None:
3737
record = self.tfm(record)
38+
else:
39+
# HACK FIXME
40+
record.set_img(np.array(record.img))
3841
return record
3942

4043
def __repr__(self):

icevision/tfms/albumentations/albumentations_adapter.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def collect(self, record):
4242

4343
class AlbumentationsImgComponent(AlbumentationsAdapterComponent):
4444
def setup_img(self, record):
45-
self.adapter._albu_in["image"] = record.img
45+
# NOTE - assumed that `record.img` is a PIL.Image
46+
self.adapter._albu_in["image"] = np.array(record.img)
4647

4748
self.adapter._collect_ops.append(CollectOp(self.collect))
4849

@@ -298,7 +299,7 @@ def _get_size_without_padding(self, record) -> ImgSize:
298299
height, width, _ = self._albu_out["image"].shape
299300

300301
if get_transform(self.tfms_list, "Pad") is not None:
301-
after_pad_h, after_pad_w, _ = record.img.shape
302+
after_pad_h, after_pad_w, _ = np.array(record.img).shape
302303

303304
t = get_transform(self.tfms_list, "SmallestMaxSize")
304305
if t is not None:

icevision/utils/imageio.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717
if PIL.ExifTags.TAGS[_EXIF_ORIENTATION_TAG] == "Orientation":
1818
break
1919

20+
# from enum import Enum
2021

21-
def open_img(fn, gray=False):
22+
# class PILMode(Enum):
23+
# blah
24+
25+
# FIXME
26+
def open_img(fn, gray=False) -> PIL.Image.Image:
27+
"Open an image from disk `fn` as a PIL Image"
2228
color = "L" if gray else "RGB"
2329
image = PIL.Image.open(str(fn))
2430
image = PIL.ImageOps.exif_transpose(image)
2531
image = image.convert(color)
26-
return np.array(image)
32+
return image
2733

2834

2935
# TODO: Deprecated

icevision/visualize/draw_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def draw_sample(
7474
* include_only: (Optional) List of labels that must be exclusively plotted. Takes
7575
precedence over `exclude_labels` (?)
7676
"""
77-
img = sample.img.copy()
77+
img = np.asarray(sample.img).copy() # HACK
7878
num_classification_plotted = 0
7979

8080
# Dynamic font size based on image height

tests/core/test_record.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def record_wrong_num_annotations(samples_source):
7070
def test_record_load(record):
7171
record_loaded = record.load()
7272

73-
assert isinstance(record_loaded.img, np.ndarray)
73+
assert isinstance(record_loaded.img, PIL.Image.Image)
7474
assert isinstance(record_loaded.detection.masks, MaskArray)
7575

7676
# test original record is not modified

tests/models/torchvision_models/mask_rcnn/test_predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def sample_dataset(samples_source):
99
images_dir = samples_source / "images"
1010
images_files = get_image_files(images_dir)[-2:]
1111

12-
images = [open_img(path) for path in images_files]
12+
images = [np.array(open_img(path)) for path in images_files]
1313
images = [cv2.resize(image, (128, 128)) for image in images]
1414

1515
return Dataset.from_images(images)

tests/transforms/test_albu_transform.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_inference_transform(records, check_attributes_on_component):
2424
ds = Dataset.from_images([img], tfm)
2525

2626
tfmed = ds[0]
27-
assert (tfmed.img == img[:, ::-1, :]).all()
27+
assert (tfmed.img == np.array(img)[:, ::-1, :]).all()
2828
check_attributes_on_component(tfmed)
2929

3030

tests/utils/test_imageio.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010
],
1111
)
1212
def test_open_img(samples_source, fn, expected):
13-
assert open_img(samples_source / fn).shape == expected
13+
# When returning np arrays
14+
assert np.array(open_img(samples_source / fn)).shape == expected
15+
assert np.array(open_img(samples_source / fn, gray=True)).shape == expected[:-1]
16+
17+
# When returning PIL Images; returns only (W,H) for size, not num. channels
18+
assert open_img(samples_source / fn).shape == expected[:2]
1419
assert open_img(samples_source / fn, gray=True).shape == expected[:-1]
20+
assert isinstance(open_img(samples_source / fn), PIL.Image.Image)
1521

1622

1723
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)