Skip to content

Commit

Permalink
Fix bugs in LR search
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Sep 3, 2024
1 parent 5a40ca9 commit 2158d28
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
Binary file modified lr_search.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 6 additions & 5 deletions src/jif/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def main(*args, **kwargs):
from matplotlib import pyplot as plt
from collections import defaultdict
from itertools import product
lrs = np.linspace(1e-5, 2e-3, 10).tolist()
# lrs = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2]
lrs = np.exp(np.linspace(np.log(5e-4), np.log(3e-3), 10))
models = ["small", "medium", "big"]
colors = ["r", "g", "b"]
lrs_sampled = defaultdict(list)
Expand All @@ -256,14 +257,14 @@ def main(*args, **kwargs):
for lr, model in tqdm(all_configs):
print("Training", model, "model with", lr, "learning rate")
lrs_sampled[model].append(lr)
log_dict = train(*args, **kwargs | {"n_steps": 2_000, "lr": lr, "size": "small"},
log_dict = train(*args, **kwargs | {"n_steps": 2_000, "lr": lr, "size": model},
quiet=True, profile=False)
losses[model].append(log_dict["loss"])
losses[model].append(log_dict["loss_sma"])

for model, color in zip(models, colors):
plt.plot(lrs_sampled[model], losses[model], label=model, c=color)
plt.scatter(lrs_sampled[model], losses[model], marker="x", c=color)
plt.scatter(lrs_sampled[model], losses[model], label=model, c=color)
plt.xlim(lrs[0], lrs[-1])
plt.xscale("log")
plt.xlabel("Learning rate")
plt.ylabel("Final loss")
plt.legend()
Expand Down

0 comments on commit 2158d28

Please sign in to comment.