Skip to content

Commit

Permalink
pep8 requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
imelekhov committed Apr 4, 2019
1 parent b13851a commit b640d59
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 27 deletions.
8 changes: 4 additions & 4 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,13 @@ def get_grid(self, H, ccrop):

# getting the central patch from the pivot
Xwarp_crop = X_grid_pivot[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]
Ywarp_crop = Y_grid_pivot[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]
X_crop = X_[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]
Y_crop = Y_[Y_CCROP:Y_CCROP + H_SCALE,
X_CCROP:X_CCROP + W_SCALE]
X_CCROP:X_CCROP + W_SCALE]

# crop grid
Xwarp_crop_range = \
Expand Down
49 changes: 28 additions & 21 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from utils.optimize import train_epoch, validate_epoch
from model.net import DGCNet


if __name__ == "__main__":
# Argument parsing
parser = argparse.ArgumentParser(description='DGC-Net train script')
Expand All @@ -27,15 +28,16 @@
help='path to TokyoTimeMachine dataset and csv files')
parser.add_argument('--metadata-path', type=str, default='./data/',
help='path to the CSV files')
parser.add_argument('--model', type=str, default='dgc', help='Model to use',
choices=['dgc', 'dgcm'])
parser.add_argument('--model', type=str, default='dgc',
help='Model to use', choices=['dgc', 'dgcm'])
parser.add_argument('--snapshots', type=str, default='./snapshots')
parser.add_argument('--logs', type=str, default='./logs')
# Optimization parameters
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--momentum', type=float,
default=0.9, help='momentum constant')
parser.add_argument('--start_epoch', type=int, default=-1, help='start epoch')
parser.add_argument('--start_epoch', type=int, default=-1,
help='start epoch')
parser.add_argument('--n_epoch', type=int, default=70,
help='number of training epochs')
parser.add_argument('--batch-size', type=int, default=32,
Expand All @@ -44,7 +46,8 @@
help='number of parallel threads for dataloaders')
parser.add_argument('--weight-decay', type=float, default=0.00001,
help='weight decay constant')
parser.add_argument('--seed', type=int, default=1984, help='Pseudo-RNG seed')
parser.add_argument('--seed', type=int, default=1984,
help='Pseudo-RNG seed')
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
Expand Down Expand Up @@ -76,19 +79,21 @@
weights_loss_coeffs = [1, 1, 1, 1, 1]
weights_loss_feat = [1, 1, 1, 1]

train_dataset = HomoAffTpsDataset(image_path=args.image_data_path,
csv_file=osp.join(args.metadata_path,
'csv',
'homo_aff_tps_train.csv'),
transforms=dataset_transforms,
pyramid_param=pyramid_param)

val_dataset = HomoAffTpsDataset(image_path=args.image_data_path,
csv_file=osp.join(args.metadata_path,
'csv',
'homo_aff_tps_test.csv'),
transforms=dataset_transforms,
pyramid_param=pyramid_param)
train_dataset = \
HomoAffTpsDataset(image_path=args.image_data_path,
csv_file=osp.join(args.metadata_path,
'csv',
'homo_aff_tps_train.csv'),
transforms=dataset_transforms,
pyramid_param=pyramid_param)

val_dataset = \
HomoAffTpsDataset(image_path=args.image_data_path,
csv_file=osp.join(args.metadata_path,
'csv',
'homo_aff_tps_test.csv'),
transforms=dataset_transforms,
pyramid_param=pyramid_param)

train_dataloader = DataLoader(train_dataset,
batch_size=args.batch_size,
Expand All @@ -114,9 +119,10 @@
model = model.to(device)

# Optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay)
optimizer = \
optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay)
# Scheduler
scheduler = lr_scheduler.MultiStepLR(optimizer,
milestones=[2, 15, 30, 45, 60],
Expand Down Expand Up @@ -153,7 +159,8 @@
criterion_grid=criterion_grid,
criterion_matchability=criterion_match,
loss_grid_weights=weights_loss_coeffs)
print(colored('==> ', 'blue') + 'Val average grid loss :', val_loss_grid)
print(colored('==> ', 'blue') + 'Val average grid loss :',
val_loss_grid)
print(colored('==> ', 'blue') + 'epoch :', epoch + 1)
val_losses.append(val_loss_grid)

Expand Down
6 changes: 4 additions & 2 deletions utils/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def train_epoch(net,
mini_batch['target_image'].to(device))

if criterion_matchability is not None and estimates_mask is None:
raise ValueError('Cannot use `criterion_matchability` without mask estimates')
raise ValueError('Cannot use `criterion_matchability` \
without mask estimates')

Loss_masked_grid = 0
EPE_loss = 0
Expand Down Expand Up @@ -156,7 +157,8 @@ def validate_epoch(net,
mini_batch['target_image'].to(device))

if criterion_matchability is not None and estimates_mask is None:
raise ValueError('Cannot use criterion_matchability without mask estimates')
raise ValueError('Cannot use criterion_matchability \
without mask estimates')

Loss_masked_grid = 0
# grid loss components (over all layers of the feature pyramid):
Expand Down

0 comments on commit b640d59

Please sign in to comment.