Skip to content

Commit

Permalink
allow slope to use the true maginal importance weight for mips
Browse files Browse the repository at this point in the history
  • Loading branch information
usaito committed Jun 15, 2022
1 parent 44e7412 commit 83f3945
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions obp/ope/estimators_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def estimate_policy_value(
position=position,
pi_b=pi_b,
action_dist=action_dist,
p_e_a=p_e_a,
)
elif self.embedding_selection_method == "greedy":
return self._estimate_with_greedy_pruning(
Expand All @@ -313,6 +314,7 @@ def estimate_policy_value(
position=position,
pi_b=pi_b,
action_dist=action_dist,
p_e_a=p_e_a,
)
else:
return self._estimate_round_rewards(
Expand All @@ -335,6 +337,7 @@ def _estimate_with_exact_pruning(
pi_b: np.ndarray,
action_dist: np.ndarray,
position: np.ndarray,
p_e_a: Optional[np.ndarray] = None,
) -> float:
"""Apply an exact version of data-drive action embedding selection."""
n_emb_dim = action_embed.shape[1]
Expand All @@ -352,6 +355,7 @@ def _estimate_with_exact_pruning(
pi_b=pi_b,
action_dist=action_dist,
position=position,
p_e_a=p_e_a[:, :, comb],
with_dev=True,
)
if len(theta_list) > 0:
Expand Down Expand Up @@ -380,6 +384,7 @@ def _estimate_with_greedy_pruning(
pi_b: np.ndarray,
action_dist: np.ndarray,
position: np.ndarray,
p_e_a: Optional[np.ndarray] = None,
) -> float:
"""Apply a greedy version of data-drive action embedding selection."""
n_emb_dim = action_embed.shape[1]
Expand All @@ -395,6 +400,7 @@ def _estimate_with_greedy_pruning(
pi_b=pi_b,
action_dist=action_dist,
position=position,
p_e_a=p_e_a[:, :, current_feat],
with_dev=True,
)
theta_list.append(theta), cnf_list.append(cnf)
Expand All @@ -413,6 +419,7 @@ def _estimate_with_greedy_pruning(
pi_b=pi_b,
action_dist=action_dist,
position=position,
p_e_a=p_e_a[:, :, candidate_feat],
with_dev=True,
)
d_list_.append(d)
Expand Down

0 comments on commit 83f3945

Please sign in to comment.