Skip to content

Commit

Permalink
add flush to print (pytorch#81)
Browse files Browse the repository at this point in the history
With flush, log info can appear immediately when it is directed to a pipe or file.
  • Loading branch information
fyu authored and apaszke committed Feb 24, 2017
1 parent 1c16b6c commit 409a726
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions imagenet/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import argparse
import os
import shutil
Expand Down Expand Up @@ -59,10 +60,10 @@ def main():

# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
print("=> using pre-trained model '{}'".format(args.arch), flush=True)
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
print("=> creating model '{}'".format(args.arch), flush=True)
model = models.__dict__[args.arch]()

if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
Expand All @@ -74,15 +75,16 @@ def main():
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
print("=> loading checkpoint '{}'".format(args.resume), flush=True)
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch']))
.format(args.evaluate, checkpoint['epoch']), flush=True)
else:
print("=> no checkpoint found at '{}'".format(args.resume))
print("=> no checkpoint found at '{}'".format(args.resume),
flush=True)

cudnn.benchmark = True

Expand Down Expand Up @@ -189,7 +191,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
data_time=data_time, loss=losses, top1=top1, top5=top5),
flush=True)


def validate(val_loader, model, criterion):
Expand Down Expand Up @@ -228,10 +231,10 @@ def validate(val_loader, model, criterion):
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
top1=top1, top5=top5), flush=True)

print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
.format(top1=top1, top5=top5), flush=True)

return top1.avg

Expand Down

0 comments on commit 409a726

Please sign in to comment.