Unsupervised Learning by Competing Hidden Units, in JAX
WIP
pip install -r requirements.txt
pip install -U "jax[cuda12]" # Replace with your CUDA version
As always, a contained environment through e.g., conda
or venv
is recommended.
python jax_unsup_learning.py
See the evolution of the weights in the animation.mp4
file.