-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataloader.py
28 lines (26 loc) · 922 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from torch.utils.data import Dataset
import scipy.io
import torch
class Hdigit():
def __init__(self, path):
data = scipy.io.loadmat(path + 'Hdigit.mat')
self.Y = data['truelabel'][0][0].astype(np.int32).reshape(10000,)
self.V1 = data['data'][0][0].T.astype(np.float32)
self.V2 = data['data'][0][1].T.astype(np.float32)
def __len__(self):
return 10000
def __getitem__(self, idx):
x1 = self.V1[idx]
x2 = self.V2[idx]
return [torch.from_numpy(x1), torch.from_numpy(x2)], self.Y[idx], torch.from_numpy(np.array(idx)).long()
def load_data(dataset):
if dataset == "Hdigit":
dataset = Hdigit('./data/')
dims = [784, 256]
view = 2
data_size = 10000
class_num = 10
else:
raise NotImplementedError
return dataset, dims, view, data_size, class_num