Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
predict tailored to three-frames output
Browse files Browse the repository at this point in the history
  • Loading branch information
mareksubocz committed Mar 17, 2023
1 parent 18c94a3 commit 9621f58
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def get_ball_position(img, original_img_=None):
ret, thresh = cv.threshold(img, 0.9, 1, 0)
thresh = cv.convertScaleAbs(thresh)

contours,hierarchy = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
contours, hierarchy = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
# numpy image
if len(contours) != 0:

Expand Down Expand Up @@ -45,9 +45,9 @@ def parse_opt():
opt.dropout = 0
device = torch.device(opt.device)
model = TrackNet(opt).to(device)
model.load(opt.weights, device = device)
model.load(opt.weights, device = opt.device)
model.eval()

cap = cv.VideoCapture(opt.video)
prev_frame_1 = None
prev_frame_2 = None
Expand All @@ -60,11 +60,13 @@ def parse_opt():
if prev_frame_1 is not None and prev_frame_2 is not None:
frames = torch.cat([prev_frame_2, prev_frame_1, frame_torch], dim=0).unsqueeze(0)
pred = model(frames)
pred = pred[0,0,:,:].detach().cpu().numpy()
frame_resized = cv.resize(frame, pred.shape[::-1], interpolation = cv.INTER_AREA)
get_ball_position(pred, original_img_=frame_resized)
cv.imshow('prediction', pred)
cv.imshow('original', frame_resized)
pred = pred[0,:,:,:].detach().cpu().numpy()
for i in range(pred.shape[0]):
pred_cut = pred[i,:,:]
frame_resized = cv.resize(frame, pred_cut.shape[::-1], interpolation = cv.INTER_AREA)
get_ball_position(pred_cut, original_img_=frame_resized)
cv.imshow('prediction', pred_cut)
cv.imshow('original', frame_resized)
prev_frame_2 = prev_frame_1
prev_frame_1 = frame_torch

Expand All @@ -73,4 +75,4 @@ def parse_opt():


cap.release()
cv.destroyAllWindows() # destroy all opened windows
cv.destroyAllWindows() # destroy all opened windows

0 comments on commit 9621f58

Please sign in to comment.