Skip to content

Commit

Permalink
disable alpha step
Browse files Browse the repository at this point in the history
  • Loading branch information
websitefingerprinting committed Aug 26, 2021
1 parent 31109cc commit cf7c5d4
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions src/mpl_df_wgan_train.py → src/mlp_df_wgan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ def parse_args():
parser.add_argument("--clip", action="store_true", default=False,
help="Whether to clip the burst size of the dataset before training")
parser.add_argument("--f_model", type=str, required=True, help="The directory of the pre-trained DF.")
parser.add_argument("--n_epochs", type=int, default=600, help="number of epochs of training")
parser.add_argument("--n_epochs", type=int, default=800, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=2, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=50, help="dimensionality of the latent space")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--alpha_max", type=float, default=0.02, help="Max ratio of f loss")
parser.add_argument("--alpha_step", type=float, default=0.0005, help="alpha growth step size")
parser.add_argument("--alpha_freq", type=int, default=20, help="alpha value update frequency")
parser.add_argument("--latent_dim", type=int, default=500, help="dimensionality of the latent space")
parser.add_argument("--n_critic", type=int, default=3, help="number of training steps for discriminator per iter")
parser.add_argument("--alpha", type=float, default=0.02, help="Max ratio of f loss")
parser.add_argument("--freq", type=int, default=20, help="Checkpoint every freq epochs")
parser.add_argument("--cuda_id", type=int, default=0, help="GPU ID")
args = parser.parse_args()
Expand Down Expand Up @@ -165,7 +163,7 @@ def my_clip(arr):
criterion = nn.CrossEntropyLoss()

# alpha initialization
alpha = 0
alpha = args.alpha

loss_checkpoints = {'generator': [], 'discriminator': [], 'dist': []}
for epoch in range(args.n_epochs):
Expand Down Expand Up @@ -281,9 +279,6 @@ def my_clip(arr):
generator_f_loss_epoch, generator_loss_combined_epoch, df_acc, w_dist_epoch)
)

if epoch % args.alpha_freq == 0:
alpha = min(alpha + args.alpha_step, args.alpha_max)

if (epoch == 0) or (epoch + 1) % args.freq == 0 or (w_dist_epoch <= w_dist_threshold and df_acc >= 0.9):
# every args.freq epoch, checkpoint
total_real = np.array(total_real)
Expand Down

0 comments on commit cf7c5d4

Please sign in to comment.