Skip to content

Commit

Permalink
1. Modify cache logic: When dataset changed, cache file will be re-ge…
Browse files Browse the repository at this point in the history
…nerated; 2. Modify the code logic of identifying whether coco dataset is used; 3. Changes of config file; 4. Some format changes.
  • Loading branch information
MTChengMeng committed Jun 29, 2022
1 parent 3ef29fc commit 6642dee
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 75 deletions.
2 changes: 2 additions & 0 deletions data/coco.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ test: ../coco/images/test2017
anno_path: ../coco/annotations/instances_val2017.json
# number of classes
nc: 80
# whether it is coco dataset, only coco dataset should be set to True.
is_coco: True

# class names
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
Expand Down
10 changes: 10 additions & 0 deletions data/dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
train: path/to/data/images/train # train images
val: path/to/data/images/val # val images
test: path/to/data/images/test # test images (optional)

# whether it is coco dataset, only coco dataset should be set to True.
is_coco: False
# Classes
nc: 20 # number of classes
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] # class names
10 changes: 5 additions & 5 deletions yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def train_in_steps(self):
self.update_optimizer()

def eval_and_save(self):
remaining_epochs = self.max_epoch - self.epoch
eval_interval = self.args.eval_interval if remaining_epochs > self.args.heavy_eval_range else 1
is_val_epoch = (not self.args.eval_final_only or (remaining_epochs == 1)) and (self.epoch % eval_interval == 0)
epoch_sub = self.max_epoch - self.epoch
val_period = 20 if epoch_sub > 100 else 1 # to fasten training time, evaluate in every 20 epochs for the early stage.
is_val_epoch = (not self.args.noval or (epoch_sub == 1)) and (self.epoch % val_period == 0)
if self.main_process:
self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model
if is_val_epoch:
Expand Down Expand Up @@ -206,14 +206,14 @@ def get_data_loader(args, cfg, data_dict):
train_loader = create_dataloader(train_path, args.img_size, args.batch_size // args.world_size, grid_size,
hyp=dict(cfg.data_aug), augment=True, rect=False, rank=args.local_rank,
workers=args.workers, shuffle=True, check_images=args.check_images,
check_labels=args.check_labels, class_names=class_names, task='train')[0]
check_labels=args.check_labels, data_dict=data_dict, task='train')[0]
# create val dataloader
val_loader = None
if args.rank in [-1, 0]:
val_loader = create_dataloader(val_path, args.img_size, args.batch_size // args.world_size * 2, grid_size,
hyp=dict(cfg.data_aug), rect=True, rank=-1, pad=0.5,
workers=args.workers, check_images=args.check_images,
check_labels=args.check_labels, class_names=class_names, task='val')[0]
check_labels=args.check_labels, data_dict=data_dict, task='val')[0]

return train_loader, val_loader

Expand Down
8 changes: 4 additions & 4 deletions yolov6/core/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def init_data(self, dataloader, task):
'''Initialize dataloader.
Returns a dataloader for task val or speed.
'''
self.is_coco = isinstance(self.data.get('val'), str) and 'coco' in self.data['val'] # COCO dataset
self.is_coco = self.data.get("is_coco", False)
self.ids = self.coco80_to_coco91_class() if self.is_coco else list(range(1000))
if task != 'train':
pad = 0.0 if task == 'speed' else 0.5
dataloader = create_dataloader(self.data[task if task in ('train', 'val', 'test') else 'val'],
self.img_size, self.batch_size, self.stride, pad=pad, rect=True,
class_names=self.data['names'], task=task)[0]
self.img_size, self.batch_size, self.stride, check_labels=True, pad=pad, rect=True,
data_dict=self.data, task=task)[0]
return dataloader

