forked from franroldans/tfm-franroldan-wav2pix
-
Notifications
You must be signed in to change notification settings - Fork 37
/
txt2image_dataset.py
94 lines (72 loc) · 3.05 KB
/
txt2image_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import io
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
import pdb
import torch.nn.functional as F
class Text2ImageDataset(Dataset):
def __init__(self, datasetFile, transform=None, split=0):
self.datasetFile = datasetFile
self.transform = transform
self.dataset = None
self.dataset_keys = None
self.split = 'train' if split == 0 else 'valid' if split == 1 else 'test'
self.h5py2int = lambda x: int(np.array(x))
def __len__(self):
f = h5py.File(self.datasetFile, 'r')
self.dataset_keys = [str(k) for k in f[self.split].keys()]
length = len(f[self.split])
f.close()
return length
def __getitem__(self, idx):
if self.dataset is None:
self.dataset = h5py.File(self.datasetFile, mode='r')
self.dataset_keys = [str(k) for k in self.dataset[self.split].keys()]
example_name = self.dataset_keys[idx]
example = self.dataset[self.split][example_name]
# pdb.set_trace()
right_image = bytes(np.array(example['img']))
right_embed = np.array(example['embeddings'], dtype=float)
wrong_image = bytes(np.array(self.find_wrong_image(example['class'])))
inter_embed = np.array(self.find_inter_embed())
right_image = Image.open(io.BytesIO(right_image)).resize((64, 64))
wrong_image = Image.open(io.BytesIO(wrong_image)).resize((64, 64))
right_image = self.validate_image(right_image)
wrong_image = self.validate_image(wrong_image)
txt = np.array(example['txt']).astype(str)
sample = {
'right_images': torch.FloatTensor(right_image),
'right_embed': torch.FloatTensor(right_embed),
'wrong_images': torch.FloatTensor(wrong_image),
'inter_embed': torch.FloatTensor(inter_embed),
'txt': str(txt)
}
sample['right_images'] = sample['right_images'].sub_(127.5).div_(127.5)
sample['wrong_images'] = sample['wrong_images'].sub_(127.5).div_(127.5)
return sample
def find_wrong_image(self, category):
idx = np.random.randint(len(self.dataset_keys))
example_name = self.dataset_keys[idx]
example = self.dataset[self.split][example_name]
_category = example['class']
if _category != category:
return example['img']
return self.find_wrong_image(category)
def find_inter_embed(self):
idx = np.random.randint(len(self.dataset_keys))
example_name = self.dataset_keys[idx]
example = self.dataset[self.split][example_name]
return example['embeddings']
def validate_image(self, img):
img = np.array(img, dtype=float)
if len(img.shape) < 3:
rgb = np.empty((64, 64, 3), dtype=np.float32)
rgb[:, :, 0] = img
rgb[:, :, 1] = img
rgb[:, :, 2] = img
img = rgb
return img.transpose(2, 0, 1)