Skip to content

Commit

Permalink
Offline + Finetune: IQL (tinkoff-ai#46)
Browse files Browse the repository at this point in the history
- Implement IQL for finetune
- Separate learning rate for offline IQL (and update the configs)
  • Loading branch information
DT6A authored Jun 13, 2023
1 parent 3e0dfd3 commit a220ecb
Show file tree
Hide file tree
Showing 54 changed files with 1,177 additions and 22 deletions.
771 changes: 771 additions & 0 deletions algorithms/finetune/iql.py

Large diffs are not rendered by default.

29 changes: 19 additions & 10 deletions algorithms/offline/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class TrainConfig:
iql_deterministic: bool = False # Use deterministic actor
normalize: bool = True # Normalize states
normalize_reward: bool = False # Normalize reward
vf_lr: float = 3e-4 # V function learning rate
qf_lr: float = 3e-4 # Critic learning rate
actor_lr: float = 3e-4 # Actor learning rate
actor_dropout: Optional[float] = None # Adroit uses dropout for policy network
# Wandb logging
project: str = "CORL"
Expand Down Expand Up @@ -418,12 +421,12 @@ def _update_v(self, observations, actions, log_dict) -> torch.Tensor:

def _update_q(
self,
next_v,
observations,
actions,
rewards,
terminals,
log_dict,
next_v: torch.Tensor,
observations: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
terminals: torch.Tensor,
log_dict: Dict,
):
targets = rewards + (1.0 - terminals.float()) * self.discount * next_v.detach()
qs = self.qf.both(observations, actions)
Expand All @@ -436,7 +439,13 @@ def _update_q(
# Update target Q network
soft_update(self.q_target, self.qf, self.tau)

def _update_policy(self, adv, observations, actions, log_dict):
def _update_policy(
self,
adv: torch.Tensor,
observations: torch.Tensor,
actions: torch.Tensor,
log_dict: Dict,
):
exp_adv = torch.exp(self.beta * adv.detach()).clamp(max=EXP_ADV_MAX)
policy_out = self.actor(observations)
if isinstance(policy_out, torch.distributions.Distribution):
Expand Down Expand Up @@ -560,9 +569,9 @@ def train(config: TrainConfig):
state_dim, action_dim, max_action, dropout=config.actor_dropout
)
).to(config.device)
v_optimizer = torch.optim.Adam(v_network.parameters(), lr=3e-4)
q_optimizer = torch.optim.Adam(q_network.parameters(), lr=3e-4)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=3e-4)
v_optimizer = torch.optim.Adam(v_network.parameters(), lr=config.vf_lr)
q_optimizer = torch.optim.Adam(q_network.parameters(), lr=config.qf_lr)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)

kwargs = {
"max_action": max_action,
Expand Down
24 changes: 24 additions & 0 deletions configs/finetune/iql/antmaze/large_diverse_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: antmaze-large-diverse-v2
eval_freq: 50000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.9
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 100
name: IQL_antmaze-large-diverse-v2
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
24 changes: 24 additions & 0 deletions configs/finetune/iql/antmaze/large_play_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: antmaze-large-play-v2
eval_freq: 50000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.9
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 100
name: IQL_antmaze-large-play-v2
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
24 changes: 24 additions & 0 deletions configs/finetune/iql/antmaze/medium_diverse_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: antmaze-medium-diverse-v2
eval_freq: 50000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.9
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 100
name: IQL_antmaze-medium-diverse-v2
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
24 changes: 24 additions & 0 deletions configs/finetune/iql/antmaze/medium_play_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: antmaze-medium-play-v2
eval_freq: 50000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.9
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 100
name: IQL_antmaze-medium-play-v2
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
24 changes: 24 additions & 0 deletions configs/finetune/iql/antmaze/umaze_diverse_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: antmaze-umaze-diverse-v2
eval_freq: 50000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.9
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 100
name: IQL_antmaze-umaze-diverse-v2
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
24 changes: 24 additions & 0 deletions configs/finetune/iql/antmaze/umaze_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: antmaze-umaze-v2
eval_freq: 50000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.9
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 100
name: IQL_antmaze-umaze-v2
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
25 changes: 25 additions & 0 deletions configs/finetune/iql/door/cloned_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
actor_lr: 3e-4
actor_dropout: 0.1
batch_size: 256
beta: 3.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: door-cloned-v1
eval_freq: 5000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.8
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 10
name: IQL_door-cloned-v1
normalize: true
normalize_reward: false
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
25 changes: 25 additions & 0 deletions configs/finetune/iql/hammer/cloned_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
actor_lr: 3e-4
actor_dropout: 0.1
batch_size: 256
beta: 3.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: hammer-cloned-v1
eval_freq: 5000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.8
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 10
name: IQL_hammer-cloned-v1
normalize: true
normalize_reward: false
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
25 changes: 25 additions & 0 deletions configs/finetune/iql/pen/cloned_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
actor_lr: 3e-4
actor_dropout: 0.1
batch_size: 256
beta: 3.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: pen-cloned-v1
eval_freq: 5000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.8
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 10
name: IQL_pen-cloned-v1
normalize: true
normalize_reward: false
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
25 changes: 25 additions & 0 deletions configs/finetune/iql/relocate/cloned_v1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
actor_lr: 3e-4
actor_dropout: 0.1
batch_size: 256
beta: 3.0
buffer_size: 10000000
checkpoints_path: null
device: cuda
discount: 0.99
env: relocate-cloned-v1
eval_freq: 5000
group: IQL-D4RL
iql_deterministic: false
iql_tau: 0.8
load_model: ''
offline_iterations: 1000000
online_iterations: 1000000
n_episodes: 10
name: IQL_relocate-cloned-v1
normalize: true
normalize_reward: false
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
3 changes: 3 additions & 0 deletions configs/offline/iql/antmaze/large_diverse_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
Expand All @@ -15,6 +16,8 @@ n_episodes: 100
name: IQL
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
3 changes: 3 additions & 0 deletions configs/offline/iql/antmaze/large_play_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
Expand All @@ -15,6 +16,8 @@ n_episodes: 100
name: IQL
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
3 changes: 3 additions & 0 deletions configs/offline/iql/antmaze/medium_diverse_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
Expand All @@ -15,6 +16,8 @@ n_episodes: 100
name: IQL
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
3 changes: 3 additions & 0 deletions configs/offline/iql/antmaze/medium_play_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
Expand All @@ -15,6 +16,8 @@ n_episodes: 100
name: IQL
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
3 changes: 3 additions & 0 deletions configs/offline/iql/antmaze/umaze_diverse_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
Expand All @@ -15,6 +16,8 @@ n_episodes: 100
name: IQL
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
3 changes: 3 additions & 0 deletions configs/offline/iql/antmaze/umaze_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
actor_lr: 3e-4
batch_size: 256
beta: 10.0
buffer_size: 10000000
Expand All @@ -15,6 +16,8 @@ n_episodes: 100
name: IQL
normalize: true
normalize_reward: true
qf_lr: 3e-4
project: CORL
seed: 0
tau: 0.005
vf_lr: 3e-4
Loading

0 comments on commit a220ecb

Please sign in to comment.