forked from foamliu/Machine-Translation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_gen.py
112 lines (83 loc) · 3.13 KB
/
data_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# encoding=utf-8
import itertools
import numpy as np
from torch.utils.data import Dataset
from config import *
samples_path = 'data/samples_train.json'
samples = json.load(open(samples_path, 'r'))
# np.random.shuffle(samples)
def zeroPadding(l, fillvalue=PAD_token):
return list(itertools.zip_longest(*l, fillvalue=fillvalue))
def binaryMatrix(l):
m = []
for i, seq in enumerate(l):
m.append([])
for token in seq:
if token == PAD_token:
m[i].append(0)
else:
m[i].append(1)
return m
# Returns padded input sequence tensor and lengths
def inputVar(indexes_batch):
lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
padList = zeroPadding(indexes_batch)
padVar = torch.LongTensor(padList)
return padVar, lengths
# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(indexes_batch):
max_target_len = max([len(indexes) for indexes in indexes_batch])
padList = zeroPadding(indexes_batch)
mask = binaryMatrix(padList)
mask = torch.ByteTensor(mask)
padVar = torch.LongTensor(padList)
return padVar, mask, max_target_len
# Returns all items for a given batch of pairs
def batch2TrainData(pair_batch):
pair_batch.sort(key=lambda x: len(x[0]), reverse=True)
input_batch, output_batch = [], []
for pair in pair_batch:
input_batch.append(pair[0])
output_batch.append(pair[1])
inp, lengths = inputVar(input_batch)
output, mask, max_target_len = outputVar(output_batch)
return inp, lengths, output, mask, max_target_len
class TranslationDataset(Dataset):
def __init__(self, split):
self.split = split
assert self.split in {'train', 'valid'}
print('loading {} samples'.format(split))
train_count = int(len(samples) * train_split)
if split == 'train':
self.samples = samples[:train_count]
else:
self.samples = samples[train_count:]
self.num_chunks = len(self.samples) // chunk_size
np.random.shuffle(self.samples)
print('count: ' + str(len(self.samples)))
def __getitem__(self, i):
start_idx = i * chunk_size
pair_batch = []
for i_batch in range(chunk_size):
sample = self.samples[start_idx + i_batch]
pair_batch.append((sample['input'], sample['output']))
return batch2TrainData(pair_batch)
def __len__(self):
return self.num_chunks
if __name__ == '__main__':
print('loading {} samples'.format('train'))
samples_path = 'data/samples_train.json'
samples = json.load(open(samples_path, 'r'))
pair_batch = []
for i in range(5):
sample = samples[i]
pair_batch.append((sample['input'], sample['output']))
# Example for validation
small_batch_size = 5
batches = batch2TrainData(pair_batch)
input_variable, lengths, target_variable, mask, max_target_len = batches
print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)