From 46aa5020d1d21ed75e08da24c0b1d1d295312842 Mon Sep 17 00:00:00 2001 From: Yuval Nirkin Date: Wed, 6 May 2020 15:02:12 +0300 Subject: [PATCH] Updated documentation and added a triplets dataset in image_list_dataset.py + fixed issue in in img_lms_pose_transforms.Resize transform --- datasets/image_list_dataset.py | 103 ++++++++++++++++++++-------- datasets/img_lms_pose_transforms.py | 10 ++- 2 files changed, 83 insertions(+), 30 deletions(-) diff --git a/datasets/image_list_dataset.py b/datasets/image_list_dataset.py index 0b6d55e..70da5a7 100644 --- a/datasets/image_list_dataset.py +++ b/datasets/image_list_dataset.py @@ -44,19 +44,13 @@ def get_loader(backend='opencv'): class ImageListDataset(VisionDataset): - """An image list datset with corresponding bounding boxes where the images can be arranged in this way: - - root/id1/xxx.png - root/id1/xxy.png - root/id1/xxz.png - - root/id2/123.png - root/id2/nsdf3.png - root/id2/asd932_.png + """An image list datset with corresponding bounding boxes and targets. Args: root (string): Root directory path. img_list (string): Image list file path. + bboxes_list (string): Bounding boxes list file path + targets_list (string): Targets list file path transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the @@ -134,24 +128,20 @@ def __len__(self): class ImagePairListDataset(ImageListDataset): - """An image list datset with corresponding bounding boxes where the images can be arranged in this way: - - root/id1/xxx.png - root/id1/xxy.png - root/id1/xxz.png - - root/id2/123.png - root/id2/nsdf3.png - root/id2/asd932_.png + """ An image dataset for loading pairs from a list. Args: root (string): Root directory path. - img_list (string): Image list file path. + img_list (string): Image list file path + bboxes_list (string): Bounding boxes list file path + targets_list (string): Targets list file path transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (string, optional): 'opencv', 'accimage', or 'pil' + same_prob (float): The probability to return images of the same class + return_targets (bool): If True, return the targets together with the images Attributes: classes (list): List of the class names. @@ -170,13 +160,13 @@ def __init__(self, root, img_list, bboxes_list=None, targets_list=None, transfor for i, target in enumerate(self.targets): self.label_ranges[target] = min(self.label_ranges[target], i) - """ - Args: - index (int): Index - Returns: - tuple: (image1, image2, target1, target2) if return_targets is True else (image1, image2, same) - """ def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image1, image2, target1, target2) if return_targets is True else (image1, image2, same) + """ # Get pair indices label1 = self.targets[index] total_imgs1 = self.label_ranges[label1 + 1] - self.label_ranges[label1] @@ -211,11 +201,63 @@ def __getitem__(self, index): return (img1, img2), target1 == target2 +class ImageTripletListDataset(VisionDataset): + """ An image dataset for loading triplets from a list. + + Args: + root (string): Root directory path. + img_list (string): Image list file path. + bboxes_list (string): Bounding boxes list file path + targets_list (string): Targets list file path + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (string, optional): 'opencv', 'accimage', or 'pil' + + Attributes: + classes (list): List of the class names. + class_to_idx (dict): Dict with items (class_name, class_index). + imgs (list): List of (image path, class_index) tuples + """ + def __init__(self, root, img_list, bboxes_list=None, targets_list=None, transform=None, target_transform=None, + loader='opencv'): + super(ImageTripletListDataset, self).__init__(root, transform=transform, target_transform=target_transform) + self.loader = get_loader(loader) + + # Load image paths from list + img_list_path = img_list if os.path.exists(img_list) else os.path.join(root, img_list) + assert os.path.isfile(img_list_path), f'Could not find image list file: "{img_list}"' + with open(img_list_path, 'r') as f: + img_rel_path_triplets = f.read().splitlines() + self.imgs = [] + for img_rel_path_triplet in img_rel_path_triplets: + self.imgs.append([os.path.join(root, p) for p in img_rel_path_triplet.split()]) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image1, image2, target1, target2) if return_targets is True else (image1, image2, same) + """ + img_triplet_paths = self.imgs[index] + img_triplet = [self.loader(p) for p in img_triplet_paths] + if self.transform is not None: + img_triplet = self.transform(img_triplet) + + return tuple(img_triplet) + + def __len__(self): + return len(self.imgs) + + def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset', np_transforms=None, tensor_transforms=('img_landmarks_transforms.ToTensor()', 'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'), workers=4, batch_size=4): import time + import fsgan from fsgan.utils.obj_factory import obj_factory from fsgan.utils.img_utils import tensor2bgr @@ -227,7 +269,7 @@ def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset', shuffle=True) start = time.time() - if isinstance(dataset, ImageListDataset): + if isinstance(dataset, fsgan.datasets.image_list_dataset.ImageListDataset): for img, target in dataloader: print(img.shape) print(target) @@ -238,7 +280,7 @@ def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset', cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break - else: + elif isinstance(dataset, fsgan.datasets.image_list_dataset.ImagePairListDataset): for img1, img2, target in dataloader: print(img1.shape) print(img2.shape) @@ -252,9 +294,16 @@ def main(dataset='fake_detection.datasets.image_list_dataset.ImageListDataset', cv2.imshow('render_img', render_img) if cv2.waitKey(0) & 0xFF == ord('q'): break + elif isinstance(dataset, fsgan.datasets.image_list_dataset.ImageTripletListDataset): + for img1, img2, img3 in dataloader: + print(img1.shape) + print(img2.shape) + print(img3.shape) end = time.time() print('elapsed time: %f[s]' % (end - start)) + return 0 + if __name__ == "__main__": # Parse program arguments diff --git a/datasets/img_lms_pose_transforms.py b/datasets/img_lms_pose_transforms.py index a070446..4aa89f7 100644 --- a/datasets/img_lms_pose_transforms.py +++ b/datasets/img_lms_pose_transforms.py @@ -203,9 +203,13 @@ def __call__(self, x, interpolation=None): numpy.ndarray or list of numpy.ndarray: Transformed images or poses """ interpolation = self.interpolation_id if interpolation is None else interpolation - if isinstance(interpolation, list) and isinstance(x, (list, tuple)) and len(x) == len(interpolation): - return [self.__call__(a, interpolation[i]) for i, a in enumerate(x)] - elif is_img(x): # x is an image + if isinstance(x, (list, tuple)): + if isinstance(interpolation, list): + assert len(x) == len(interpolation) + return [self.__call__(a, interpolation[i]) for i, a in enumerate(x)] + else: + return [self.__call__(a, interpolation) for a in x] + elif is_img(x): # x is an image interpolation = interpolation[0] if isinstance(interpolation, list) else interpolation x = cv2.resize(x, (self.size[1], self.size[0]), interpolation=interpolation)