Skip to content

Commit

Permalink
modify readme
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengYinan-AIR committed Dec 9, 2023
1 parent 261c602 commit c20d1c9
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 68 deletions.
111 changes: 50 additions & 61 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,74 +1,63 @@
# IDQL: Implicit Q-Learning as an Actor-Critic Method with Diffusion Policies

Paper Link : https://arxiv.org/abs/2304.10573

Check out https://github.com/philippe-eecs/JaxDDPM for an implementation of DDPMs in JAX for continuous spaces!

# Reproducing Results

[Offline Script Location.](launcher/examples/train_ddpm_iql_offline.py)

Run Line for each variant. Edit the script location above to change hyperparameters and environments to sweep over.

```
python3 launcher/examples/train_ddpm_iql_offline.py --variant 0...N
# Feasibility-Guided Safe Offline Reinforcement Learning

The official implementation of FISOR, which **represents a pioneering effort in considering hard constraints (Hamilton-Jacobi Reachability) within the safe offline RL setting**. FISOR transform the original tightly-coupled safety-constrained offline RL problem into
three decoupled simple supervised objectives: 1) Offline identification of the largest feasible region; 2) Optimal advantage learning; and 3) Optimal policy extraction via
guided diffusion model, enhancing both performance and stability.

## Branches Overview
| Branch name | Usage |
|:---: |:---: |
| [master](https://github.com/ZhengYinan-AIR/FISOR) | FISOR implementation for ``Point Robot``, ``Safety-Gymnasium`` and ``Bullet-Safety-Gym``; data quantity experiment; feasible region visualization. |
| [metadrive_imitation](https://github.com/ZhengYinan-AIR/FISOR/tree/metadrive_imitation) | FISOR implementation for ``MetaDrive``; data quantity experiment; imitation learning experiment. |
|

## Installation
``` Bash
conda create -n env_name python=3.9
conda activate FISOR
git clone https://github.com/ZhengYinan-AIR/FISOR.git
cd FISOR
pip install -r requirements.txt
```

[Finetune Script Location.](launcher/examples/train_ddpm_iql_finetune.py)

Run
```
python3 launcher/examples/train_ddpm_iql_finetune.py --variant 0...N
## Main results
Run
``` Bash
# OfflineCarButton1Gymnasium-v0
export XLA_PYTHON_CLIENT_PREALLOCATE=False
python launcher/examples/train_offline.py --env_id 0 --config configs/train_config.py:fisor
```
where ``env_id`` serves as an index for the [list of environments](https://github.com/ZhengYinan-AIR/FISOR/blob/master/env/env_list.py).

# Important File Locations

[Main run script were variant dictionary is passed.](/examples/states/train_diffusion_offline.py)

[DDPM Implementation.](/jaxrl5/networks/diffusion.py)

[LN_Resnet.](/jaxrl5/networks/resnet.py)

[DDPM IQL Learner.](/jaxrl5/agents/ddpm_iql/ddpm_iql_learner.py)

[![CircleCI](https://dl.circleci.com/status-badge/img/gh/ikostrikov/jaxrl5/tree/main.svg?style=svg&circle-token=668374ebe0f27c7ee70edbdfbbd1dd928725c01a)](https://dl.circleci.com/status-badge/redirect/gh/ikostrikov/jaxrl5/tree/main) [![codecov](https://codecov.io/gh/ikostrikov/jaxrl5/branch/main/graph/badge.svg?token=Q5QMIDZNZ3)](https://codecov.io/gh/ikostrikov/jaxrl5)

# Installation

Run
```bash
pip install --upgrade pip
pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
## Data Quantity Experiments
We can run [filter_data.py](https://github.com/ZhengYinan-AIR/FISOR/blob/master/filter_data.py) to generate offline data of varying volumes. We also can download the necessary offline datasets ([Download link](https://cloud.tsinghua.edu.cn/d/591cf8fd6d8649a89df4/)). Then run
``` Bash
python launcher/examples/train_offline.py --env_id 0 --config configs/train_config.py:fisor --ratio 0.1
```
where ``ratio`` refers to the proportion of the processed data to the original dataset.

See instructions for other versions of CUDA [here](https://github.com/google/jax#pip-installation-gpu-cuda).
## Feasible Region Visualization
We need to download the necessary offline dataset for ``Point Robot`` environment ([Download link](https://cloud.tsinghua.edu.cn/d/162d6fe92bde43e28676/)). Training FISOR in the ``Point Robot`` environment
``` Bash
python launcher/examples/train_offline.py --env_id 29 --config configs/train_config.py:fisor
```
Then visualize the feasible region by running [viz_map.py](https://github.com/ZhengYinan-AIR/FISOR/blob/master/launcher/viz/viz_map.py).
<p float="left">
<img src="assets/viz_map.png" width="800">
</p>

Based from a re-implementation of https://github.com/ikostrikov/jaxrl
## Bibtex

# Citations
Cite this paper
If you find our code and paper can help, please cite our paper as:
```
@misc{hansenestruch2023idql,
title={IDQL: Implicit Q-Learning as an Actor-Critic Method with Diffusion Policies},
author={Philippe Hansen-Estruch and Ilya Kostrikov and Michael Janner and Jakub Grudzien Kuba and Sergey Levine},
year={2023},
eprint={2304.10573},
archivePrefix={arXiv},
primaryClass={cs.LG}
@article{zheng2023feasibility,
title={Feasibility-Guided Safe Offline Reinforcement Learning},
author={Zheng, Yinan and Li, Jianxiong and Yu, Dongjie and Yang, Yujie and Li, Shengbo Eben and Zhan, Xianyuan and Liu, Jingjing},
journal={openreview},
year={2023}
}
```

Please also cite the JAXRL repo as well if you use this repo
```
@misc{jaxrl,
author = {Kostrikov, Ilya},
doi = {10.5281/zenodo.5535154},
month = {10},
title = {{JAXRL: Implementations of Reinforcement Learning algorithms in JAX}},
url = {https://github.com/ikostrikov/jaxrl},
year = {2021}
}
```
## Acknowledgements

python launcher/viz/viz_map.py --model_location 'results/PointRobot/ddpm_feasibility_hj_N16_minqc_2023-12-09_s54_486'
Parts of this code are adapted from [IDQL](https://github.com/philippe-eecs/IDQL) and [DRPO](https://github.com/ManUtdMoon/Distributional-Reachability-Policy-Optimization).
10 changes: 3 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
#gym[mujoco] >= 0.21.0, < 0.24.1
numpy >= 1.20.2
dm_control >= 1.0.0
# jax >= 0.3.13
jax == 0.4.9
flax == 0.5.3
ml_collections >= 0.1.0
tqdm >= 4.60.0
optax == 0.1.5
absl-py >= 0.12.0
scipy >= 1.6.0
wandb >= 0.12.14
tensorflow-probability >= 0.17.0
# dmcgym @ git+https://github.com/ikostrikov/dmcgym
moviepy
imageio
dsrl
matplotlib

0 comments on commit c20d1c9

Please sign in to comment.