def predict_model(self, model, dataloader, task):
Expand Down Expand Up @@ -105,7 +105,7 @@ def predict_model(self, model, dataloader, task):
def eval_model(self, pred_results, model, dataloader, task):
'''Evaluate models
For task speed, this function only evaluates the speed of model and outputs inference time.
For task val, this function evaluates the speed and mAP by pycocotools, and returns
For task val, this function evalutates the speed and mAP by pycocotools, and returns
inference time and mAP value.
'''
LOGGER.info(f'\nEvaluating speed.')
Expand Down
88 changes: 62 additions & 26 deletions yolov6/data/data_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,83 @@
from yolov6.utils.torch_utils import torch_distributed_zero_first


def create_dataloader(path, img_size, batch_size, stride, hyp=None, augment=False, check_images=False, check_labels=False, pad=0.0, rect=False, rank=-1, workers=8, shuffle=False,class_names=None, task='Train'):
'''Create general dataloader.
def create_dataloader(
path,
img_size,
batch_size,
stride,
hyp=None,
augment=False,
check_images=False,
check_labels=False,
pad=0.0,
rect=False,
rank=-1,
workers=8,
shuffle=False,
data_dict=None,
task="Train",
):
"""Create general dataloader.
Returns dataloader and dataset
'''
"""
if rect and shuffle:
LOGGER.warning('WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False')
LOGGER.warning(
"WARNING: --rect is incompatible with DataLoader shuffle, setting shuffle=False"
)
shuffle = False
with torch_distributed_zero_first(rank):
dataset = TrainValDataset(path, img_size, batch_size,
augment=augment,
hyp=hyp,
rect=rect,
check_images=check_images,
stride=int(stride),
pad=pad,
rank=rank,
class_names=class_names,
task=task)
dataset = TrainValDataset(
path,
img_size,
batch_size,
augment=augment,
hyp=hyp,
rect=rect,
check_images=check_images,
check_labels=check_labels,
stride=int(stride),
pad=pad,
rank=rank,
data_dict=data_dict,
task=task,
)

batch_size = min(batch_size, len(dataset))
workers = min([os.cpu_count() // int(os.getenv('WORLD_SIZE', 1)), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
return TrainValDataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=workers,
sampler=sampler,
pin_memory=True,
collate_fn=TrainValDataset.collate_fn), dataset
workers = min(
[
os.cpu_count() // int(os.getenv("WORLD_SIZE", 1)),
batch_size if batch_size > 1 else 0,
workers,
]
) # number of workers
sampler = (
None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
)
return (
TrainValDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=workers,
sampler=sampler,
pin_memory=True,
collate_fn=TrainValDataset.collate_fn,
),
dataset,
)


class TrainValDataLoader(dataloader.DataLoader):
""" Dataloader that reuses workers
"""Dataloader that reuses workers
Uses same syntax as vanilla DataLoader
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()

def __len__(self):
Expand All @@ -63,7 +99,7 @@ def __iter__(self):


class _RepeatSampler:
""" Sampler that repeats forever
"""Sampler that repeats forever
Args:
sampler (Sampler)
Expand Down
97 changes: 57 additions & 40 deletions yolov6/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import json
import time
import hashlib

from multiprocessing.pool import Pool

Expand All @@ -16,7 +17,6 @@
from PIL import ExifTags, Image, ImageOps
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path

from .data_augment import (
augment_hsv,
Expand Down Expand Up @@ -51,14 +51,15 @@ def __init__(
stride=32,
pad=0.0,
rank=-1,
class_names=None,
data_dict=None,
task="train",
):
assert task.lower() in ("train", "val", "speed"), f"Not supported task: {task}"
t1 = time.time()
self.__dict__.update(locals())
self.main_process = self.rank in (-1, 0)
self.task = self.task.capitalize()
self.class_names = data_dict["names"]
self.img_paths, self.labels = self.get_imgs_labels(self.img_dir)
if self.rect:
shapes = [self.img_info[p]["shape"] for p in self.img_paths]
Expand Down Expand Up @@ -201,18 +202,28 @@ def get_imgs_labels(self, img_dir):
valid_img_record = osp.join(
osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
)
img_info = {}
NUM_THREADS = min(8, os.cpu_count())
# check images
if (
self.check_images or not osp.exists(valid_img_record)
) and self.main_process:
img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
img_paths = sorted(
p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
)
assert img_paths, f"No images found in {img_dir}."

img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True)
img_paths = sorted(
p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS
)
assert img_paths, f"No images found in {img_dir}."

img_hash = self.get_hash(img_paths)
if osp.exists(valid_img_record):
with open(valid_img_record, "r") as f:
cache_info = json.load(f)
if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
img_info = cache_info["information"]
else:
self.check_images = True
else:
self.check_images = True

# check images
if self.check_images and self.main_process:
img_info = {}
nc, msgs = 0, [] # number corrupt, messages
LOGGER.info(
f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
Expand All @@ -233,29 +244,28 @@ def get_imgs_labels(self, img_dir):
if msgs:
LOGGER.info("\n".join(msgs))

cache_info = {"information": img_info, "image_hash": img_hash}
# save valid image paths.
with open(valid_img_record, "w") as f:
json.dump(img_info, f)
json.dump(cache_info, f)

# check and load anns
label_dir = osp.join(
osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
)
assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
if not img_info:
with open(valid_img_record, "r") as f:
img_info = json.load(f)
assert (
img_info
), "No information in record files, please add option --check_images."

img_paths = list(img_info.keys())
label_paths = [
label_paths = sorted(
osp.join(label_dir, osp.basename(p).split(".")[0] + ".txt")
for p in img_paths
]
if (
self.check_labels or "labels" not in img_info[img_paths[0]]
): # key 'labels' not saved in img_info
)
label_hash = self.get_hash(label_paths)
if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
self.check_labels = True

if self.check_labels:
cache_info["label_hash"] = label_hash
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number corrupt, messages
LOGGER.info(
f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
Expand Down Expand Up @@ -289,27 +299,27 @@ def get_imgs_labels(self, img_dir):
if self.main_process:
pbar.close()
with open(valid_img_record, "w") as f:
json.dump(img_info, f)
json.dump(cache_info, f)
if msgs:
LOGGER.info("\n".join(msgs))
if nf == 0:
LOGGER.warning(
f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. "
)
else:
with open(valid_img_record) as f:
img_info = json.load(f)

if self.task.lower() == "val":
assert (
self.class_names
), "Class names is required when converting labels to coco format for evaluating."
save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
if not osp.exists(save_dir):
os.mkdir(save_dir)
save_path = osp.join(
save_dir, "instances_" + osp.basename(img_dir) + ".json"
)
if not osp.exists(save_path):
if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
else:
assert (
self.class_names
), "Class names is required when converting labels to coco format for evaluating."
save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
if not osp.exists(save_dir):
os.mkdir(save_dir)
save_path = osp.join(
save_dir, "instances_" + osp.basename(img_dir) + ".json"
)
TrainValDataset.generate_coco_format_labels(
img_info, self.class_names, save_path
)
Expand Down Expand Up @@ -489,8 +499,8 @@ def generate_coco_format_labels(img_info, class_names, save_path):
LOGGER.info(f"Convert to COCO format")
for i, (img_path, info) in enumerate(tqdm(img_info.items())):
labels = info["labels"] if info["labels"] else []
path = Path(img_path)
img_id = int(path.stem) if path.stem.isnumeric() else path.stem
img_id = osp.splitext(osp.basename(img_path))[0]
img_id = int(img_id) if img_id.isnumeric() else img_id
img_w, img_h = info["shape"]
dataset["images"].append(
{
Expand Down Expand Up @@ -531,3 +541,10 @@ def generate_coco_format_labels(img_info, class_names, save_path):
LOGGER.info(
f"Convert to COCO format finished. Resutls saved in {save_path}"
)

@staticmethod
def get_hash(paths):
"""Get the hash value of paths"""
assert isinstance(paths, list), "Only support list currently."
h = hashlib.md5("".join(paths).encode())
return h.hexdigest()

0 comments on commit 6642dee

Please sign in to comment.