This repo features a high-speed JAX implementation of the Proximal Policy Optimisation (PPO) algorithm. It also includes that batch size-invariant version, which uses exponetially weighted moving averages to remove dependence on the batch-size hyperparameter.
To run a system, you need to execute the following command:
python3 ppox/systems/ppo.py
Since hydra is used for managing configurations, overide parameters can be passed as arguments to this command.
We recommend managing dependencies using a virtual environment, which can be installed with the following commands,
python3.9 -m venv venv
source venv/bin/activate
Install dependencies using the requirements.txt file:
pip install -r requirements.txt
The codebase is installed as a pip package with the following command:
pip install -e .
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
The code is based on the format of Mava and is inspired from PureJaxRL and CleanRL.