Skip to content
/ ppx Public

A repo of Proximal Policy Optimisation (PPO) algorithms written in JAX

Notifications You must be signed in to change notification settings

BDEvan5/ppx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PPO Algorithms in JAX

This repo features a high-speed JAX implementation of the Proximal Policy Optimisation (PPO) algorithm. It also includes that batch size-invariant version, which uses exponetially weighted moving averages to remove dependence on the batch-size hyperparameter.

Usage

To run a system, you need to execute the following command:

python3 ppox/systems/ppo.py 

Since hydra is used for managing configurations, overide parameters can be passed as arguments to this command.

Installation

We recommend managing dependencies using a virtual environment, which can be installed with the following commands,

python3.9 -m venv venv
source venv/bin/activate

Install dependencies using the requirements.txt file:

pip install -r requirements.txt

The codebase is installed as a pip package with the following command:

pip install -e .

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Acknowledgements

The code is based on the format of Mava and is inspired from PureJaxRL and CleanRL.

About

A repo of Proximal Policy Optimisation (PPO) algorithms written in JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published