Skip to content

Commit

Permalink
Update detect.py
Browse files Browse the repository at this point in the history
Added some recent updates that were missing, and updated the filename with an if else.
  • Loading branch information
glenn-jocher authored Jun 26, 2020
1 parent 68f6361 commit 496ec33
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def detect(save_img=False):
dataset = LoadImages(source, img_size=imgsz)

# Get names and colors
names = model.names if hasattr(model, 'names') else model.modules.names
names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

# Run inference
Expand Down Expand Up @@ -80,6 +80,7 @@ def detect(save_img=False):
p, s, im0 = path, '', im0s

save_path = str(Path(out) / Path(p).name)
txt_path = save_path[:save_path.rfind('.')] + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] #  normalization gain whwh
if det is not None and len(det):
Expand All @@ -95,12 +96,8 @@ def detect(save_img=False):
for *xyxy, conf, cls in det:
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
if dataset.frame == 0:
with open(save_path[:save_path.rfind('.')] + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
else:
with open(save_path[:save_path.rfind('.')] + '_' + str(dataset.frame) + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format

if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf)
Expand Down Expand Up @@ -160,3 +157,8 @@ def detect(save_img=False):

with torch.no_grad():
detect()

# Update all models
# for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
# detect()
# create_pretrained(opt.weights, opt.weights)

0 comments on commit 496ec33

Please sign in to comment.