Skip to content

Commit

Permalink
Added video support, read and write
Browse files Browse the repository at this point in the history
  • Loading branch information
1chimaruGin committed Jul 5, 2022
1 parent b6cef87 commit d737515
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 32 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ __pycache__/
# C extensions

# Distribution / packaging
*.pt
.Python
videos/
build/
runs/
weights/
Expand Down
4 changes: 3 additions & 1 deletion tools/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_args_parser(add_help=True):
parser.add_argument('--device', default='0', help='device to run our model i.e. 0 or 0,1,2,3 or cpu.')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt.')
parser.add_argument('--save-img', action='store_false', help='save visuallized inference results.')
parser.add_argument('--view-img', action='store_true', help='show inference results')
parser.add_argument('--classes', nargs='+', type=int, help='filter by classes, e.g. --classes 0, or --classes 0 2 3.')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS.')
parser.add_argument('--project', default='runs/inference', help='save inference results to project/name.')
Expand All @@ -50,6 +51,7 @@ def run(weights=osp.join(ROOT, 'yolov6s.pt'),
device='',
save_txt=False,
save_img=True,
view_img=False,
classes=None,
agnostic_nms=False,
project=osp.join(ROOT, 'runs/inference'),
Expand Down Expand Up @@ -93,7 +95,7 @@ def run(weights=osp.join(ROOT, 'yolov6s.pt'),

# Inference
inferer = Inferer(source, weights, device, yaml, img_size, half)
inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf)
inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf, view_img)

if save_txt or save_img:
LOGGER.info(f"Results saved to {save_dir}")
Expand Down
Empty file added yolov6/__init__.py
Empty file.
74 changes: 44 additions & 30 deletions yolov6/core/inferer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,28 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import warnings
warnings.filterwarnings("ignore")

import os
import os.path as osp
import math
from tqdm import tqdm
import numpy as np
import cv2
import math
import torch
import numpy as np
import os.path as osp

from tqdm import tqdm
from pathlib import Path
from PIL import ImageFont

from yolov6.utils.events import LOGGER, load_yaml
from yolov6.layers.common import DetectBackend
from yolov6.data.data_augment import letterbox
from yolov6.utils.nms import non_max_suppression


class Inferer:
def __init__(self, source, weights, device, yaml, img_size, half):
import glob
from yolov6.data.datasets import IMG_FORMATS
from yolov6.data.datasets import LoadData

self.__dict__.update(locals())

Expand All @@ -43,19 +47,13 @@ def __init__(self, source, weights, device, yaml, img_size, half):
self.model(torch.zeros(1, 3, *self.img_size).to(self.device).type_as(next(self.model.model.parameters()))) # warmup

# Load data
if os.path.isdir(source):
img_paths = sorted(glob.glob(os.path.join(source, '*.*'))) # dir
elif os.path.isfile(source):
img_paths = [source] # files
else:
raise Exception(f'Invalid path: {source}')
self.img_paths = [img_path for img_path in img_paths if img_path.split('.')[-1].lower() in IMG_FORMATS]
self.files = LoadData(source)

def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf):
def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, save_img, hide_labels, hide_conf, view_img):
''' Model Inference and results visualization '''

for img_path in tqdm(self.img_paths):
img, img_src = self.precess_image(img_path, self.img_size, self.stride, self.half)
vid_path, vid_writer, windows = None, None, []
for img_src, img_path, vid_cap in tqdm(self.files):
img, img_src = self.precess_image(img_src, self.img_size, self.stride, self.half)
img = img.to(self.device)
if len(img.shape) == 3:
img = img[None]
Expand All @@ -67,15 +65,14 @@ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir,
txt_path = osp.join(save_dir, 'labels', osp.splitext(osp.basename(img_path))[0])

gn = torch.tensor(img_src.shape)[[1, 0, 1, 0]] # normalization gain whwh
img_ori = img_src

img_ori = img_src.copy()
# check image and font
assert img_ori.data.contiguous, 'Image needs to be contiguous. Please apply to input images with np.ascontiguousarray(im).'
self.font_check()

if len(det):
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()

for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
Expand All @@ -91,20 +88,37 @@ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir,

img_src = np.asarray(img_ori)

# Save results (image with detections)
if save_img:
if view_img:
if img_path not in windows:
windows.append(img_path)
cv2.namedWindow(str(img_path), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(str(img_path), img_src.shape[1], img_src.shape[0])
cv2.imshow(str(img_path), img_src)
cv2.waitKey(1) # 1 millisecond

# Save results (image with detections)
if save_img:
if self.files.type == 'image':
cv2.imwrite(save_path, img_src)
else: # 'video' or 'stream'
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
if vid_cap: # video
fps = vid_cap.get(cv2.CAP_PROP_FPS)
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
else: # stream
fps, w, h = 30, img_ori.shape[1], img_ori.shape[0]
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer.write(img_src)

@staticmethod
def precess_image(path, img_size, stride, half):
def precess_image(img_src, img_size, stride, half):
'''Process image before image inference.'''
try:
img_src = cv2.imread(path)
assert img_src is not None, f'Invalid image: {path}'
except Exception as e:
LOGGER.warning(e)
image = letterbox(img_src, img_size, stride=stride)[0]

image = letterbox(img_src, img_size, stride=stride)[0]
# Convert
image = image.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
image = torch.from_numpy(np.ascontiguousarray(image))
Expand Down
64 changes: 63 additions & 1 deletion yolov6/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# -*- coding:utf-8 -*-

import glob
from io import UnsupportedOperation
import os
import os.path as osp
import random
import json
import time
import hashlib

from pathlib import Path
from multiprocessing.pool import Pool

import cv2
Expand All @@ -29,6 +30,7 @@

# Parameters
IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"]
VID_FORMATS = ["mp4", "mov", "avi", "mkv"]
# Get orientation exif tag
for k, v in ExifTags.TAGS.items():
if v == "Orientation":
Expand Down Expand Up @@ -548,3 +550,63 @@ def get_hash(paths):
assert isinstance(paths, list), "Only support list currently."
h = hashlib.md5("".join(paths).encode())
return h.hexdigest()

class LoadData:
def __init__(self, path):
p = str(Path(path).resolve()) # os-agnostic absolute path
if os.path.isdir(p):
files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
elif os.path.isfile(p):
files = [p] # files
else:
raise FileNotFoundError(f'Invalid path {p}')

imgp = [i for i in files if i.split('.')[-1] in IMG_FORMATS]
vidp = [v for v in files if v.split('.')[-1] in VID_FORMATS]
self.files = imgp + vidp
self.nf = len(self.files)
self.type = 'image'
if any(vidp):
self.add_video(vidp[0]) # new video
else:
self.cap = None

@staticmethod
def checkext(path):
file_type = 'image' if path.split('.')[-1].lower() in IMG_FORMATS else 'video'
return file_type

def __iter__(self):
self.count = 0
return self

def __next__(self):
if self.count == self.nf:
raise StopIteration
path = self.files[self.count]

if self.checkext(path) == 'video':
self.type = 'video'
ret_val, img = self.cap.read()
while not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self.add_video(path)
ret_val, img = self.cap.read()
else:
# Read image
self.count += 1
img = cv2.imread(path) # BGR

return img, path, self.cap

def add_video(self, path):
self.frame = 0
self.cap = cv2.VideoCapture(path)
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))

def __len__(self):
return self.nf # number of files

0 comments on commit d737515

Please sign in to comment.