Skip to content

Commit

Permalink
add linear and off-policy bandits
Browse files Browse the repository at this point in the history
  • Loading branch information
usaito committed Aug 4, 2020
1 parent ac6bcd2 commit 398153e
Show file tree
Hide file tree
Showing 6 changed files with 614 additions and 57 deletions.
4 changes: 3 additions & 1 deletion obp/policy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .base import *
from .contextfree import *
from .contextual import *
from .linear import *
from .logistic import *
from .offline import *
146 changes: 145 additions & 1 deletion obp/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""Base Interfaces for Bandit Algorithms."""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union
from typing import Optional, Union, Tuple

import numpy as np
from sklearn.base import ClassifierMixin
from sklearn.utils import check_random_state

from ..ope import RegressionModel


@dataclass
class BaseContextFreePolicy(metaclass=ABCMeta):
Expand Down Expand Up @@ -144,4 +147,145 @@ def update_params(self, action: float, reward: float, context: np.ndarray) -> No
pass


@dataclass
class BaseOffPolicyLearner(metaclass=ABCMeta):
"""Base Class for off-policy learner with standard OPE estimators.
Note
------
Parameters
-----------
base_model: ClassifierMixin
Machine learning classifier to be used to create the decision making policy.
Examples
----------
.. code-block:: python
Reference
-----------
Miroslav Dudík, Dumitru Erhan, John Langford, and Lihong Li.
"Doubly Robust Policy Evaluation and Optimization.", 2014.
"""

base_model: ClassifierMixin

def __post_init__(self) -> None:
"""Initialize class."""
pass

@abstractmethod
def _create_train_data_for_opl(
self,
context: np.ndarray,
action: np.ndarray,
reward: np.ndarray,
pscore: np.ndarray,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Create training data for off-policy learning.
Parameters
-----------
context: array-like, shape (n_actions,)
Context vectors in the given training logged bandit feedback.
action: array-like, shape (n_actions,)
Selected actions by behavior policy in the given training logged bandit feedback.
reward: array-like, shape (n_actions,)
Observed rewards in the given training logged bandit feedback.
pscore: Optional[np.ndarray], default: None
Propensity scores, the probability of selecting each action by behavior policy,
in the given training logged bandit feedback.
Returns
--------
(X, sample_weight, y): Tuple[np.ndarray, np.ndarray, np.ndarray]
Feature vectors, sample weights, and outcome for training the base machine learning model.
"""
return NotImplementedError

def fit(
self,
context: np.ndarray,
action: float,
reward: float,
pscore: Optional[np.ndarray] = None,
action_context: Optional[np.ndarray] = None,
regression_model: Optional[RegressionModel] = None,
) -> None:
"""Fits the offline bandit policy according to the given logged bandit feedback data.
Parameters
-----------
context: array-like, shape (n_actions,)
Context vectors in the given training logged bandit feedback.
action: array-like, shape (n_actions,)
Selected actions by behavior policy in the given training logged bandit feedback.
reward: array-like, shape (n_actions,)
Observed rewards in the given training logged bandit feedback.
pscore: Optional[np.ndarray], default: None
Propensity scores, the probability of selecting each action by behavior policy,
in the given training logged bandit feedback.
action_context: array-like, shape (n_actions, dim_action_context), default: None
Context vectors used as input to predict the mean reward function.
regression_model: Optional[RegressionModel], default: None
Regression model that predicts the mean reward function :math:`E[Y | X, A]`.
"""
X, sample_weight, y = self._create_train_data_for_opl(
context=context,
action=action,
reward=reward,
pscore=pscore,
action_context=action_context,
regression_model=regression_model,
)
self.base_model.fit(X=X, y=y, sample_weight=sample_weight)

def predict(self, context: np.ndarray) -> None:
"""Predict best action for new data.
Parameters
-----------
context: array like of shape (n_rounds_of_new_data, dim_context)
Observed context vector for new data.
Returns
-----------
pred: array like of shape (n_rounds_of_new_data,)
Predicted best action for new data.
"""
return self.base_model.predict(context)

def predict_proba(self, context: np.ndarray) -> None:
"""Predict probabilities of each action being the best one for new data.
Parameters
-----------
context: array like of shape (n_rounds_of_new_data, dim_context)
Observed context vector for new data.
Returns
-----------
pred_proba: array like of shape (n_rounds_of_new_data, n_actions)
Probability estimates of each arm being the best one for new data.
The returned estimates for all classes are ordered by the label of classes.
"""
return self.base_model.predict_proba(context)


BanditPolicy = Union[BaseContextFreePolicy, BaseContextualPolicy]
15 changes: 6 additions & 9 deletions obp/policy/contextfree.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,9 @@ def select_action(self) -> np.ndarray:
selected_actions: array-like shape (len_list, )
List of selected actions.
"""
self.n_trial += 1
if self.random_.rand() > self.epsilon:
unsorted_max_arms = np.argpartition(-self.reward_counts, self.len_list)[
: self.len_list
]
return unsorted_max_arms[np.argsort(-self.reward_counts[unsorted_max_arms])]
if (self.random_.rand() > self.epsilon) and (self.action_counts.min() > 0):
reward_preds = self.reward_counts / self.action_counts
return reward_preds.argsort()[::-1][: self.len_list]
else:
return self.random_.choice(
self.n_actions, size=self.len_list, replace=False
Expand All @@ -76,6 +73,7 @@ def update_params(self, action: int, reward: float) -> None:
reward: float
Observed reward for the chosen action and position.
"""
self.n_trial += 1
self.action_counts_temp[action] += 1
n, old_reward = self.action_counts_temp[action], self.reward_counts_temp[action]
self.reward_counts_temp[action] = (old_reward * (n - 1) / n) + (reward / n)
Expand Down Expand Up @@ -172,13 +170,11 @@ def select_action(self) -> np.ndarray:
selected_actions: array-like shape (len_list, )
List of selected actions.
"""
self.n_trial += 1
theta = self.random_.beta(
a=self.reward_counts + self.alpha,
b=(self.action_counts - self.reward_counts) + self.beta,
)
unsorted_max_arms = np.argpartition(-theta, self.len_list)[: self.len_list]
return unsorted_max_arms[np.argsort(-theta[unsorted_max_arms])]
return theta.argsort()[::-1][: self.len_list]

def update_params(self, action: int, reward: float) -> None:
"""Update policy parameters.
Expand All @@ -191,6 +187,7 @@ def update_params(self, action: int, reward: float) -> None:
reward: float
Observed reward for the chosen action and position.
"""
self.n_trial += 1
self.action_counts_temp[action] += 1
self.reward_counts_temp[action] += reward
if self.n_trial % self.batch_size == 0:
Expand Down
Loading

0 comments on commit 398153e

Please sign in to comment.