Skip to content

Commit

Permalink
Instead of hardcoding the value of beta1 and beta2 in case of TTUR(Tw…
Browse files Browse the repository at this point in the history
…o-Timestep Update Rule), change the default values. This modification is for the readability of the code.
  • Loading branch information
taesungp committed Oct 17, 2019
1 parent 929194e commit 1aa3633
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
3 changes: 1 addition & 2 deletions models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ def create_optimizers(self, opt):
if opt.isTrain:
D_params = list(self.netD.parameters())

beta1, beta2 = opt.beta1, opt.beta2
if opt.no_TTUR:
beta1, beta2 = opt.beta1, opt.beta2
G_lr, D_lr = opt.lr, opt.lr
else:
beta1, beta2 = 0, 0.9
G_lr, D_lr = opt.lr / 2, opt.lr * 2

optimizer_G = torch.optim.Adam(G_params, lr=G_lr, betas=(beta1, beta2))
Expand Down
12 changes: 9 additions & 3 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@ def initialize(self, parser):
parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--optimizer', type=str, default='adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
parser.add_argument('--beta2', type=float, default=0.9, help='momentum term of adam')
parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')

# the default values for beta1 and beta2 differ by TTUR option
opt, _ = parser.parse_known_args()
if opt.no_TTUR:
parser.set_defaults(beta1=0.5, beta2=0.999)

parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.')

Expand All @@ -37,7 +44,6 @@ def initialize(self, parser):
parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)')
parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
parser.add_argument('--lambda_kld', type=float, default=0.05)
self.isTrain = True
return parser

0 comments on commit 1aa3633

Please sign in to comment.