forked from ZhengYinan-AIR/FISOR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
261c602
commit c20d1c9
Showing
2 changed files
with
53 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,63 @@ | ||
# IDQL: Implicit Q-Learning as an Actor-Critic Method with Diffusion Policies | ||
|
||
Paper Link : https://arxiv.org/abs/2304.10573 | ||
|
||
Check out https://github.com/philippe-eecs/JaxDDPM for an implementation of DDPMs in JAX for continuous spaces! | ||
|
||
# Reproducing Results | ||
|
||
[Offline Script Location.](launcher/examples/train_ddpm_iql_offline.py) | ||
|
||
Run Line for each variant. Edit the script location above to change hyperparameters and environments to sweep over. | ||
|
||
``` | ||
python3 launcher/examples/train_ddpm_iql_offline.py --variant 0...N | ||
# Feasibility-Guided Safe Offline Reinforcement Learning | ||
|
||
The official implementation of FISOR, which **represents a pioneering effort in considering hard constraints (Hamilton-Jacobi Reachability) within the safe offline RL setting**. FISOR transform the original tightly-coupled safety-constrained offline RL problem into | ||
three decoupled simple supervised objectives: 1) Offline identification of the largest feasible region; 2) Optimal advantage learning; and 3) Optimal policy extraction via | ||
guided diffusion model, enhancing both performance and stability. | ||
|
||
## Branches Overview | ||
| Branch name | Usage | | ||
|:---: |:---: | | ||
| [master](https://github.com/ZhengYinan-AIR/FISOR) | FISOR implementation for ``Point Robot``, ``Safety-Gymnasium`` and ``Bullet-Safety-Gym``; data quantity experiment; feasible region visualization. | | ||
| [metadrive_imitation](https://github.com/ZhengYinan-AIR/FISOR/tree/metadrive_imitation) | FISOR implementation for ``MetaDrive``; data quantity experiment; imitation learning experiment. | | ||
| | ||
|
||
## Installation | ||
``` Bash | ||
conda create -n env_name python=3.9 | ||
conda activate FISOR | ||
git clone https://github.com/ZhengYinan-AIR/FISOR.git | ||
cd FISOR | ||
pip install -r requirements.txt | ||
``` | ||
|
||
[Finetune Script Location.](launcher/examples/train_ddpm_iql_finetune.py) | ||
|
||
Run | ||
``` | ||
python3 launcher/examples/train_ddpm_iql_finetune.py --variant 0...N | ||
## Main results | ||
Run | ||
``` Bash | ||
# OfflineCarButton1Gymnasium-v0 | ||
export XLA_PYTHON_CLIENT_PREALLOCATE=False | ||
python launcher/examples/train_offline.py --env_id 0 --config configs/train_config.py:fisor | ||
``` | ||
where ``env_id`` serves as an index for the [list of environments](https://github.com/ZhengYinan-AIR/FISOR/blob/master/env/env_list.py). | ||
|
||
# Important File Locations | ||
|
||
[Main run script were variant dictionary is passed.](/examples/states/train_diffusion_offline.py) | ||
|
||
[DDPM Implementation.](/jaxrl5/networks/diffusion.py) | ||
|
||
[LN_Resnet.](/jaxrl5/networks/resnet.py) | ||
|
||
[DDPM IQL Learner.](/jaxrl5/agents/ddpm_iql/ddpm_iql_learner.py) | ||
|
||
[![CircleCI](https://dl.circleci.com/status-badge/img/gh/ikostrikov/jaxrl5/tree/main.svg?style=svg&circle-token=668374ebe0f27c7ee70edbdfbbd1dd928725c01a)](https://dl.circleci.com/status-badge/redirect/gh/ikostrikov/jaxrl5/tree/main) [![codecov](https://codecov.io/gh/ikostrikov/jaxrl5/branch/main/graph/badge.svg?token=Q5QMIDZNZ3)](https://codecov.io/gh/ikostrikov/jaxrl5) | ||
|
||
# Installation | ||
|
||
Run | ||
```bash | ||
pip install --upgrade pip | ||
pip install -r requirements.txt | ||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
## Data Quantity Experiments | ||
We can run [filter_data.py](https://github.com/ZhengYinan-AIR/FISOR/blob/master/filter_data.py) to generate offline data of varying volumes. We also can download the necessary offline datasets ([Download link](https://cloud.tsinghua.edu.cn/d/591cf8fd6d8649a89df4/)). Then run | ||
``` Bash | ||
python launcher/examples/train_offline.py --env_id 0 --config configs/train_config.py:fisor --ratio 0.1 | ||
``` | ||
where ``ratio`` refers to the proportion of the processed data to the original dataset. | ||
|
||
See instructions for other versions of CUDA [here](https://github.com/google/jax#pip-installation-gpu-cuda). | ||
## Feasible Region Visualization | ||
We need to download the necessary offline dataset for ``Point Robot`` environment ([Download link](https://cloud.tsinghua.edu.cn/d/162d6fe92bde43e28676/)). Training FISOR in the ``Point Robot`` environment | ||
``` Bash | ||
python launcher/examples/train_offline.py --env_id 29 --config configs/train_config.py:fisor | ||
``` | ||
Then visualize the feasible region by running [viz_map.py](https://github.com/ZhengYinan-AIR/FISOR/blob/master/launcher/viz/viz_map.py). | ||
<p float="left"> | ||
<img src="assets/viz_map.png" width="800"> | ||
</p> | ||
|
||
Based from a re-implementation of https://github.com/ikostrikov/jaxrl | ||
## Bibtex | ||
|
||
# Citations | ||
Cite this paper | ||
If you find our code and paper can help, please cite our paper as: | ||
``` | ||
@misc{hansenestruch2023idql, | ||
title={IDQL: Implicit Q-Learning as an Actor-Critic Method with Diffusion Policies}, | ||
author={Philippe Hansen-Estruch and Ilya Kostrikov and Michael Janner and Jakub Grudzien Kuba and Sergey Levine}, | ||
year={2023}, | ||
eprint={2304.10573}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.LG} | ||
@article{zheng2023feasibility, | ||
title={Feasibility-Guided Safe Offline Reinforcement Learning}, | ||
author={Zheng, Yinan and Li, Jianxiong and Yu, Dongjie and Yang, Yujie and Li, Shengbo Eben and Zhan, Xianyuan and Liu, Jingjing}, | ||
journal={openreview}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
Please also cite the JAXRL repo as well if you use this repo | ||
``` | ||
@misc{jaxrl, | ||
author = {Kostrikov, Ilya}, | ||
doi = {10.5281/zenodo.5535154}, | ||
month = {10}, | ||
title = {{JAXRL: Implementations of Reinforcement Learning algorithms in JAX}}, | ||
url = {https://github.com/ikostrikov/jaxrl}, | ||
year = {2021} | ||
} | ||
``` | ||
## Acknowledgements | ||
|
||
python launcher/viz/viz_map.py --model_location 'results/PointRobot/ddpm_feasibility_hj_N16_minqc_2023-12-09_s54_486' | ||
Parts of this code are adapted from [IDQL](https://github.com/philippe-eecs/IDQL) and [DRPO](https://github.com/ManUtdMoon/Distributional-Reachability-Policy-Optimization). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,11 @@ | ||
#gym[mujoco] >= 0.21.0, < 0.24.1 | ||
numpy >= 1.20.2 | ||
dm_control >= 1.0.0 | ||
# jax >= 0.3.13 | ||
jax == 0.4.9 | ||
flax == 0.5.3 | ||
ml_collections >= 0.1.0 | ||
tqdm >= 4.60.0 | ||
optax == 0.1.5 | ||
absl-py >= 0.12.0 | ||
scipy >= 1.6.0 | ||
wandb >= 0.12.14 | ||
tensorflow-probability >= 0.17.0 | ||
# dmcgym @ git+https://github.com/ikostrikov/dmcgym | ||
moviepy | ||
imageio | ||
dsrl | ||
matplotlib |