Skip to content

Commit

Permalink
The repo code has been updated to Pytorch v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
akamaster committed Oct 5, 2019
1 parent 829abc3 commit 4e4f8da
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 43 deletions.
4 changes: 2 additions & 2 deletions resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@

def _weights_init(m):
classname = m.__class__.__name__
print(classname)
#print(classname)
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight)
init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
def __init__(self, lambd):
Expand Down
87 changes: 46 additions & 41 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import shutil
import time

import torch
Expand All @@ -17,6 +18,8 @@
and name.startswith("resnet")
and callable(resnet.__dict__[name]))

print(model_names)

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
choices=model_names,
Expand Down Expand Up @@ -119,10 +122,11 @@ def main():

if args.arch in ['resnet1202', 'resnet110']:
# for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up
# then switch back. In this implementation it will correspond for first epoch.
# then switch back. In this setup it will correspond for first epoch.
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr*0.1


if args.evaluate:
validate(val_loader, model, criterion)
return
Expand All @@ -148,10 +152,10 @@ def main():
'best_prec1': best_prec1,
}, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th'))

save_checkpoint({
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best, filename=os.path.join(args.save_dir, 'model.th'))
save_checkpoint({
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best, filename=os.path.join(args.save_dir, 'model.th'))


def train(train_loader, model, criterion, optimizer, epoch):
Expand All @@ -172,9 +176,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
# measure data loading time
data_time.update(time.time() - end)

target = target.cuda(async=True)
input_var = torch.autograd.Variable(input).cuda()
target_var = torch.autograd.Variable(target)
target = target.cuda()
input_var = input.cuda()
target_var = target
if args.half:
input_var = input_var.half()

Expand All @@ -191,8 +195,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
loss = loss.float()
# measure accuracy and record loss
prec1 = accuracy(output.data, target)[0]
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
Expand Down Expand Up @@ -220,37 +224,38 @@ def validate(val_loader, model, criterion):
model.eval()

end = time.time()
for i, (input, target) in enumerate(val_loader):
target = target.cuda(async=True)
input_var = torch.autograd.Variable(input, volatile=True).cuda()
target_var = torch.autograd.Variable(target, volatile=True)

if args.half:
input_var = input_var.half()

# compute output
output = model(input_var)
loss = criterion(output, target_var)

output = output.float()
loss = loss.float()

# measure accuracy and record loss
prec1 = accuracy(output.data, target)[0]
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1))
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.cuda()
input_var = input.cuda()
target_var = target.cuda()

if args.half:
input_var = input_var.half()

# compute output
output = model(input_var)
loss = criterion(output, target_var)

output = output.float()
loss = loss.float()

# measure accuracy and record loss
prec1 = accuracy(output.data, target)[0]
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))

# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1))

print(' * Prec@1 {top1.avg:.3f}'
.format(top1=top1))
Expand Down

0 comments on commit 4e4f8da

Please sign in to comment.