A simple JAX-based implementation of random search for locomotion tasks using MuJoCo XLA (MJX).
Clone the repository:
git clone https://github.com/thowell/rs
Optionally, create a conda environment:
conda create -n rs python=3.10
conda activate rs
pip install:
pip install -e .
Train cheetah in ~1 minute with Nvidia RTX 4090 on Ubuntu 22.04.4 LTS.
Run:
python rs/train.py --env cheetah --search --visualize --nsample 2048 --ntop 512 --niter 50 --neval 5 --nhorizon_search 200 --nhorizon_eval 1000 --random_step 0.1 --update_step 0.1
Output:
Settings:
environment: cheetah
nsample: 2048 | ntop: 512
niter: 50 | neval: 5
nhorizon_search: 200 | nhorizon_eval: 1000
random_step: 0.1 | update_step: 0.1
nenveval: 128
reward_shift: 0.0
Search:
iteration (10 / 50): reward = 1172.42 +- 1144.11 | time = 17.52 | avg episode length: 1000 / 1000 | global steps: 8232960 | steps/second: 470022
iteration (20 / 50): reward = 2947.71 +- 1237.87 | time = 5.58 | avg episode length: 1000 / 1000 | global steps: 16465920 | steps/second: 1474670
iteration (30 / 50): reward = 3152.07 +- 1401.50 | time = 5.58 | avg episode length: 1000 / 1000 | global steps: 24698880 | steps/second: 1475961
iteration (40 / 50): reward = 4175.49 +- 783.41 | time = 5.59 | avg episode length: 1000 / 1000 | global steps: 32931840 | steps/second: 1472244
iteration (50 / 50): reward = 4293.36 +- 784.80 | time = 5.59 | avg episode length: 1000 / 1000 | global steps: 41164800 | steps/second: 1473380
total time: 56.43
The pretrained policy can be visualized in MuJoCo's passive viewer:
python train.py --env cheetah --load pretrained/cheetah --visualize
Environments available:
- Ant
- based on ant_v5
- modified solver settings
- only contact between feet and floor
- no rewards or observations dependent on contact forces
- Cheetah
- based on half_cheetah_v5
- modified solver settings
- Humanoid
- based on humanoid_v5
- modified solver settings
- only contact between feet and floor
- no rewards or observations dependent on contact forces
- Walker
- based on walker2d_v5
- modified solver settings
- only contact between feet and floor
Note: run multiple times to find good policies.
First, change to rs/
directory:
cd rs
Search:
python train.py --env ant --search
Visualize policy checkpoint:
python train.py --env ant --mode visualize --load pretrained/ant
Search:
python train.py --env cheetah --search
Visualize policy checkpoint:
python train.py --env cheetah --load pretrained/cheetah --visualize
Search:
python train.py --env humanoid --search
Visualize policy checkpoint:
python train.py --env humanoid --load pretrained/humanoid --visualize
Search:
python train.py --env walker --search
Visualize policy checkpoint:
python train.py --env walker --load pretrained/walker --visualize
Setup:
--env
:ant
,cheetah
,humanoid
,walker
--search
: run random search to improve policy--checkpoint
: filename incheckpoint/
to save policy--load
: provide string incheckpoint/
directory to load policy from checkpoint--seed
: int for random number generation--visualize
: visualize policy
Search settings:
--nsample
: number of random directions to sample--ntop
: number of random directions to use for policy update--niter
: number of policy updates--neval
: number of policy evaluations during search--nhorizon_search
: number of environment steps during policy improvement--nhorizon_eval
: number of environment steps during policy evaluation--random_step
: step size for random direction during policy perturbation--update_step
: step size for policy update during policy improvement--nenveval
: number of environments for policy evaluation--reward_shift
: subtract baseline from per-timestep reward
update_step
random_step
nsample
ntop
- The environments are based on the v5 MuJoCo Gym environments but may not be exact in all details.
- The search settings are based on Simple random search provides a competitive approach to reinforcement learning: Table 9 but may not be exact in all details either.
This repository was developed to:
- understand the Augmented Random Search algorithm
- understand how to compute numerically stable running statistics
- understand the details of Gym environments
- experiment with code generation tools that are useful for improving development times, including: ChatGPT and Claude
- gain experience with MuJoCo XLA (MJX)
- gain experience with JAX
MuJoCo models use resources from Gymnasium and dm_control