-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
39 changed files
with
5,082 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
__pycache__/ | ||
.ipynb_checkpoints/ | ||
exp_local | ||
output | ||
nbs | ||
tmp/ | ||
notebooks/ | ||
slurm/ | ||
data/ | ||
datasets/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,99 @@ | ||
# exorl | ||
|
||
|
||
# ExORL: Exploratory Data for Offline Reinforcement Learning | ||
|
||
This is an original PyTorch implementation of the ExORL framework from | ||
|
||
[Don't Change the Algorithm, Change the Data: Exploratory Data for Offline Reinforcement Learning](https://arxiv.org/abs/2201.13425) by | ||
|
||
[Denis Yarats*](https://cs.nyu.edu/~dy1042/), [David Brandfonbrener*](https://davidbrandfonbrener.github.io/), [Hao Liu](https://www.haoliu.site/), [Misha Laskin](https://www.mishalaskin.com/), [Pieter Abbeel](https://people.eecs.berkeley.edu/~pabbeel/), [Alessandro Lazaric](http://chercheurs.lille.inria.fr/~lazaric/Webpage/Home/Home.html), and [Lerrel Pinto](https://www.lerrelpinto.com). | ||
|
||
*Equal contribution. | ||
|
||
## Prerequisites | ||
|
||
Install [MuJoCo](http://www.mujoco.org/) if it is not already the case: | ||
|
||
* Obtain a license on the [MuJoCo website](https://www.roboti.us/license.html). | ||
* Download MuJoCo binaries [here](https://www.roboti.us/index.html). | ||
* Unzip the downloaded archive into `~/.mujoco/mujoco200` and place your license key file `mjkey.txt` at `~/.mujoco`. | ||
* Use the env variables `MUJOCO_PY_MJKEY_PATH` and `MUJOCO_PY_MUJOCO_PATH` to specify the MuJoCo license key path and the MuJoCo directory path. | ||
* Append the MuJoCo subdirectory bin path into the env variable `LD_LIBRARY_PATH`. | ||
|
||
Install the following libraries: | ||
```sh | ||
sudo apt update | ||
sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 unzip | ||
``` | ||
|
||
Install dependencies: | ||
```sh | ||
conda env create -f conda_env.yml | ||
conda activate exorl | ||
``` | ||
|
||
## Datasets | ||
We provide exploratory datasets for 6 DeepMind Control Stuite domains | ||
| Domain | Dataset name | Available task names | | ||
|---|---|---| | ||
| Cartpole | `cartpole` | `cartpole_balance`, `cartpole_balance_sparse`, `cartpole_swingup`, `cartpole_swingup_sparse` | | ||
| Cheetah | `cheetah` | `cheetah_run`, `cheetah_run_backward` | | ||
| Jaco Arm | `jaco` | `jaco_reach_top_left`, `jaco_reach_top_right`, `jaco_reach_bottom_left`, `jaco_reach_bottom_right` | | ||
| Point Mass Maze | `point_mass_maze` | `point_mass_maze_reach_top_left`, `point_mass_maze_reach_top_right`, `point_mass_maze_reach_bottom_left`, `point_mass_maze_reach_bottom_right` | | ||
| Quadruped | `quadruped` | `quadruped_walk`, `quadruped_run` | | ||
| Walker | `walker` | `walker_stand`, `walker_walk`, `walker_run` | | ||
|
||
|
||
For each domain we collected datasets by running 9 unsupervised RL algorithms from [URLB](https://github.com/rll-research/url_benchmark) for total of `10M` steps. Here is the list of algorithms | ||
| Unsupervised RL method | Name | Paper | | ||
|---|---|---| | ||
| APS | `aps` | [paper](http://proceedings.mlr.press/v139/liu21b.html)| | ||
| APT(ICM) | `icm_apt` | [paper](https://arxiv.org/abs/2103.04551)| | ||
| DIAYN | `diayn` |[paper](https://arxiv.org/abs/1802.06070)| | ||
| Disagreement | `disagreement` | [paper](https://arxiv.org/abs/1906.04161) | | ||
| ICM | `icm` | [paper](https://arxiv.org/abs/1705.05363)| | ||
| ProtoRL | `proto` | [paper](https://arxiv.org/abs/2102.11271)| | ||
| Random | `random` | N/A | | ||
| RND | `agent=rnd` | [paper](https://arxiv.org/abs/1810.12894) | | ||
| SMM | `agent=smm` | [paper](https://arxiv.org/abs/1906.05274) | | ||
|
||
You can download a dataset by running `./download.sh <DOMAIN> <ALGO>`, for example to download ProtoRL dataset for Walker, run | ||
```sh | ||
./download.sh walker proto | ||
``` | ||
The script will download the dataset from S3 and store it under `datasets/walker/proto/`, where you can find episodes (under `buffer`) and episode videos (under `video`). | ||
|
||
## Offline RL training | ||
We also provide implementation of 5 offline RL algorithms for evaluating the datasets | ||
| Offline RL method | Name | Paper | | ||
|---|---|---| | ||
| Behavior Cloning | `bc` | [paper](https://proceedings.neurips.cc/paper/1988/file/812b4ba287f5ee0bc9d43bbf5bbe87fb-Paper.pdf)| | ||
| CQL | `cql` | [paper](https://arxiv.org/pdf/2006.04779.pdf)| | ||
| CRR | `crr` |[paper](https://arxiv.org/pdf/2006.15134.pdf)| | ||
| TD3+BC | `td3_bc` | [paper](https://arxiv.org/pdf/2106.06860.pdf) | | ||
| TD3 | `td3` | [paper](https://arxiv.org/pdf/1802.09477.pdf)| | ||
|
||
After downloading required datasets, you can evaluate it using offline RL methon for a specific task. For example, to evaluate a dataset collected by ProtoRL on Walker for the waling task using TD3+BC you can run | ||
```sh | ||
python train_offline.py agent=td3_bc expl_agent=proto task=walker_walk | ||
``` | ||
Logs are stored in the `output` folder. To launch tensorboard run: | ||
```sh | ||
tensorboard --logdir output | ||
``` | ||
|
||
## Citation | ||
|
||
If you use this repo in your research, please consider citing the paper as follows: | ||
``` | ||
@article{yarats2022exorl, | ||
title={Don't Change the Algorithm, Change the Data: Exploratory Data for Offline Reinforcement Learning}, | ||
author={Denis Yarats, David Brandfonbrener, Hao Liu, Michael Laskin, Pieter Abbeel, Alessandro Lazaric, Lerrel Pinto}, | ||
journal={arXiv preprint arXiv:2201.13425}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
|
||
## License | ||
The majority of ExORL is licensed under the MIT license, however portions of the project are available under separate license terms: DeepMind is licensed under the Apache 2.0 license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import hydra | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from collections import OrderedDict | ||
import functools | ||
|
||
import utils | ||
from dm_control.utils import rewards | ||
|
||
|
||
class Actor(nn.Module): | ||
def __init__(self, obs_dim, action_dim, hidden_dim): | ||
super().__init__() | ||
|
||
self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim), | ||
nn.LayerNorm(hidden_dim), nn.Tanh(), | ||
nn.Linear(hidden_dim, hidden_dim), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(hidden_dim, action_dim)) | ||
|
||
self.apply(utils.weight_init) | ||
|
||
def forward(self, obs, std): | ||
mu = self.policy(obs) | ||
mu = torch.tanh(mu) | ||
std = torch.ones_like(mu) * std | ||
|
||
dist = utils.TruncatedNormal(mu, std) | ||
return dist | ||
|
||
|
||
class BCAgent: | ||
def __init__(self, | ||
name, | ||
obs_shape, | ||
action_shape, | ||
device, | ||
lr, | ||
hidden_dim, | ||
batch_size, | ||
stddev_schedule, | ||
use_tb, | ||
has_next_action=False): | ||
self.action_dim = action_shape[0] | ||
self.hidden_dim = hidden_dim | ||
self.lr = lr | ||
self.device = device | ||
self.stddev_schedule = stddev_schedule | ||
self.use_tb = use_tb | ||
|
||
# models | ||
self.actor = Actor(obs_shape[0], action_shape[0], | ||
hidden_dim).to(device) | ||
|
||
# optimizers | ||
self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr) | ||
|
||
self.train() | ||
|
||
def train(self, training=True): | ||
self.training = training | ||
self.actor.train(training) | ||
|
||
def act(self, obs, step, eval_mode): | ||
obs = torch.as_tensor(obs, device=self.device).unsqueeze(0) | ||
stddev = utils.schedule(self.stddev_schedule, step) | ||
policy = self.actor(obs, stddev) | ||
if eval_mode: | ||
action = policy.mean | ||
else: | ||
action = policy.sample(clip=None) | ||
if step < self.num_expl_steps: | ||
action.uniform_(-1.0, 1.0) | ||
return action.cpu().numpy()[0] | ||
|
||
def update_actor(self, obs, action, step): | ||
metrics = dict() | ||
|
||
stddev = utils.schedule(self.stddev_schedule, step) | ||
policy = self.actor(obs, stddev) | ||
|
||
log_prob = policy.log_prob(action).sum(-1, keepdim=True) | ||
actor_loss = (-log_prob).mean() | ||
|
||
self.actor_opt.zero_grad(set_to_none=True) | ||
actor_loss.backward() | ||
self.actor_opt.step() | ||
|
||
if self.use_tb: | ||
metrics['actor_loss'] = actor_loss.item() | ||
metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item() | ||
|
||
return metrics | ||
|
||
def update(self, replay_iter, step): | ||
metrics = dict() | ||
|
||
batch = next(replay_iter) | ||
obs, action, reward, discount, next_obs = utils.to_torch( | ||
batch, self.device) | ||
|
||
if self.use_tb: | ||
metrics['batch_reward'] = reward.mean().item() | ||
|
||
# update actor | ||
metrics.update(self.update_actor(obs, action, step)) | ||
|
||
return metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# @package agent | ||
_target_: agent.bc.BCAgent | ||
name: bc | ||
obs_shape: ??? # to be specified later | ||
action_shape: ??? # to be specified later | ||
device: ${device} | ||
lr: 1e-4 | ||
use_tb: ${use_tb} | ||
hidden_dim: 1024 | ||
batch_size: 1024 # 256 for pixels | ||
has_next_action: False | ||
stddev_schedule: 0.2 |
Oops, something went wrong.