-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets_.py
35 lines (31 loc) · 1.2 KB
/
datasets_.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import numpy as np
class Dataset(data.Dataset):
"""Args:
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, data, label,
transform=None,target_transform=None):
self.transform = transform
self.target_transform = target_transform
self.data = data
self.labels = label
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.labels[index]
return img, target
def __len__(self):
return len(self.data)