Skip to content

Commit

Permalink
training procedure added
Browse files Browse the repository at this point in the history
  • Loading branch information
imelekhov committed Apr 2, 2019
1 parent e29da6a commit b801109
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
2 changes: 1 addition & 1 deletion utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def calculate_epe_hpatches(net, val_loader, img_size=240):
bs, _, _, _ = source_img.shape

# net prediction
estimates_grid = net(source_img, target_img)
estimates_grid, estimates_mask = net(source_img, target_img)

flow_est = estimates_grid[-1].transpose(1,2).transpose(2,3).to(net.device())
flow_target = mini_batch['correspondence_map'].to(net.device())
Expand Down
80 changes: 80 additions & 0 deletions utils/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np

from tqdm import tqdm
import torch
import torch.nn.functional as F
import cv2

def train_epoch(net,
optimizer,
train_loader,
criterion_grid,
criterion_matchability=None,
loss_grid_weights=[1, 1, 1, 1, 1],
L_coeff=1):


net.train()
running_total_loss = 0
running_matchability_loss = 0

pbar = tqdm(enumerate(train_loader), total=len(train_loader))
for i, mini_batch in pbar:

optimizer.zero_grad()

# net predictions
estimates_grid, estimates_mask = net(mini_batch['source_image'].to(net.device()),
mini_batch['target_image'].to(net.device()))

if criterion_matchability is None:
assert not estimates_mask, 'Cannot use `criterion_matchability` without mask estimates'

Loss_masked_grid = 0
EPE_loss = 0

# grid loss components (over all layers of the feature pyramid):
for k in range(0, len(estimates_grid)):

grid_gt = mini_batch['correspondence_map_pyro'][k].to(net.device())
bs, s_x, s_y, _ = grid_gt.shape

flow_est = estimates_grid[k].transpose(1,2).transpose(2,3)
flow_target = grid_gt

# calculating mask
mask_x_gt = flow_target[:, :, :, 0].ge(-1) & flow_target[:, :, :, 0].le(1)
mask_y_gt = flow_target[:, :, :, 1].ge(-1) & flow_target[:, :, :, 1].le(1)
mask_gt = mask_x_gt & mask_y_gt

# number of valid pixels based on the mask
N_valid_pxs = mask_gt.view(1, bs * s_x * s_y).data.sum()

# applying mask
mask_gt = torch.cat((mask_gt.unsqueeze(3), mask_gt.unsqueeze(3)), dim=3).float()
flow_target_m = flow_target * mask_gt
flow_est_m = flow_est * mask_gt

# compute grid loss
Loss_masked_grid = Loss_masked_grid + loss_grid_weights[k] * criterion_grid(flow_est_m, flow_target_m, N_valid_pxs)

Loss_matchability = 0
if estimates_mask is not None:
match_mask_gt = mini_batch['mask_x'][-1].to(net.device()) & mini_batch['mask_y'][-1].to(net.device())
Loss_matchability = criterion_matchability(estimates_mask.squeeze(1), match_mask_gt)

Loss = Loss_masked_grid + L_coeff * Loss_matchability
Loss.backward()

optimizer.step()

running_total_loss += Loss.item()
if estimates_mask is not None:
running_matchability_loss += Loss_matchability.item()
pbar.set_description('R_total_loss: %.3f/%.3f | Match_loss: %.3f/%.3f' % (running_total_loss / (i+1), Loss.item(), \
runnining_matchability_loss / (i + 1), Loss_matchability.item()))
else:
pbar.set_description('R_total_loss: %.3f/%.3f' % (running_total_loss / (i+1), Loss.item()))

running_total_loss /= len(train_loader)
return running_total_loss

0 comments on commit b801109

Please sign in to comment.