A joint project by: Sean McLeish, John Kirchenbauer, David Yu Miller, Siddharth Singh, Abhinav Bhatele, Micah Goldblum, Ashwinee Panda and Tom Goldstein.
To cite our work, please use this bibtex.
@article{mcleish2024gemstones
title={Gemstones: A Model Suite for Multi-Faceted Scaling Laws},
author={Sean McLeish and John Kirchenbauer and David Yu Miller and Siddharth Singh and Abhinav Bhatele and Micah Goldblum and Ashwinee Panda and Tom Goldstein},
journal={arXiv preprint arXiv:2502.06857},
year={2025},
url={https://arxiv.org/abs/2502.06857},
}
We developed in Python 3.10.4, to install run:
git clone [email protected]:mcleish7/gemstone-scaling-laws
cd gemstone-scaling-laws
pip install .
All of our training runs were completed on Frontier at the Oak Ridge National Laboratory. We train in two hour intervals over multiple nodes of AMD MI250X GPUs logging to wandb. We extract data from wandb using wandb_data_extraction.py, where we stich the two hour chunks back into complete runs. However, our wandb space is currently private so we provide the intermediate dataframe after our we process the models, this is close to raw form apart from the runs being grouped.
We provide bash commands to run all code needed in shells/fitting.sh. We also give the outputs in json from in the ./parameters
folders as this is a compute intensive process.
We use approach_1.py to fit approach 1 laws. This is a quick process so we also plot at the same time.
We use depth_width.py to fit approach 3 laws. We provide our outputs in parameters/, parameters_delta-3/ and parameters_delta-4/.
We provide bash commands to run all code needed in plotting.sh, due to the large compute requirements to run the grid searches in many parts of this code, we provide our cache files here, please read the README there for how to use it. This should be placed:
gemstone-scaling-laws
└── plotters
└── data_cache
We use approach_3_brute_force.py to plot the output of approach 3 width-depth laws using brute force search.
- The rainbow of scaling laws is plotted inside of rainbow.py. This requires the correct approach 1 and approach 3 laws to have been created.
- Plotting of overtraining parabolas is done in overtrain_parabola.py. This requires the correct part of approach_3_brute_force.py to have ran before hand to cache outputs correctly. Caution: this is currently hard coded to point to only the files we use in the paper.
- Overspending analysis is done inside of approach_1.py.
- Chinchilla Reduced Sampling is visualised in chinchilla_reduced_sampling.py.
- Analysis of delta and grid search sizes in done in slope_analysis.py
- Plotting for feasible model shapes is done in plot_feasible_model_shapes_paper_plots.ipynb.
- Plotting for (\mu P) is done in plot_mup.py.
- Loss curves are plotted in wandb_data_plot.py.
Please, feel free to contact us with any questions, or open an issue on Github.
We used Resolving Discrepancies in Compute-Optimal Scaling of Language Models to guide the format of this code base.
We use the Epoch AI Analyzing Chinchilla data in data/
.