Skip to content

Commit

Permalink
Merge pull request marlbenchmark#113 from HosnLS/main
Browse files Browse the repository at this point in the history
fix: r_mappo & mat buffer factor unpack
  • Loading branch information
zoeyuchao authored Jul 18, 2024
2 parents dd61a20 + cfc6f6d commit de66d7a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
11 changes: 8 additions & 3 deletions onpolicy/algorithms/mat/mat_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ def ppo_update(self, sample):
:return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
:return imp_weights: (torch.Tensor) importance sampling weights.
"""
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch = sample
if len(sample) == 12:
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch = sample
else:
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch, _ = sample

old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
adv_targ = check(adv_targ).to(**self.tpdv)
Expand Down
11 changes: 8 additions & 3 deletions onpolicy/algorithms/r_mappo/r_mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ def ppo_update(self, sample, update_actor=True):
:return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
:return imp_weights: (torch.Tensor) importance sampling weights.
"""
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch = sample
if len(sample) == 12:
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch = sample
else:
share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
adv_targ, available_actions_batch, _ = sample

old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
adv_targ = check(adv_targ).to(**self.tpdv)
Expand Down

0 comments on commit de66d7a

Please sign in to comment.