Skip to content

Commit

Permalink
Very minor refactoring. Still issues in the dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
lext committed Apr 3, 2019
1 parent 11a6f5a commit 0e28a45
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 196 deletions.
46 changes: 25 additions & 21 deletions data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from os import path as osp

import cv2
import numpy as np
import pandas as pd
from os import path as osp

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import Dataset


def center_crop(img, size):
Expand Down Expand Up @@ -64,6 +64,7 @@ class HPatchesDataset(Dataset):
between source and target views
mask: valid/invalid correspondences
"""

def __init__(self,
csv_file,
image_path_orig,
Expand Down Expand Up @@ -140,12 +141,12 @@ def __getitem__(self, idx):
img1 = \
cv2.resize(cv2.imread(osp.join(self.image_path_orig,
obj,
im1_id+'.ppm'), -1),
im1_id + '.ppm'), -1),
self.image_size)
img2 = \
cv2.resize(cv2.imread(osp.join(self.image_path_orig,
obj,
im2_id+'.ppm'), -1),
im2_id + '.ppm'), -1),
self.image_size)
_, _, ch = img1.shape
if ch == 3:
Expand Down Expand Up @@ -192,6 +193,7 @@ class HomoAffTpsDataset(Dataset):
mask_x: X component of the mask (valid/invalid correspondences)
mask_y: Y component of the mask (valid/invalid correspondences)
"""

def __init__(self,
image_path,
csv_file,
Expand Down Expand Up @@ -276,13 +278,13 @@ def get_grid(self, H, ccrop):

# getting the central patch from the pivot
Xwarp_crop = X_grid_pivot[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]
Ywarp_crop = Y_grid_pivot[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]
X_crop = X_[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]
Y_crop = Y_[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]

# crop grid
Xwarp_crop_range = \
Expand All @@ -293,7 +295,8 @@ def get_grid(self, H, ccrop):
Ywarp_crop_range], dim=-1)
return grid_full.unsqueeze(0), grid_crop.unsqueeze(0)

def symmetric_image_pad(self, image_batch, padding_factor):
@staticmethod
def symmetric_image_pad(image_batch, padding_factor):
"""
Pad an input image mini-batch symmetrically
Args:
Expand Down Expand Up @@ -328,13 +331,13 @@ def __getitem__(self, idx):
transform_type = data['aff/tps/homo'].astype('uint8')

# affine transformation
if ((transform_type == 0) or (transform_type == 1)):
if transform_type == 0 or transform_type == 1:
# read image
source_img_name = osp.join(self.img_path, data.fname)
source_img = cv2.cvtColor(cv2.imread(source_img_name),
cv2.COLOR_BGR2RGB)

if (transform_type == 0):
if transform_type == 0:
theta = data.iloc[2:8].values.astype('float').reshape(2, 3)
theta = torch.Tensor(theta.astype(np.float32)).expand(1, 2, 3)
else:
Expand All @@ -346,7 +349,7 @@ def __getitem__(self, idx):
# make arrays float tensor for subsequent processing
image = torch.Tensor(source_img.astype(np.float32))

if (image.numpy().ndim == 2):
if image.numpy().ndim == 2:
image = \
torch.Tensor(np.dstack((source_img.astype(np.float32),
source_img.astype(np.float32),
Expand All @@ -364,24 +367,24 @@ def __getitem__(self, idx):
image_pad = self.symmetric_image_pad(image, padding_factor=0.5)

# get cropped source image (240x240)
cropped_source_image = \
img_src_crop = \
self.transform_image(image_pad,
self.H_OUT,
self.W_OUT,
padding_factor=0.5,
crop_factor=9/16).squeeze()
crop_factor=9 / 16).squeeze()

# get cropped target image (240x240)
cropped_target_image = \
img_target_crop = \
self.transform_image(image_pad,
self.H_OUT,
self.W_OUT,
padding_factor=0.5,
crop_factor=9/16,
crop_factor=9 / 16,
theta=theta).squeeze(0)

# Homography transformation
elif (transform_type == 2):
elif transform_type == 2:
# Homography matrix for 768x576 image resolution
theta = data.iloc[2:11].values.astype('double').reshape(3, 3)

Expand Down Expand Up @@ -433,7 +436,7 @@ def __getitem__(self, idx):
grid_pyramid = []
mask_x = []
mask_y = []
if (transform_type == 0):
if transform_type == 0:
for layer_size in self.pyramid_param:
grid = self.generate_grid(layer_size,
layer_size,
Expand All @@ -442,7 +445,7 @@ def __getitem__(self, idx):
grid_pyramid.append(grid)
mask_x.append(mask[:, :, 0])
mask_y.append(mask[:, :, 1])
elif (transform_type == 1):
elif transform_type == 1:
grid = self.generate_grid(self.H_OUT,
self.W_OUT,
theta).squeeze(0)
Expand All @@ -453,7 +456,7 @@ def __getitem__(self, idx):
grid_pyramid.append(grid_m)
mask_x.append(mask[:, :, 0])
mask_y.append(mask[:, :, 1])
elif (transform_type == 2):
elif transform_type == 2:
grid = grid_crop.squeeze(0)
for layer_size in self.pyramid_param:
grid_m = torch.from_numpy(cv2.resize(grid.numpy(),
Expand All @@ -475,6 +478,7 @@ class TpsGridGen(nn.Module):
Adopted version of synthetically transformed pairs dataset by I.Rocco
https://github.com/ignacio-rocco/cnngeometric_pytorch
"""

def __init__(self,
out_h=240,
out_w=240,
Expand Down
Loading

0 comments on commit 0e28a45

Please sign in to comment.