Skip to content

Commit

Permalink
Adding deterministic option to main_fp16_optimizer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
definitelynotmcarilli committed Nov 30, 2018
1 parent b436213 commit 2a8022c
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions examples/imagenet/main_fp16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
'--static-loss-scale.')
parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')

parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
Expand All @@ -94,6 +95,12 @@ def fast_collate(batch):

best_prec1 = 0
args = parser.parse_args()

if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)

def main():
global best_prec1, args

Expand Down Expand Up @@ -125,6 +132,7 @@ def main():
else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()

if args.sync_bn:
import apex
print("using apex synced BN")
Expand Down

0 comments on commit 2a8022c

Please sign in to comment.