Skip to content

Commit

Permalink
Updated documentation and added a triplets dataset in image_list_data…
Browse files Browse the repository at this point in the history
…set.py + fixed issue in in img_lms_pose_transforms.Resize transform
  • Loading branch information
YuvalNirkin committed May 6, 2020
1 parent 2b350e9 commit 46aa502
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 30 deletions.
103 changes: 76 additions & 27 deletions datasets/image_list_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,13 @@ def get_loader(backend='opencv'):


class ImageListDataset(VisionDataset):
"""An image list datset with corresponding bounding boxes where the images can be arranged in this way:
root/id1/xxx.png
root/id1/xxy.png
root/id1/xxz.png
root/id2/123.png
root/id2/nsdf3.png
root/id2/asd932_.png
"""An image list datset with corresponding bounding boxes and targets.
Args:
root (string): Root directory path.
img_list (string): Image list file path.
bboxes_list (string): Bounding boxes list file path
targets_list (string): Targets list file path
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
Expand Down Expand Up @@ -134,24 +128,20 @@ def __len__(self):


class ImagePairListDataset(ImageListDataset):
"""An image list datset with corresponding bounding boxes where the images can be arranged in this way:
root/id1/xxx.png
root/id1/xxy.png
root/id1/xxz.png
root/id2/123.png
root/id2/nsdf3.png
root/id2/asd932_.png
""" An image dataset for loading pairs from a list.
Args:
root (string): Root directory path.
img_list (string): Image list file path.
img_list (string): Image list file path
bboxes_list (string): Bounding boxes list file path
targets_list (string): Targets list file path
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (string, optional): 'opencv', 'accimage', or 'pil'
same_prob (float): The probability to return images of the same class
return_targets (bool): If True, return the targets together with the images
Attributes:
classes (list): List of the class names.
Expand All @@ -170,13 +160,13 @@ def __init__(self, root, img_list, bboxes_list=None, targets_list=None, transfor
for i, target in enumerate(self.targets):
self.label_ranges[target] = min(self.label_ranges[target], i)

"""
Args:
index (int): Index
Returns:
tuple: (image1, image2, target1, target2) if return_targets is True else (image1, image2, same)
"""
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image1, image2, target1, target2) if return_targets is True else (image1, image2, same)
"""
# Get pair indices
label1 = self.targets[index]
total_imgs1 = self.label_ranges[label1 + 1] - self.label_ranges[label1]
Expand Down Expand Up @@ -211,11 +201,63 @@ def __getitem__(self, index):
return (img1, img2), target1 == target2


class ImageTripletListDataset(VisionDataset):
""" An image dataset for loading triplets from a list.
Args:
root (string): Root directory path.
img_list (string): Image list file path.
bboxes_list (string): Bounding boxes list file path
targets_list (string): Targets list file path
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (string, optional): 'opencv', 'accimage', or 'pil'
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, img_list, bboxes_list=None, targets_list=None, transform=None, target_transform=None,
loader='opencv'):
super(ImageTripletListDataset, self).__init__(root, transform=transform, target_transform=target_transform)
self.loader = get_loader(loader)

# Load image paths from list
img_list_path = img_list if os.path.exists(img_list) else os.path.join(root, img_list)
assert os.path.isfile(img_list_path), f'Could not find image list file: "{img_list}"'
with open(img_list_path, 'r') as f:
img_rel_path_triplets = f.read().splitlines()
self.imgs = []
for img_rel_path_triplet in img_rel_path_triplets:
self.imgs.append([os.path.join(root, p) for p in img_rel_path_triplet.split()])

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image1, image2, target1, target2) if return_targets is True else (image1, image2, same)
"""
img_triplet_paths = self.imgs[index]
img_triplet = [self.loader(p) for p in img_triplet_paths]
if self.transform is not None:
img_triplet = self.transform(img_triplet)

return tuple(img_triplet)

def __len__(self):
return len(self.imgs)


def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset', np_transforms=None,
tensor_transforms=('img_landmarks_transforms.ToTensor()',
'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
workers=4, batch_size=4):
import time
import fsgan
from fsgan.utils.obj_factory import obj_factory
from fsgan.utils.img_utils import tensor2bgr

Expand All @@ -227,7 +269,7 @@ def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset',
shuffle=True)

start = time.time()
if isinstance(dataset, ImageListDataset):
if isinstance(dataset, fsgan.datasets.image_list_dataset.ImageListDataset):
for img, target in dataloader:
print(img.shape)
print(target)
Expand All @@ -238,7 +280,7 @@ def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset',
cv2.imshow('render_img', render_img)
if cv2.waitKey(0) & 0xFF == ord('q'):
break
else:
elif isinstance(dataset, fsgan.datasets.image_list_dataset.ImagePairListDataset):
for img1, img2, target in dataloader:
print(img1.shape)
print(img2.shape)
Expand All @@ -252,9 +294,16 @@ def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset',
cv2.imshow('render_img', render_img)
if cv2.waitKey(0) & 0xFF == ord('q'):
break
elif isinstance(dataset, fsgan.datasets.image_list_dataset.ImageTripletListDataset):
for img1, img2, img3 in dataloader:
print(img1.shape)
print(img2.shape)
print(img3.shape)
end = time.time()
print('elapsed time: %f[s]' % (end - start))

return 0


if __name__ == "__main__":
# Parse program arguments
Expand Down
10 changes: 7 additions & 3 deletions datasets/img_lms_pose_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,13 @@ def __call__(self, x, interpolation=None):
numpy.ndarray or list of numpy.ndarray: Transformed images or poses
"""
interpolation = self.interpolation_id if interpolation is None else interpolation
if isinstance(interpolation, list) and isinstance(x, (list, tuple)) and len(x) == len(interpolation):
return [self.__call__(a, interpolation[i]) for i, a in enumerate(x)]
elif is_img(x): # x is an image
if isinstance(x, (list, tuple)):
if isinstance(interpolation, list):
assert len(x) == len(interpolation)
return [self.__call__(a, interpolation[i]) for i, a in enumerate(x)]
else:
return [self.__call__(a, interpolation) for a in x]
elif is_img(x): # x is an image
interpolation = interpolation[0] if isinstance(interpolation, list) else interpolation
x = cv2.resize(x, (self.size[1], self.size[0]), interpolation=interpolation)

Expand Down

0 comments on commit 46aa502

Please sign in to comment.