-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdatasetKITTIEval.py
62 lines (44 loc) · 1.65 KB
/
datasetKITTIEval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import random
import json
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
import torch.nn.functional as F
from torchvision.io import read_video
import random
import cv2
import math
resolution = (1248,368)
dresolution = (312,92)
class KITTIDataset(Dataset):
def __init__(self, split='train', root = None):
super(KITTIDataset, self).__init__()
self.resolution = resolution
self.root_dir = root
self.rgb_dir = os.path.join(self.root_dir,'rgb')
self.instance_dir = os.path.join(self.root_dir,'instance')
self.files = os.listdir(self.rgb_dir)
self.files.sort()
if split == 'eval':
self.files = self.files[0:5]
self.img_transform = transforms.Compose([
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def __getitem__(self, index):
path = self.files[index]
image = cv2.imread(os.path.join(self.rgb_dir,path))
mask = cv2.imread(os.path.join(self.instance_dir,path),-1)
image = cv2.resize(image, resolution, interpolation = cv2.INTER_LINEAR)
mask = cv2.resize(mask, dresolution, interpolation = cv2.INTER_NEAREST)
mask = torch.Tensor(mask).long()
image = torch.Tensor(image).float()
image = image / 255.0
image = image.permute(2,0,1)
image = self.img_transform(image)
sample = {'image': image, 'mask':mask}
return sample
def __len__(self):
return len(self.files)