forked from bubbliiiing/yolo3-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d41adb5
commit 20975fa
Showing
1 changed file
with
141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from random import shuffle | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
import torch.nn.functional as F | ||
from PIL import Image | ||
from torch.autograd import Variable | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.dataset import Dataset | ||
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb | ||
from nets.yolo_training import Generator | ||
|
||
|
||
class YoloDataset(Dataset): | ||
def __init__(self, train_lines, image_size, mosaic=True): | ||
super(YoloDataset, self).__init__() | ||
|
||
self.train_lines = train_lines | ||
self.train_batches = len(train_lines) | ||
self.image_size = image_size | ||
self.mosaic = mosaic | ||
self.flag = True | ||
|
||
def __len__(self): | ||
return self.train_batches | ||
|
||
def rand(self, a=0, b=1): | ||
return np.random.rand() * (b - a) + a | ||
|
||
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5): | ||
"""实时数据增强的随机预处理""" | ||
line = annotation_line.split() | ||
image = Image.open(line[0]) | ||
iw, ih = image.size | ||
h, w = input_shape | ||
box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) | ||
|
||
# 调整图片大小 | ||
new_ar = w / h * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter) | ||
scale = self.rand(.25, 2) | ||
if new_ar < 1: | ||
nh = int(scale * h) | ||
nw = int(nh * new_ar) | ||
else: | ||
nw = int(scale * w) | ||
nh = int(nw / new_ar) | ||
image = image.resize((nw, nh), Image.BICUBIC) | ||
|
||
# 放置图片 | ||
dx = int(self.rand(0, w - nw)) | ||
dy = int(self.rand(0, h - nh)) | ||
new_image = Image.new('RGB', (w, h), | ||
(np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))) | ||
new_image.paste(image, (dx, dy)) | ||
image = new_image | ||
|
||
# 是否翻转图片 | ||
flip = self.rand() < .5 | ||
if flip: | ||
image = image.transpose(Image.FLIP_LEFT_RIGHT) | ||
|
||
# 色域变换 | ||
hue = self.rand(-hue, hue) | ||
sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat) | ||
val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val) | ||
x = rgb_to_hsv(np.array(image) / 255.) | ||
x[..., 0] += hue | ||
x[..., 0][x[..., 0] > 1] -= 1 | ||
x[..., 0][x[..., 0] < 0] += 1 | ||
x[..., 1] *= sat | ||
x[..., 2] *= val | ||
x[x > 1] = 1 | ||
x[x < 0] = 0 | ||
image_data = hsv_to_rgb(x) * 255 # numpy array, 0 to 1 | ||
|
||
# 调整目标框坐标 | ||
box_data = np.zeros((len(box), 5)) | ||
if len(box) > 0: | ||
np.random.shuffle(box) | ||
box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx | ||
box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy | ||
if flip: | ||
box[:, [0, 2]] = w - box[:, [2, 0]] | ||
box[:, 0:2][box[:, 0:2] < 0] = 0 | ||
box[:, 2][box[:, 2] > w] = w | ||
box[:, 3][box[:, 3] > h] = h | ||
box_w = box[:, 2] - box[:, 0] | ||
box_h = box[:, 3] - box[:, 1] | ||
box = box[np.logical_and(box_w > 1, box_h > 1)] # 保留有效框 | ||
box_data = np.zeros((len(box), 5)) | ||
box_data[:len(box)] = box | ||
if len(box) == 0: | ||
return image_data, [] | ||
|
||
if (box_data[:, :4] > 0).any(): | ||
return image_data, box_data | ||
else: | ||
return image_data, [] | ||
|
||
def __getitem__(self, index): | ||
if index == 0: | ||
shuffle(self.train_lines) | ||
lines = self.train_lines | ||
n = self.train_batches | ||
index = index % n | ||
img, y = self.get_random_data(lines[index], self.image_size[0:2]) | ||
if len(y) != 0: | ||
# 从坐标转换成0~1的百分比 | ||
boxes = np.array(y[:, :4], dtype=np.float32) | ||
boxes[:, 0] = boxes[:, 0] / self.image_size[1] | ||
boxes[:, 1] = boxes[:, 1] / self.image_size[0] | ||
boxes[:, 2] = boxes[:, 2] / self.image_size[1] | ||
boxes[:, 3] = boxes[:, 3] / self.image_size[0] | ||
|
||
boxes = np.maximum(np.minimum(boxes, 1), 0) | ||
boxes[:, 2] = boxes[:, 2] - boxes[:, 0] | ||
boxes[:, 3] = boxes[:, 3] - boxes[:, 1] | ||
|
||
boxes[:, 0] = boxes[:, 0] + boxes[:, 2] / 2 | ||
boxes[:, 1] = boxes[:, 1] + boxes[:, 3] / 2 | ||
y = np.concatenate([boxes, y[:, -1:]], axis=-1) | ||
|
||
img = np.array(img, dtype=np.float32) | ||
|
||
tmp_inp = np.transpose(img / 255.0, (2, 0, 1)) | ||
tmp_targets = np.array(y, dtype=np.float32) | ||
return tmp_inp, tmp_targets | ||
|
||
|
||
# DataLoader中collate_fn使用 | ||
def yolo_dataset_collate(batch): | ||
images = [] | ||
bboxes = [] | ||
for img, box in batch: | ||
images.append(img) | ||
bboxes.append(box) | ||
images = np.array(images) | ||
bboxes = np.array(bboxes) | ||
return images, bboxes | ||
|