Skip to content

Commit

Permalink
Make ATAC take target_qf as input.
Browse files Browse the repository at this point in the history
  • Loading branch information
chinganc committed Jan 18, 2023
1 parent ee9ef8d commit b97066b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion lightATAC/atac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ATAC(nn.Module):
def __init__(self, *,
policy,
qf,
target_qf=None,
optimizer,
discount=0.99,
Vmin=-float('inf'), # min value of Q (used in target backup)
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(self, *,
# networks
self.policy = policy
self._qf = qf
self._target_qf = copy.deepcopy(self._qf).requires_grad_(False)
self._target_qf = copy.deepcopy(self._qf).requires_grad_(False) if target_qf is None else target_qf

# optimization
self._policy_lr = policy_lr
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def main(args):
# Assume vector observation and action
obs_dim, act_dim = dataset['observations'].shape[1], dataset['actions'].shape[1]
qf = TwinQ(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden).to(DEFAULT_DEVICE)
target_qf = copy.deepcopy(qf).requires_grad_(False)
policy = GaussianPolicy(obs_dim, act_dim, hidden_dim=args.hidden_dim, n_hidden=args.n_hidden,
use_tanh=True, std_type='diagonal').to(DEFAULT_DEVICE)
dataset['actions'] = np.clip(dataset['actions'], -1+EPS, 1-EPS) # due to tanh
Expand All @@ -74,7 +75,7 @@ def main(args):

# ------------------ Pretraining ------------------ #
# Train policy and value to fit the behavior data
bp = BehaviorPretraining(qf=qf, target_qf=rl._target_qf, policy=policy, lr=args.fast_lr, discount=args.discount,
bp = BehaviorPretraining(qf=qf, target_qf=target_qf, policy=policy, lr=args.fast_lr, discount=args.discount,
td_weight=0.5, rs_weight=0.5, fixed_alpha=None, action_shape=act_dim).to(DEFAULT_DEVICE)
def bp_log_fun(metrics, step):
print(step, metrics)
Expand Down

0 comments on commit b97066b

Please sign in to comment.