Skip to content

Commit

Permalink
LB-SAC implementation (tinkoff-ai#31)
Browse files Browse the repository at this point in the history
* lb-sac algo and configs from paper

* fix linter

* add edac init

* fix config

* updated readme and tables

* linter fix

* default device to cuda
  • Loading branch information
Howuhh authored Mar 1, 2023
1 parent f6cc8b6 commit ca1a949
Show file tree
Hide file tree
Showing 22 changed files with 1,098 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Internal stuff
benchmark.py
mlc_run.sh
wandb_run.sh
sweep_config.yaml
.ml-job-preset.yml
wandb

Expand Down
75 changes: 38 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,37 @@ docker run gpus=all -it --rm --name <container_name> <image_name>

## Algorithms Implemented

| Algorithm | Variants Implemented | Wandb Report |
| ----------- | ----------- | ----------- |
| ✅ Behavioral Cloning <br>(BC) | [`any_percent_bc.py`](algorithms/any_percent_bc.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/BC-D4RL-Results--VmlldzoyNzA2MjE1)
| ✅ Behavioral Cloning-10% <br>(BC-10%) | [`any_percent_bc.py`](algorithms/any_percent_bc.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/BC-10-D4RL-Results--VmlldzoyNzEwMjcx)
| ✅ [Conservative Q-Learning for Offline Reinforcement Learning <br>(CQL)](https://arxiv.org/abs/2006.04779) | [`cql.py`](algorithms/cql.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/CQL-D4RL-Results--VmlldzoyNzA2MTk5)
| ✅ [Accelerating Online Reinforcement Learning with Offline Datasets <br>(AWAC)](https://arxiv.org/abs/2006.09359) | [`awac.py`](algorithms/awac.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/AWAC-D4RL-Results--VmlldzoyNzA2MjE3)
| ✅ [Offline Reinforcement Learning with Implicit Q-Learning <br>(IQL)](https://arxiv.org/abs/2110.06169) | [`iql.py`](algorithms/iql.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/IQL-D4RL-Results--VmlldzoyNzA2MTkx)
| ✅ [A Minimalist Approach to Offline Reinforcement Learning <br>(TD3+BC)](https://arxiv.org/abs/2106.06860) | [`td3_bc.py`](algorithms/td3_bc.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/TD3-BC-D4RL-Results--VmlldzoyNzA2MjA0)
| ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling <br>(DT)](https://arxiv.org/abs/2106.01345) | [`dt.py`](algorithms/dt.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/DT-D4RL-Results--VmlldzoyNzA2MTk3)
| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble <br>(SAC-N)](https://arxiv.org/abs/2110.01548) | [`sac_n.py`](algorithms/sac_n.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/SAC-N-D4RL-Results--VmlldzoyNzA1NTY1)
| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble <br>(EDAC)](https://arxiv.org/abs/2110.01548) | [`edac.py`](algorithms/edac.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/EDAC-D4RL-Results--VmlldzoyNzA5ODUw)
| Algorithm | Variants Implemented | Wandb Report |
|---------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------| ----------- |
| ✅ Behavioral Cloning <br>(BC) | [`any_percent_bc.py`](algorithms/any_percent_bc.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/BC-D4RL-Results--VmlldzoyNzA2MjE1)
| ✅ Behavioral Cloning-10% <br>(BC-10%) | [`any_percent_bc.py`](algorithms/any_percent_bc.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/BC-10-D4RL-Results--VmlldzoyNzEwMjcx)
| ✅ [Conservative Q-Learning for Offline Reinforcement Learning <br>(CQL)](https://arxiv.org/abs/2006.04779) | [`cql.py`](algorithms/cql.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/CQL-D4RL-Results--VmlldzoyNzA2MTk5)
| ✅ [Accelerating Online Reinforcement Learning with Offline Datasets <br>(AWAC)](https://arxiv.org/abs/2006.09359) | [`awac.py`](algorithms/awac.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/AWAC-D4RL-Results--VmlldzoyNzA2MjE3)
| ✅ [Offline Reinforcement Learning with Implicit Q-Learning <br>(IQL)](https://arxiv.org/abs/2110.06169) | [`iql.py`](algorithms/iql.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/IQL-D4RL-Results--VmlldzoyNzA2MTkx)
| ✅ [A Minimalist Approach to Offline Reinforcement Learning <br>(TD3+BC)](https://arxiv.org/abs/2106.06860) | [`td3_bc.py`](algorithms/td3_bc.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/TD3-BC-D4RL-Results--VmlldzoyNzA2MjA0)
| ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling <br>(DT)](https://arxiv.org/abs/2106.01345) | [`dt.py`](algorithms/dt.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/DT-D4RL-Results--VmlldzoyNzA2MTk3)
| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble <br>(SAC-N)](https://arxiv.org/abs/2110.01548) | [`sac_n.py`](algorithms/sac_n.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/SAC-N-D4RL-Results--VmlldzoyNzA1NTY1)
| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble <br>(EDAC)](https://arxiv.org/abs/2110.01548) | [`edac.py`](algorithms/edac.py) | [`Gym-MuJoCo, Maze2D`](https://wandb.ai/tlab/CORL/reports/EDAC-D4RL-Results--VmlldzoyNzA5ODUw)
| ✅ [Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size <br>(LB-SAC)](https://arxiv.org/abs/2211.11092) | [`lb_sac.py`](algorithms/lb_sac.py) | [`Gym-MuJoCo`](https://wandb.ai/tlab/CORL/reports/LB-SAC-D4RL-Results--VmlldzozNjIxMDY1)

## D4RL Benchmarks
For learning curves and all the details, you can check the links above. Here, we report reproduced **final** and **best** scores. Note that thay differ by a big margin, and some papers may use different approaches not making it always explicit which one reporting methodology they chose.

### Last Scores
#### Gym-MuJoCo
| **Task-Name**|BC|BC-10%|TD3 + BC|CQL|IQL|AWAC|SAC-N|EDAC|DT |
|------------------------------|------------|--------|--------|-----|-----|------|-------|------|----|
|halfcheetah-medium-v2 | 42.40±0.21 | 42.46±0.81 | 48.10±0.21 | 47.08±0.19 | 48.31±0.11 | 50.01±0.30 | 68.20±1.48 | 67.70±1.20 | 42.20±0.30|
|halfcheetah-medium-expert-v2 | 55.95±8.49 | 90.10±2.83 | 90.78±6.98 | 95.98±0.83 | 94.55±0.21 | 95.29±0.91 | 98.96±10.74 | 104.76±0.74 | 91.55±1.10|
|halfcheetah-medium-replay-v2 | 35.66±2.68 | 23.59±8.02 | 44.84±0.68 | 45.19±0.58 | 43.53±0.43 | 44.91±1.30 | 60.70±1.17 | 62.06±1.27 | 38.91±0.57|
|hopper-medium-v2 | 53.51±2.03 | 55.48±8.43 | 60.37±4.03 | 64.98±6.12 | 62.75±6.02 | 63.69±4.29 | 40.82±11.44 | 101.70±0.32 | 65.10±1.86|
|hopper-medium-expert-v2 | 52.30±4.63 | 111.16±1.19 | 101.17±10.48 | 93.89±14.34 | 106.24±6.09 | 105.29±7.19 | 101.31±13.43 | 105.19±11.64 | 110.44±0.39|
|hopper-medium-replay-v2 | 29.81±2.39 | 70.42±9.99 | 64.42±24.84 | 87.67±14.42 | 84.57±13.49 | 98.15±2.85 | 100.33±0.90 | 99.66±0.94 | 81.77±7.93|
|walker2d-medium-v2 | 63.23±18.76 | 67.34±5.97 | 82.71±5.51 | 80.38±3.45 | 84.03±5.42 | 69.39±31.97 | 87.47±0.76 | 93.36±1.60 | 67.63±2.93|
|walker2d-medium-expert-v2 | 98.96±18.45 | 108.70±0.29 | 110.03±0.41 | 109.68±0.52 | 111.68±0.56 | 111.16±2.41 | 114.93±0.48 | 114.75±0.86 | 107.11±1.11|
|walker2d-medium-replay-v2 | 21.80±11.72 | 54.35±7.32 | 85.62±4.63 | 79.24±4.97 | 82.55±8.00 | 71.73±13.98 | 78.99±0.58 | 87.10±3.21 | 59.86±3.15|
| | | | | | | | | | |
| **locomotion average** |50.40 | 69.29 | 76.45 | 78.23 | 79.80 | 78.85 | 83.52 | 92.92 | 73.84|
| **Task-Name**|BC|BC-10%|TD3 + BC|CQL|IQL|AWAC|SAC-N|EDAC|DT |LB-SAC |
|------------------------------|------------|--------|--------|-----|-----|------|-------|------|----|------|
|halfcheetah-medium-v2 | 42.40±0.21 | 42.46±0.81 | 48.10±0.21 | 47.08±0.19 | 48.31±0.11 | 50.01±0.30 | 68.20±1.48 | 67.70±1.20 | 42.20±0.30| 71.21±1.35|
|halfcheetah-medium-expert-v2 | 55.95±8.49 | 90.10±2.83 | 90.78±6.98 | 95.98±0.83 | 94.55±0.21 | 95.29±0.91 | 98.96±10.74 | 104.76±0.74 | 91.55±1.10| 106.57±3.90|
|halfcheetah-medium-replay-v2 | 35.66±2.68 | 23.59±8.02 | 44.84±0.68 | 45.19±0.58 | 43.53±0.43 | 44.91±1.30 | 60.70±1.17 | 62.06±1.27 | 38.91±0.57| 64.10±0.82|
|hopper-medium-v2 | 53.51±2.03 | 55.48±8.43 | 60.37±4.03 | 64.98±6.12 | 62.75±6.02 | 63.69±4.29 | 40.82±11.44 | 101.70±0.32 | 65.10±1.86| 103.75±0.07|
|hopper-medium-expert-v2 | 52.30±4.63 | 111.16±1.19 | 101.17±10.48 | 93.89±14.34 | 106.24±6.09 | 105.29±7.19 | 101.31±13.43 | 105.19±11.64 | 110.44±0.39| 110.93±0.51|
|hopper-medium-replay-v2 | 29.81±2.39 | 70.42±9.99 | 64.42±24.84 | 87.67±14.42 | 84.57±13.49 | 98.15±2.85 | 100.33±0.90 | 99.66±0.94 | 81.77±7.93| 102.53±0.92|
|walker2d-medium-v2 | 63.23±18.76 | 67.34±5.97 | 82.71±5.51 | 80.38±3.45 | 84.03±5.42 | 69.39±31.97 | 87.47±0.76 | 93.36±1.60 | 67.63±2.93| 90.95±0.65|
|walker2d-medium-expert-v2 | 98.96±18.45 | 108.70±0.29 | 110.03±0.41 | 109.68±0.52 | 111.68±0.56 | 111.16±2.41 | 114.93±0.48 | 114.75±0.86 | 107.11±1.11| 113.46±2.31|
|walker2d-medium-replay-v2 | 21.80±11.72 | 54.35±7.32 | 85.62±4.63 | 79.24±4.97 | 82.55±8.00 | 71.73±13.98 | 78.99±0.58 | 87.10±3.21 | 59.86±3.15| 87.95±1.43|
| | | | | | | | | | | |
| **locomotion average** |50.40 | 69.29 | 76.45 | 78.23 | 79.80 | 78.85 | 83.52 | 92.92 | 73.84| 94.60|

#### Maze2d
| **Task-Name**|BC|BC-10%|TD3 + BC|CQL|IQL|AWAC|SAC-N|EDAC|DT |
Expand All @@ -77,19 +78,19 @@ For learning curves and all the details, you can check the links above. Here, we

### Best Scores
#### Gym-MuJoCo
| **Task-Name**|BC|BC-10%|TD3 + BC|CQL|IQL|AWAC|SAC-N|EDAC|DT |
|------------------------------|------------|--------|--------|-----|-----|------|-------|------|----|
|halfcheetah-medium-v2 | 43.60±0.16 | 43.90±0.15 | 48.93±0.13 | 47.45±0.10 | 48.77±0.06 | 50.87±0.21 | 72.21±0.35 | 69.72±1.06 | 42.73±0.11|
|halfcheetah-medium-expert-v2 | 79.69±3.58 | 94.11±0.25 | 96.59±1.01 | 96.74±0.14 | 95.83±0.38 | 96.87±0.31 | 111.73±0.55 | 110.62±1.20 | 93.40±0.25|
|halfcheetah-medium-replay-v2 | 40.52±0.22 | 42.27±0.53 | 45.84±0.30 | 46.38±0.14 | 45.06±0.16 | 46.57±0.27 | 67.29±0.39 | 66.55±1.21 | 40.31±0.32|
|hopper-medium-v2 | 69.04±3.35 | 73.84±0.43 | 70.44±1.37 | 77.47±6.00 | 80.74±1.27 | 99.40±1.12 | 101.79±0.23 | 103.26±0.16 | 69.42±4.21|
|hopper-medium-expert-v2 | 90.63±12.68 | 113.13±0.19 | 113.22±0.50 | 112.74±0.07 | 111.79±0.47 | 113.37±0.63 | 111.24±0.17 | 111.80±0.13 | 111.18±0.24|
|hopper-medium-replay-v2 | 68.88±11.93 | 90.57±2.38 | 98.12±1.34 | 102.20±0.38 | 102.33±0.44 | 101.76±0.43 | 103.83±0.61 | 103.28±0.57 | 88.74±3.49|
|walker2d-medium-v2 | 80.64±1.06 | 82.05±1.08 | 86.91±0.32 | 84.57±0.15 | 87.99±0.83 | 86.22±4.58 | 90.17±0.63 | 95.78±1.23 | 74.70±0.64|
|walker2d-medium-expert-v2 | 109.95±0.72 | 109.90±0.10 | 112.21±0.07 | 111.63±0.20 | 113.19±0.33 | 113.40±2.57 | 116.93±0.49 | 116.52±0.86 | 108.71±0.39|
|walker2d-medium-replay-v2 | 48.41±8.78 | 76.09±0.47 | 91.17±0.83 | 89.34±0.59 | 91.85±2.26 | 87.06±0.93 | 85.18±1.89 | 89.69±1.60 | 68.22±1.39|
| | | | | | | | | | |
| **locomotion average** | 70.15 | 80.65 | 84.83 | 85.39 | 86.40 | 88.39 | 95.60 | 96.36 | 77.49 |
| **Task-Name**|BC|BC-10%|TD3 + BC|CQL|IQL|AWAC|SAC-N|EDAC|DT | LB-SAC|
|------------------------------|------------|--------|--------|-----|-----|------|-------|------|----|----|
|halfcheetah-medium-v2 | 43.60±0.16 | 43.90±0.15 | 48.93±0.13 | 47.45±0.10 | 48.77±0.06 | 50.87±0.21 | 72.21±0.35 | 69.72±1.06 | 42.73±0.11| 71.82±0.68|
|halfcheetah-medium-expert-v2 | 79.69±3.58 | 94.11±0.25 | 96.59±1.01 | 96.74±0.14 | 95.83±0.38 | 96.87±0.31 | 111.73±0.55 | 110.62±1.20 | 93.40±0.25| 110.37±0.47|
|halfcheetah-medium-replay-v2 | 40.52±0.22 | 42.27±0.53 | 45.84±0.30 | 46.38±0.14 | 45.06±0.16 | 46.57±0.27 | 67.29±0.39 | 66.55±1.21 | 40.31±0.32| 66.14±1.06|
|hopper-medium-v2 | 69.04±3.35 | 73.84±0.43 | 70.44±1.37 | 77.47±6.00 | 80.74±1.27 | 99.40±1.12 | 101.79±0.23 | 103.26±0.16 | 69.42±4.21| 103.88±0.17|
|hopper-medium-expert-v2 | 90.63±12.68 | 113.13±0.19 | 113.22±0.50 | 112.74±0.07 | 111.79±0.47 | 113.37±0.63 | 111.24±0.17 | 111.80±0.13 | 111.18±0.24| 110.93±0.51|
|hopper-medium-replay-v2 | 68.88±11.93 | 90.57±2.38 | 98.12±1.34 | 102.20±0.38 | 102.33±0.44 | 101.76±0.43 | 103.83±0.61 | 103.28±0.57 | 88.74±3.49| 104.00±0.94|
|walker2d-medium-v2 | 80.64±1.06 | 82.05±1.08 | 86.91±0.32 | 84.57±0.15 | 87.99±0.83 | 86.22±4.58 | 90.17±0.63 | 95.78±1.23 | 74.70±0.64| 90.95±0.65|
|walker2d-medium-expert-v2 | 109.95±0.72 | 109.90±0.10 | 112.21±0.07 | 111.63±0.20 | 113.19±0.33 | 113.40±2.57 | 116.93±0.49 | 116.52±0.86 | 108.71±0.39| 113.46±2.31|
|walker2d-medium-replay-v2 | 48.41±8.78 | 76.09±0.47 | 91.17±0.83 | 89.34±0.59 | 91.85±2.26 | 87.06±0.93 | 85.18±1.89 | 89.69±1.60 | 68.22±1.39| 92.25±2.20|
| | | | | | | | | | | |
| **locomotion average** | 70.15 | 80.65 | 84.83 | 85.39 | 86.40 | 88.39 | 95.60 | 96.36 | 77.49 | 95.97|


#### Maze2d
Expand Down
Loading

0 comments on commit ca1a949

Please sign in to comment.