diff --git a/utils/dataloader.py b/utils/dataloader.py new file mode 100644 index 0000000..34f682a --- /dev/null +++ b/utils/dataloader.py @@ -0,0 +1,141 @@ +from random import shuffle +import numpy as np +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +from PIL import Image +from torch.autograd import Variable +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb +from nets.yolo_training import Generator + + +class YoloDataset(Dataset): + def __init__(self, train_lines, image_size, mosaic=True): + super(YoloDataset, self).__init__() + + self.train_lines = train_lines + self.train_batches = len(train_lines) + self.image_size = image_size + self.mosaic = mosaic + self.flag = True + + def __len__(self): + return self.train_batches + + def rand(self, a=0, b=1): + return np.random.rand() * (b - a) + a + + def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5): + """实时数据增强的随机预处理""" + line = annotation_line.split() + image = Image.open(line[0]) + iw, ih = image.size + h, w = input_shape + box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) + + # 调整图片大小 + new_ar = w / h * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter) + scale = self.rand(.25, 2) + if new_ar < 1: + nh = int(scale * h) + nw = int(nh * new_ar) + else: + nw = int(scale * w) + nh = int(nw / new_ar) + image = image.resize((nw, nh), Image.BICUBIC) + + # 放置图片 + dx = int(self.rand(0, w - nw)) + dy = int(self.rand(0, h - nh)) + new_image = Image.new('RGB', (w, h), + (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))) + new_image.paste(image, (dx, dy)) + image = new_image + + # 是否翻转图片 + flip = self.rand() < .5 + if flip: + image = image.transpose(Image.FLIP_LEFT_RIGHT) + + # 色域变换 + hue = self.rand(-hue, hue) + sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat) + val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val) + x = rgb_to_hsv(np.array(image) / 255.) + x[..., 0] += hue + x[..., 0][x[..., 0] > 1] -= 1 + x[..., 0][x[..., 0] < 0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x > 1] = 1 + x[x < 0] = 0 + image_data = hsv_to_rgb(x) * 255 # numpy array, 0 to 1 + + # 调整目标框坐标 + box_data = np.zeros((len(box), 5)) + if len(box) > 0: + np.random.shuffle(box) + box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx + box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy + if flip: + box[:, [0, 2]] = w - box[:, [2, 0]] + box[:, 0:2][box[:, 0:2] < 0] = 0 + box[:, 2][box[:, 2] > w] = w + box[:, 3][box[:, 3] > h] = h + box_w = box[:, 2] - box[:, 0] + box_h = box[:, 3] - box[:, 1] + box = box[np.logical_and(box_w > 1, box_h > 1)] # 保留有效框 + box_data = np.zeros((len(box), 5)) + box_data[:len(box)] = box + if len(box) == 0: + return image_data, [] + + if (box_data[:, :4] > 0).any(): + return image_data, box_data + else: + return image_data, [] + + def __getitem__(self, index): + if index == 0: + shuffle(self.train_lines) + lines = self.train_lines + n = self.train_batches + index = index % n + img, y = self.get_random_data(lines[index], self.image_size[0:2]) + if len(y) != 0: + # 从坐标转换成0~1的百分比 + boxes = np.array(y[:, :4], dtype=np.float32) + boxes[:, 0] = boxes[:, 0] / self.image_size[1] + boxes[:, 1] = boxes[:, 1] / self.image_size[0] + boxes[:, 2] = boxes[:, 2] / self.image_size[1] + boxes[:, 3] = boxes[:, 3] / self.image_size[0] + + boxes = np.maximum(np.minimum(boxes, 1), 0) + boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + + boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 + boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 + y = np.concatenate([boxes, y[:, -1:]], axis=-1) + + img = np.array(img, dtype=np.float32) + + tmp_inp = np.transpose(img / 255.0, (2, 0, 1)) + tmp_targets = np.array(y, dtype=np.float32) + return tmp_inp, tmp_targets + + +# DataLoader中collate_fn使用 +def yolo_dataset_collate(batch): + images = [] + bboxes = [] + for img, box in batch: + images.append(img) + bboxes.append(box) + images = np.array(images) + bboxes = np.array(bboxes) + return images, bboxes +