Skip to content

Commit

Permalink
Init (#1)
Browse files Browse the repository at this point in the history
Init
  • Loading branch information
denisyarats authored Feb 6, 2022
1 parent 8edecee commit 9cb4334
Show file tree
Hide file tree
Showing 39 changed files with 5,082 additions and 1 deletion.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__pycache__/
.ipynb_checkpoints/
exp_local
output
nbs
tmp/
notebooks/
slurm/
data/
datasets/
100 changes: 99 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,99 @@
# exorl


# ExORL: Exploratory Data for Offline Reinforcement Learning

This is an original PyTorch implementation of the ExORL framework from

[Don't Change the Algorithm, Change the Data: Exploratory Data for Offline Reinforcement Learning](https://arxiv.org/abs/2201.13425) by

[Denis Yarats*](https://cs.nyu.edu/~dy1042/), [David Brandfonbrener*](https://davidbrandfonbrener.github.io/), [Hao Liu](https://www.haoliu.site/), [Misha Laskin](https://www.mishalaskin.com/), [Pieter Abbeel](https://people.eecs.berkeley.edu/~pabbeel/), [Alessandro Lazaric](http://chercheurs.lille.inria.fr/~lazaric/Webpage/Home/Home.html), and [Lerrel Pinto](https://www.lerrelpinto.com).

*Equal contribution.

## Prerequisites

Install [MuJoCo](http://www.mujoco.org/) if it is not already the case:

* Obtain a license on the [MuJoCo website](https://www.roboti.us/license.html).
* Download MuJoCo binaries [here](https://www.roboti.us/index.html).
* Unzip the downloaded archive into `~/.mujoco/mujoco200` and place your license key file `mjkey.txt` at `~/.mujoco`.
* Use the env variables `MUJOCO_PY_MJKEY_PATH` and `MUJOCO_PY_MUJOCO_PATH` to specify the MuJoCo license key path and the MuJoCo directory path.
* Append the MuJoCo subdirectory bin path into the env variable `LD_LIBRARY_PATH`.

Install the following libraries:
```sh
sudo apt update
sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 unzip
```

Install dependencies:
```sh
conda env create -f conda_env.yml
conda activate exorl
```

## Datasets
We provide exploratory datasets for 6 DeepMind Control Stuite domains
| Domain | Dataset name | Available task names |
|---|---|---|
| Cartpole | `cartpole` | `cartpole_balance`, `cartpole_balance_sparse`, `cartpole_swingup`, `cartpole_swingup_sparse` |
| Cheetah | `cheetah` | `cheetah_run`, `cheetah_run_backward` |
| Jaco Arm | `jaco` | `jaco_reach_top_left`, `jaco_reach_top_right`, `jaco_reach_bottom_left`, `jaco_reach_bottom_right` |
| Point Mass Maze | `point_mass_maze` | `point_mass_maze_reach_top_left`, `point_mass_maze_reach_top_right`, `point_mass_maze_reach_bottom_left`, `point_mass_maze_reach_bottom_right` |
| Quadruped | `quadruped` | `quadruped_walk`, `quadruped_run` |
| Walker | `walker` | `walker_stand`, `walker_walk`, `walker_run` |


For each domain we collected datasets by running 9 unsupervised RL algorithms from [URLB](https://github.com/rll-research/url_benchmark) for total of `10M` steps. Here is the list of algorithms
| Unsupervised RL method | Name | Paper |
|---|---|---|
| APS | `aps` | [paper](http://proceedings.mlr.press/v139/liu21b.html)|
| APT(ICM) | `icm_apt` | [paper](https://arxiv.org/abs/2103.04551)|
| DIAYN | `diayn` |[paper](https://arxiv.org/abs/1802.06070)|
| Disagreement | `disagreement` | [paper](https://arxiv.org/abs/1906.04161) |
| ICM | `icm` | [paper](https://arxiv.org/abs/1705.05363)|
| ProtoRL | `proto` | [paper](https://arxiv.org/abs/2102.11271)|
| Random | `random` | N/A |
| RND | `agent=rnd` | [paper](https://arxiv.org/abs/1810.12894) |
| SMM | `agent=smm` | [paper](https://arxiv.org/abs/1906.05274) |

You can download a dataset by running `./download.sh <DOMAIN> <ALGO>`, for example to download ProtoRL dataset for Walker, run
```sh
./download.sh walker proto
```
The script will download the dataset from S3 and store it under `datasets/walker/proto/`, where you can find episodes (under `buffer`) and episode videos (under `video`).

## Offline RL training
We also provide implementation of 5 offline RL algorithms for evaluating the datasets
| Offline RL method | Name | Paper |
|---|---|---|
| Behavior Cloning | `bc` | [paper](https://proceedings.neurips.cc/paper/1988/file/812b4ba287f5ee0bc9d43bbf5bbe87fb-Paper.pdf)|
| CQL | `cql` | [paper](https://arxiv.org/pdf/2006.04779.pdf)|
| CRR | `crr` |[paper](https://arxiv.org/pdf/2006.15134.pdf)|
| TD3+BC | `td3_bc` | [paper](https://arxiv.org/pdf/2106.06860.pdf) |
| TD3 | `td3` | [paper](https://arxiv.org/pdf/1802.09477.pdf)|

After downloading required datasets, you can evaluate it using offline RL methon for a specific task. For example, to evaluate a dataset collected by ProtoRL on Walker for the waling task using TD3+BC you can run
```sh
python train_offline.py agent=td3_bc expl_agent=proto task=walker_walk
```
Logs are stored in the `output` folder. To launch tensorboard run:
```sh
tensorboard --logdir output
```

## Citation

If you use this repo in your research, please consider citing the paper as follows:
```
@article{yarats2022exorl,
title={Don't Change the Algorithm, Change the Data: Exploratory Data for Offline Reinforcement Learning},
author={Denis Yarats, David Brandfonbrener, Hao Liu, Michael Laskin, Pieter Abbeel, Alessandro Lazaric, Lerrel Pinto},
journal={arXiv preprint arXiv:2201.13425},
year={2022}
}
```


## License
The majority of ExORL is licensed under the MIT license, however portions of the project are available under separate license terms: DeepMind is licensed under the Apache 2.0 license.
110 changes: 110 additions & 0 deletions agent/bc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import functools

import utils
from dm_control.utils import rewards


class Actor(nn.Module):
def __init__(self, obs_dim, action_dim, hidden_dim):
super().__init__()

self.policy = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim), nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, action_dim))

self.apply(utils.weight_init)

def forward(self, obs, std):
mu = self.policy(obs)
mu = torch.tanh(mu)
std = torch.ones_like(mu) * std

dist = utils.TruncatedNormal(mu, std)
return dist


class BCAgent:
def __init__(self,
name,
obs_shape,
action_shape,
device,
lr,
hidden_dim,
batch_size,
stddev_schedule,
use_tb,
has_next_action=False):
self.action_dim = action_shape[0]
self.hidden_dim = hidden_dim
self.lr = lr
self.device = device
self.stddev_schedule = stddev_schedule
self.use_tb = use_tb

# models
self.actor = Actor(obs_shape[0], action_shape[0],
hidden_dim).to(device)

# optimizers
self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)

self.train()

def train(self, training=True):
self.training = training
self.actor.train(training)

def act(self, obs, step, eval_mode):
obs = torch.as_tensor(obs, device=self.device).unsqueeze(0)
stddev = utils.schedule(self.stddev_schedule, step)
policy = self.actor(obs, stddev)
if eval_mode:
action = policy.mean
else:
action = policy.sample(clip=None)
if step < self.num_expl_steps:
action.uniform_(-1.0, 1.0)
return action.cpu().numpy()[0]

def update_actor(self, obs, action, step):
metrics = dict()

stddev = utils.schedule(self.stddev_schedule, step)
policy = self.actor(obs, stddev)

log_prob = policy.log_prob(action).sum(-1, keepdim=True)
actor_loss = (-log_prob).mean()

self.actor_opt.zero_grad(set_to_none=True)
actor_loss.backward()
self.actor_opt.step()

if self.use_tb:
metrics['actor_loss'] = actor_loss.item()
metrics['actor_ent'] = policy.entropy().sum(dim=-1).mean().item()

return metrics

def update(self, replay_iter, step):
metrics = dict()

batch = next(replay_iter)
obs, action, reward, discount, next_obs = utils.to_torch(
batch, self.device)

if self.use_tb:
metrics['batch_reward'] = reward.mean().item()

# update actor
metrics.update(self.update_actor(obs, action, step))

return metrics
12 changes: 12 additions & 0 deletions agent/bc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# @package agent
_target_: agent.bc.BCAgent
name: bc
obs_shape: ??? # to be specified later
action_shape: ??? # to be specified later
device: ${device}
lr: 1e-4
use_tb: ${use_tb}
hidden_dim: 1024
batch_size: 1024 # 256 for pixels
has_next_action: False
stddev_schedule: 0.2
Loading

0 comments on commit 9cb4334

Please sign in to comment.