Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
usaito committed Dec 20, 2021
1 parent ab6b477 commit 0659552
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
6 changes: 6 additions & 0 deletions obp/ope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,9 @@
"SwitchDoublyRobustTuning",
"DoublyRobustWithShrinkageTuning",
]


__all_estimators_tuning_sg__ = [
"SubGaussianInverseProbabilityWeightingTuning",
"SubGaussianDoublyRobustTuning",
]
62 changes: 60 additions & 2 deletions tests/ope/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def test_estimation_of_all_estimators_using_invalid_input_data(
for estimator_name in all_estimators_tuning
for tuning_method in ["slope", "mse"]
]
all_estimators_tuning_sg = ope.__all_estimators_tuning_sg__
estimators_tuning_sg = [
getattr(ope.estimators_tuning, estimator_name)(
lambdas=[0.001, 0.01, 0.1, 1.0],
tuning_method=tuning_method,
)
for estimator_name in all_estimators_tuning_sg
for tuning_method in ["slope", "mse"]
]
estimators_tuning = estimators_tuning + estimators_tuning_sg
# estimate_intervals function raises ValueError of all estimators
for estimator in estimators:
with pytest.raises(ValueError, match=f"{description}*"):
Expand Down Expand Up @@ -244,6 +254,16 @@ def test_estimation_of_all_estimators_using_valid_input_data(
for estimator_name in all_estimators_tuning
for tuning_method in ["slope", "mse"]
]
all_estimators_tuning_sg = ope.__all_estimators_tuning_sg__
estimators_tuning_sg = [
getattr(ope.estimators_tuning, estimator_name)(
lambdas=[0.001, 0.01, 0.1, 1.0],
tuning_method=tuning_method,
)
for estimator_name in all_estimators_tuning_sg
for tuning_method in ["slope", "mse"]
]
estimators_tuning = estimators_tuning + estimators_tuning_sg
# estimate_intervals function raises ValueError of all estimators
for estimator in estimators:
_ = estimator.estimate_policy_value(
Expand Down Expand Up @@ -348,6 +368,16 @@ def test_estimate_intervals_of_all_estimators_using_invalid_input_data(
for estimator_name in all_estimators_tuning
for tuning_method in ["slope", "mse"]
]
all_estimators_tuning_sg = ope.__all_estimators_tuning_sg__
estimators_tuning_sg = [
getattr(ope.estimators_tuning, estimator_name)(
lambdas=[0.001, 0.01, 0.1, 1.0],
tuning_method=tuning_method,
)
for estimator_name in all_estimators_tuning_sg
for tuning_method in ["slope", "mse"]
]
estimators_tuning = estimators_tuning + estimators_tuning_sg
# estimate_intervals function raises ValueError of all estimators
for estimator in estimators:
with pytest.raises(err, match=f"{description}*"):
Expand Down Expand Up @@ -409,6 +439,16 @@ def test_estimate_intervals_of_all_estimators_using_valid_input_data(
for estimator_name in all_estimators_tuning
for tuning_method in ["slope", "mse"]
]
all_estimators_tuning_sg = ope.__all_estimators_tuning_sg__
estimators_tuning_sg = [
getattr(ope.estimators_tuning, estimator_name)(
lambdas=[0.001, 0.01, 0.1, 1.0],
tuning_method=tuning_method,
)
for estimator_name in all_estimators_tuning_sg
for tuning_method in ["slope", "mse"]
]
estimators_tuning = estimators_tuning + estimators_tuning_sg
# estimate_intervals function raises ValueError of all estimators
for estimator in estimators:
_ = estimator.estimate_interval(
Expand Down Expand Up @@ -481,7 +521,16 @@ def test_performance_of_ope_estimators_using_random_evaluation_policy(
for estimator_name in all_estimators_tuning
for tuning_method in ["slope", "mse"]
]
estimators = estimators_standard + estimators_tuning
all_estimators_tuning_sg = ope.__all_estimators_tuning_sg__
estimators_tuning_sg = [
getattr(ope.estimators_tuning, estimator_name)(
lambdas=[0.001, 0.01, 0.1, 1.0],
tuning_method=tuning_method,
)
for estimator_name in all_estimators_tuning_sg
for tuning_method in ["slope", "mse"]
]
estimators = estimators_standard + estimators_tuning + estimators_tuning_sg
# conduct OPE
ope_instance = OffPolicyEvaluation(
bandit_feedback=synthetic_bandit_feedback, ope_estimators=estimators
Expand Down Expand Up @@ -523,7 +572,16 @@ def test_response_format_of_ope_estimators_using_random_evaluation_policy(
for estimator_name in all_estimators_tuning
for tuning_method in ["slope", "mse"]
]
estimators = estimators_standard + estimators_tuning
all_estimators_tuning_sg = ope.__all_estimators_tuning_sg__
estimators_tuning_sg = [
getattr(ope.estimators_tuning, estimator_name)(
lambdas=[0.001, 0.01, 0.1, 1.0],
tuning_method=tuning_method,
)
for estimator_name in all_estimators_tuning_sg
for tuning_method in ["slope", "mse"]
]
estimators = estimators_standard + estimators_tuning + estimators_tuning_sg
# conduct OPE
ope_instance = OffPolicyEvaluation(
bandit_feedback=synthetic_bandit_feedback, ope_estimators=estimators
Expand Down

0 comments on commit 0659552

Please sign in to comment.