Skip to content

Commit

Permalink
Learning rate tuning ??
Browse files Browse the repository at this point in the history
  • Loading branch information
neverix committed Sep 2, 2024
1 parent 90fd37f commit 5a40ca9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
Binary file modified lr_search.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 4 additions & 1 deletion src/jif/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,11 @@ def main(*args, **kwargs):
log_dict = train(*args, **kwargs | {"n_steps": 2_000, "lr": lr, "size": "small"},
quiet=True, profile=False)
losses[model].append(log_dict["loss"])

for model, color in zip(models, colors):
plt.plot(lrs_sampled[model], losses[model], label=model, c=color)
plt.scatter(lrs_sampled[model], losses[model], label=model, marker="x", c=color)
plt.scatter(lrs_sampled[model], losses[model], marker="x", c=color)
plt.xlim(lrs[0], lrs[-1])
plt.xlabel("Learning rate")
plt.ylabel("Final loss")
plt.legend()
Expand All @@ -270,4 +272,5 @@ def main(*args, **kwargs):


if __name__ == "__main__":
# fire.Fire(train); exit()
fire.Fire(main)

0 comments on commit 5a40ca9

Please sign in to comment.