Skip to content

Commit

Permalink
FEAT: change optimizer and add learning rate argument option
Browse files Browse the repository at this point in the history
  • Loading branch information
GunwooHan committed May 31, 2024
1 parent fb9ba6a commit dc8ddc8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions data/bapps.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=True
)

Expand All @@ -86,6 +87,7 @@ def val_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=False
)

Expand All @@ -95,6 +97,7 @@ def test_dataloader(self):
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=True,
pin_memory=True,
shuffle=False
)

Expand Down
3 changes: 2 additions & 1 deletion single_reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--learninig_rate', type=float, default=0.0001)

parser.add_argument('--reconstruction_target', type=str, default='single_reconstruction_sample.jpeg')
parser.add_argument('--lpips_model_path', type=str, default='checkpoints/lpips/vgg-epoch=09-val/score=80.11.ckpt')
Expand Down Expand Up @@ -111,7 +112,7 @@ def forward(self, x):
pass

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=0.0001, betas=(0.9, 0.999))
optimizer = optim.AdamW(self.parameters(), lr=self.args.learning_rate)
return optimizer

def encode_text(self, text):
Expand Down

0 comments on commit dc8ddc8

Please sign in to comment.