Skip to content

Commit

Permalink
fix a bug of mipw with one-dimensional action embed
Browse files Browse the repository at this point in the history
  • Loading branch information
usaito committed Mar 27, 2022
1 parent 623510e commit 66e052f
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions obp/ope/estimators_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,28 +293,27 @@ def estimate_policy_value(
else:
check_array(array=context, name="context", expected_dim=2)

if self.embedding_selection_method == "exact":
return self._estimate_with_exact_pruning(
context=context,
reward=reward,
action=action,
action_embed=action_embed,
position=position,
pi_b=pi_b,
action_dist=action_dist,
)

elif self.embedding_selection_method == "greedy":
return self._estimate_with_greedy_pruning(
context=context,
reward=reward,
action=action,
action_embed=action_embed,
position=position,
pi_b=pi_b,
action_dist=action_dist,
)

if action_embed.shape[1] > 1 and self.embedding_selection_method is not None:
if self.embedding_selection_method == "exact":
return self._estimate_with_exact_pruning(
context=context,
reward=reward,
action=action,
action_embed=action_embed,
position=position,
pi_b=pi_b,
action_dist=action_dist,
)
elif self.embedding_selection_method == "greedy":
return self._estimate_with_greedy_pruning(
context=context,
reward=reward,
action=action,
action_embed=action_embed,
position=position,
pi_b=pi_b,
action_dist=action_dist,
)
else:
return self._estimate_round_rewards(
context=context,
Expand Down

0 comments on commit 66e052f

Please sign in to comment.