Skip to content

Commit

Permalink
use predict_proa
Browse files Browse the repository at this point in the history
  • Loading branch information
Kurorororo committed Mar 8, 2021
1 parent d655183 commit c4ee9ab
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
6 changes: 4 additions & 2 deletions examples/opl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ python evaluate_off_policy_learners.py\
# =============================================
# random_state=12345
# ---------------------------------------------
# random ipw nn
# policy value 0.604339 0.767615 0.77251
# policy value
# random 0.604339
# ipw 0.767615
# nn 0.764302
# =============================================
```

Expand Down
4 changes: 2 additions & 2 deletions examples/opl/evaluate_off_policy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
ipw_learner_action_dist = ipw_learner.predict(
context=bandit_feedback_test["context"],
)
nn_policy_learner_action_dist = nn_policy_learner.predict(
nn_policy_learner_action_dist = nn_policy_learner.predict_proba(
context=bandit_feedback_test["context"],
)

Expand All @@ -236,7 +236,7 @@
],
columns=["policy value"],
index=["random", "ipw", "nn"],
).T.round(6)
).round(6)
print("=" * 45)
print(f"random_state={random_state}")
print("-" * 45)
Expand Down
14 changes: 7 additions & 7 deletions examples/quickstart/opl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
" estimated_rewards_by_reg_model=estimated_rewards_by_reg_model,\n",
")\n",
"# obtains action choice probabilities for the test set of the synthetic logged bandit feedback\n",
"action_dist_nn_dm = nn_dm.predict(context=bandit_feedback_test[\"context\"])"
"action_dist_nn_dm = nn_dm.predict_proba(context=bandit_feedback_test[\"context\"])"
]
},
{
Expand All @@ -252,7 +252,7 @@
" pscore=bandit_feedback_train[\"pscore\"],\n",
")\n",
"# obtains action choice probabilities for the test set of the synthetic logged bandit feedback\n",
"action_dist_nn_ipw = nn_ipw.predict(context=bandit_feedback_test[\"context\"])"
"action_dist_nn_ipw = nn_ipw.predict_proba(context=bandit_feedback_test[\"context\"])"
]
},
{
Expand All @@ -279,7 +279,7 @@
" estimated_rewards_by_reg_model=estimated_rewards_by_reg_model,\n",
")\n",
"# obtains action choice probabilities for the test set of the synthetic logged bandit feedback\n",
"action_dist_nn_dr = nn_dr.predict(context=bandit_feedback_test[\"context\"])"
"action_dist_nn_dr = nn_dr.predict_proba(context=bandit_feedback_test[\"context\"])"
]
},
{
Expand Down Expand Up @@ -372,9 +372,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"policy value of NN Policy Learner with DM: 0.6786610995854737\n",
"policy value of NN Policy Learner with IPW: 0.751523191255569\n",
"policy value of NN Policy Learner with DR: 0.7719675528392871\n",
"policy value of NN Policy Learner with DM: 0.6785771195516228\n",
"policy value of NN Policy Learner with IPW: 0.7429362678096227\n",
"policy value of NN Policy Learner with DR: 0.7651217293062053\n",
"policy value of IPW Learner with Logistic Regression: 0.767614655337475\n",
"policy value of IPW Learner with Random Forest: 0.703809241480009\n",
"policy value of Unifrom Random: 0.6043385526445931\n"
Expand Down Expand Up @@ -411,7 +411,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In fact, NN Policy Learner with DR reveals the best performance among the 6 evaluation policies."
"In fact, IPW Learner with Logistic Regression is the best, and NN Policy Learner with DR is the second."
]
},
{
Expand Down
29 changes: 27 additions & 2 deletions obp/policy/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,11 @@ def predict_proba(

@dataclass
class NNPolicyLearner(BaseOfflinePolicyLearner):
"""Off-policy learner using an neural network whose objective function is an off-policy estimator.
"""Off-policy learner using an neural network whose objective function is an OPE estimator.
Note
--------
MLP is implemented in PyTorch.
Parameters
-----------
Expand Down Expand Up @@ -727,6 +731,16 @@ def fit(
) -> None:
"""Fits an offline bandit policy using the given logged bandit feedback data.
Note
----------
Given the training data :math:`\\mathcal{D}`, this policy maximizes the following objective function:
.. math::
\\hat{V}(\\pi_\\theta; \\mathcal{D}) - \\lambda \\Omega(\\theta)
where :math:`\\hat{V}` is an OPE estimator and :math:`\\lambda \\Omega(\\theta)` is a regularization term.
Parameters
-----------
context: array-like, shape (n_rounds, dim_context)
Expand Down Expand Up @@ -960,7 +974,18 @@ def predict_proba(
self,
context: np.ndarray,
) -> np.ndarray:
"""Obtains action choice probabilities for new data based on scores predicted by a classifier.
"""Obtains action choice probabilities for new data.
Note
--------
This policy uses multi-layer perceptron (MLP) and the softmax function as the last layer.
This is a stochastic policy and represented as follows:
.. math::
\\pi_\\theta (a \\mid x) = \\frac{\\exp(f_\\theta(x, a))}{\\sum_{a' \\in \\mathcal{A}} \\exp(f_\\theta(x, a'))}
where :math:`f__\\theta(x, a)` is MLP with parameter :math:`\\theta`.
Parameters
----------------
Expand Down

0 comments on commit c4ee9ab

Please sign in to comment.