-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_iterator.py
103 lines (84 loc) · 2.74 KB
/
data_iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
'''
Python 3.6
Pytorch >= 0.4
Written by Hongyu Wang in Beihang university
'''
import numpy
import pickle as pkl
import sys
import torch
def dataIterator(feature_file,label_file,dictionary,batch_size,batch_Imagesize,maxlen,maxImagesize):
fp=open(feature_file,'rb')
# pkl.dump(fp, feature_file, protocol=2)
features=pkl.load(fp, encoding="latin1")
fp.close()
fp2=open(label_file,'r')
labels=fp2.readlines()
fp2.close()
len_label = len(labels)
targets={}
# map word to int with dictionary
for l in labels:
tmp=l.strip().split()
uid=tmp[0]
w_list=[]
for w in tmp[1:]:
if w in dictionary:
w_list.append(dictionary[w])
else:
print('a word not in the dictionary !! sentence ',uid,'word ', w)
sys.exit()
targets[uid]=w_list
imageSize={}
imagehigh={}
imagewidth={}
for uid,fea in features.items():
imageSize[uid]=fea.shape[1]*fea.shape[2]
imagehigh[uid]=fea.shape[1]
imagewidth[uid]=fea.shape[2]
imageSize= sorted(imageSize.items(), key=lambda d:d[1],reverse=True) # sorted by sentence length, return a list with each triple element
feature_batch=[]
label_batch=[]
feature_total=[]
label_total=[]
uidList=[]
batch_image_size=0
biggest_image_size=0
i=0
for uid,size in imageSize:
if size>biggest_image_size:
biggest_image_size=size
fea=features[uid]
lab=targets[uid]
batch_image_size=biggest_image_size*(i+1)
if len(lab)>maxlen:
continue
# print('sentence', uid, 'length bigger than', maxlen, 'ignore')
elif size>maxImagesize:
continue
# print('image', uid, 'size bigger than', maxImagesize, 'ignore')
else:
uidList.append(uid)
if batch_image_size>batch_Imagesize or i==batch_size: # a batch is full
if label_batch:
feature_total.append(feature_batch)
label_total.append(label_batch)
i=0
biggest_image_size=size
feature_batch=[]
label_batch=[]
feature_batch.append(fea)
label_batch.append(lab)
batch_image_size=biggest_image_size*(i+1)
i+=1
else:
feature_batch.append(fea)
label_batch.append(lab)
i+=1
# last
feature_total.append(feature_batch)
label_total.append(label_batch)
len_ignore = len_label - len(feature_total)
print('total ',len(feature_total), 'batch data loaded')
print('ignore',len_ignore,'images')
return feature_total,label_total