-
Notifications
You must be signed in to change notification settings - Fork 383
/
Copy pathdataset_m3.py
72 lines (56 loc) · 1.93 KB
/
dataset_m3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import warnings
import torch
import time
import os
# import random
from random import SystemRandom
random = SystemRandom()
from torch.utils.data import Dataset
import sys
class MyDataset(Dataset):
def __init__(self, ain_size=48, aout_size=8):
self.a_z = []
self.digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
for i in range(26):
self.a_z += chr(i + ord('a'))
self.BLANK = chr(2)
self.EOS = "^"
self.alphabet = [self.BLANK] + self.digits + self.a_z + [' ', '+', '-', '*', '/', "=", self.EOS]
self.ain_size = ain_size
self.aout_size = aout_size
self.char2idx = {}
for i, c in enumerate(self.alphabet):
self.char2idx[c] = i
def __len__(self):
return 1024 * 128
def __getitem__(self, index):
_ = index
a1 = random.randint(0, 999)
a2 = random.randint(0, 999)
aout = f"{a1*a2}{self.EOS}"
len_aout = len(aout)
aout += self.EOS * (self.aout_size - len_aout)
ain = f"{str(a1)} * {str(a2)} = {aout}"
ain = [self.char2idx[_] for _ in ain]
ain = [0]*(self.ain_size + 1 - len(ain)) + ain
mask = torch.zeros(size=(self.ain_size,), dtype=torch.bool)
mask[-self.aout_size:-(self.aout_size - len_aout)] = True
return (torch.LongTensor(ain[:-1]), torch.LongTensor(ain[1:]),
mask)
if __name__ == '__main__':
ds = MyDataset()
for i in range(5):
ain, aout, mask = ds[i]
aout = aout[mask.nonzero()]
print(ain.shape, aout.shape, mask.shape)
# print(mask)
for idx in ain:
c = ds.alphabet[idx]
if c == ds.BLANK: continue
print(ds.alphabet[idx], end='')
print()
for idx in aout:
c = ds.alphabet[idx]
if c == ds.BLANK: continue
print(ds.alphabet[idx], end='')
print()