Skip to content

Commit

Permalink
fix the bug of shape in random affine, add torch version check. (meit…
Browse files Browse the repository at this point in the history
  • Loading branch information
mtjhl authored May 15, 2023
1 parent eee6b74 commit 75f662c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
6 changes: 4 additions & 2 deletions yolov6/assigners/anchor_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from yolov6.utils.general import check_version

torch_1_10_plus = check_version(torch.__version__, minimum='1.10.0')

def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5, device='cpu', is_eval=False, mode='af'):
'''Generate anchors from features.'''
Expand All @@ -13,7 +15,7 @@ def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.
_, _, h, w = feats[i].shape
shift_x = torch.arange(end=w, device=device) + grid_cell_offset
shift_y = torch.arange(end=h, device=device) + grid_cell_offset
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij') if torch_1_10_plus else torch.meshgrid(shift_y, shift_x)
anchor_point = torch.stack(
[shift_x, shift_y], axis=-1).to(torch.float)
if mode == 'af': # anchor-free
Expand All @@ -35,7 +37,7 @@ def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.
cell_half_size = grid_cell_size * stride * 0.5
shift_x = (torch.arange(end=w, device=device) + grid_cell_offset) * stride
shift_y = (torch.arange(end=h, device=device) + grid_cell_offset) * stride
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij') if torch_1_10_plus else torch.meshgrid(shift_y, shift_x)
anchor = torch.stack(
[
shift_x - cell_half_size, shift_y - cell_half_size,
Expand Down
5 changes: 4 additions & 1 deletion yolov6/data/data_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def random_affine(img, labels=(), degrees=10, translate=.1, scale=.1, shear=10,
new_shape=(640, 640)):
'''Applies Random affine transformation.'''
n = len(labels)
height, width = new_shape
if isinstance(new_shape, int):
height = width = new_shape
else:
height, width = new_shape

M, s = get_transform_matrix(img.shape[:2], (height, width), degrees, scale, shear, translate)
if (M != np.eye(3)).any(): # image changed
Expand Down
11 changes: 11 additions & 0 deletions yolov6/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import torch
import requests
import pkg_resources as pkg
from pathlib import Path
from yolov6.utils.events import LOGGER

Expand Down Expand Up @@ -114,3 +115,13 @@ def check_img_size(imgsz, s=32, floor=0):
if new_size != imgsz:
LOGGER.warning(f'--img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
return new_size


def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
# Check whether the package's version is match the required version.
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
result = (current == minimum) if pinned else (current >= minimum) # bool
if hard:
info = f'⚠️ {name}{minimum} is required by YOLOv6, but {name}{current} is currently installed'
assert result, info # assert minimum version requirement
return result

0 comments on commit 75f662c

Please sign in to comment.