-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathpenn_tree_bank_char.py
142 lines (120 loc) · 4.39 KB
/
penn_tree_bank_char.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Adapted from https://github.com/locuslab/TCN/blob/master/TCN/
"""
import pickle
from collections import Counter
import os
import numpy as np
import torch
import pathlib
from .utils import load_data, save_data
import observations
class PennTreeBankChar(torch.utils.data.Dataset):
def __init__(
self,
partition: int,
seq_length: int,
valid_seq_len: int,
batch_size: int,
**kwargs,
):
self.seq_len = seq_length
self.valid_seq_len = valid_seq_len
self.batch_size = batch_size
self.root = pathlib.Path("./data")
self.base_loc = self.root / "penn"
data_loc = self.base_loc / "preprocessed_data_char"
if os.path.exists(data_loc):
self.dictionary = pickle.load(open(str(data_loc / 'dictionary_char'), 'rb'))
else:
train, valid, test = self._process_data()
if not os.path.exists(data_loc):
os.mkdir(data_loc)
pickle.dump(self.dictionary, open(str(data_loc / 'dictionary_char'), 'wb'))
save_data(
data_loc,
train=train,
valid=valid,
test=test,
)
self.X, self.y = self.load_data(data_loc, partition)
if partition == 'train':
self.sampler = SequentialBatchSampler(self)
else:
self.sampler = SequentialBatchSampler(self, shuffle=False)
super(PennTreeBankChar, self).__init__()
def __getitem__(self, ind):
b = ind // len(self.X[0])
i = ind - b * len(self.X[0])
return self.X[b][i], self.y[b][i]
def __len__(self):
return len(self.X[0]) * len(self.X)
def create_seq(self, data, batch_size):
nbatch = data.size(0) // batch_size
data = data.narrow(0, 0, nbatch * batch_size).view(batch_size, -1) ## crop tail
x = []
y = []
L = data.shape[1]
for i in range(0, L-1, self.valid_seq_len):
if i + self.seq_len - self.valid_seq_len >= L - 1:
continue
end = min(i + self.seq_len, L - 1)
x.append(data[:, i: end].contiguous())
y.append(data[:, i+1: end+1].contiguous())
return x, y
def _process_data(self):
self.dictionary = Dictionary()
train, test, valid = getattr(observations, 'ptb')(self.base_loc)
for c in train + ' ' + test + '' + valid:
self.dictionary.add_word(c)
self.dictionary.prep_dict()
train = self._char_to_tensor(train)
valid = self._char_to_tensor(valid)
test = self._char_to_tensor(test)
return train, valid, test
def _char_to_tensor(self, string):
tensor = torch.zeros(len(string)).long()
for i in range(len(string)):
tensor[i] = self.dictionary.char2idx[string[i]]
return tensor
def load_data(self, data_loc, partition):
tensors = load_data(data_loc)
if partition == "train":
data = tensors["train"]
elif partition == "val":
data = tensors["valid"]
elif partition == "test":
data = tensors["test"]
else:
raise NotImplementedError("the set {} is not implemented.".format(set))
X, y = self.create_seq(data, self.batch_size)
return X, y
class Dictionary(object):
def __init__(self):
self.char2idx = {}
self.idx2char = []
self.counter = Counter()
def add_word(self, word):
self.counter[word] += 1
def prep_dict(self):
for char in self.counter:
if char not in self.char2idx:
self.idx2char.append(char)
self.char2idx[char] = len(self.idx2char) - 1
def __len__(self):
return len(self.idx2char)
class SequentialBatchSampler(torch.utils.data.Sampler):
def __init__(self, data_source, shuffle=True):
super(SequentialBatchSampler, self).__init__(data_source)
self.X = data_source.X
if shuffle:
self.sampler = torch.utils.data.SubsetRandomSampler(np.arange(len(self.X)))
else:
self.sampler = np.arange(len(self.X))
self.batch_size = self.X[0].shape[0]
def __iter__(self):
for idx in self.sampler:
batch = [idx * self.batch_size + j for j in range(self.batch_size)]
yield batch
def __len__(self):
return len(self.X)