-
Notifications
You must be signed in to change notification settings - Fork 8
/
In_shop_clothes.py
130 lines (101 loc) · 3.93 KB
/
In_shop_clothes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import absolute_import, print_function
"""
In-shop-clothes data-set for Pytorch
"""
import torch
import torch.utils.data as data
from PIL import Image
import os
from torchvision import transforms
from collections import defaultdict
def default_loader(path):
return Image.open(path).convert('RGB')
class MyData(data.Dataset):
def __init__(self, root=None, label_txt=None,
transform=None, loader=default_loader):
# Initialization data path and train(gallery or query) txt path
if root is None:
root = "/home/xunwang"
label_txt = os.path.join(root, 'train.txt')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if transform is None:
transform = transforms.Compose([
# transforms.CovertBGR(),
transforms.Resize(256),
transforms.RandomResizedCrop(scale=(0.16, 1), size=224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
# read txt get image path and labels
file = open(label_txt)
images_anon = file.readlines()
images = []
labels = []
for img_anon in images_anon:
img_anon = img_anon.replace(' ', '\t')
[img, label] = (img_anon.split('\t'))[:2]
images.append(img)
labels.append(int(label))
classes = list(set(labels))
# Generate Index Dictionary for every class
Index = defaultdict(list)
for i, label in enumerate(labels):
Index[label].append(i)
# Initialization Done
self.root = root
self.images = images
self.labels = labels
self.classes = classes
self.transform = transform
self.Index = Index
self.loader = loader
def __getitem__(self, index):
fn, label = self.images[index], self.labels[index]
# print(os.path.join(self.root, fn))
img = self.loader(os.path.join(self.root, fn))
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.images)
class InShopClothes:
def __init__(self, root=None, transform=None, crop=False):
# Data loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if transform is None:
transform = [transforms.Compose([
# transforms.CovertBGR(),
transforms.Resize(256),
transforms.RandomResizedCrop(scale=(0.16, 1), size=224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]),
transforms.Compose([
# transforms.CovertBGR(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])]
if crop:
root = '/opt/intern/users/xunwang/DataSet/In_shop_clothes_retrieval/cropIms'
else:
root = '/home/xunwang'
root_ = '/opt/intern/users/xunwang/DataSet/In_shop_clothes_retrieval'
train_txt = os.path.join(root_, 'train.txt')
gallery_txt = os.path.join(root_, 'gallery.txt')
query_txt = os.path.join(root_, 'query.txt')
self.train = MyData(root, label_txt=train_txt, transform=transform[0])
self.gallery = MyData(root, label_txt=gallery_txt, transform=transform[1])
self.query = MyData(root, label_txt=query_txt, transform=transform[1])
def testIn_Shop_Clothes():
data = InShopClothes()
print(len(data.gallery))
print(len(data.query))
print(data.train[1])
if __name__ == "__main__":
testIn_Shop_Clothes()