Skip to content

Commit

Permalink
fix linear and logistic check-scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
fullflu committed Sep 5, 2021
1 parent f8a4832 commit 7d8d0ce
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 39 deletions.
29 changes: 12 additions & 17 deletions obp/policy/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from dataclasses import dataclass

import numpy as np
from sklearn.utils import check_scalar

from .base import BaseContextualPolicy
from ..utils import check_array


@dataclass
Expand Down Expand Up @@ -123,10 +125,7 @@ class LinEpsilonGreedy(BaseLinPolicy):

def __post_init__(self) -> None:
"""Initialize class."""
if not 0 <= self.epsilon <= 1:
raise ValueError(
f"epsilon must be between 0 and 1, but {self.epsilon} is given"
)
check_scalar(self.epsilon, "epsilon", float, min_val=0.0, max_val=1.0)
self.policy_name = f"linear_epsilon_greedy_{self.epsilon}"

super().__post_init__()
Expand All @@ -145,10 +144,9 @@ def select_action(self, context: np.ndarray) -> np.ndarray:
List of selected actions.
"""
if context.ndim != 2 or context.shape[0] != 1:
raise ValueError(
f"context shape must be (1, dim_context),but {context.shape} is given"
)
check_array(array=context, name="context", expected_dim=2)
if context.shape[0] != 1:
raise ValueError("Expected `context.shape[1] == 1`, but found it False")

if self.random_.rand() > self.epsilon:
self.theta_hat = np.concatenate(
Expand Down Expand Up @@ -189,7 +187,7 @@ class LinUCB(BaseLinPolicy):
Controls the random seed in sampling actions.
epsilon: float, default=0.
Exploration hyperparameter that must take value in the range of [0., 1.].
Exploration hyperparameter that must be greater than or equal to 0.0.
References
--------------
Expand All @@ -203,10 +201,7 @@ class LinUCB(BaseLinPolicy):

def __post_init__(self) -> None:
"""Initialize class."""
if self.epsilon < 0:
raise ValueError(
f"epsilon must be positive scalar, but {self.epsilon} is given"
)
check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
self.policy_name = f"linear_ucb_{self.epsilon}"

super().__post_init__()
Expand All @@ -225,10 +220,10 @@ def select_action(self, context: np.ndarray) -> np.ndarray:
List of selected actions.
"""
if context.ndim != 2 or context.shape[0] != 1:
raise ValueError(
f"context shape must be (1, dim_context),but {context.shape} is given"
)
check_array(array=context, name="context", expected_dim=2)
if context.shape[0] != 1:
raise ValueError("Expected `context.shape[1] == 1`, but found it False")

self.theta_hat = np.concatenate(
[
self.A_inv[i] @ self.b[:, i][:, np.newaxis]
Expand Down
28 changes: 10 additions & 18 deletions obp/policy/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional

import numpy as np
from sklearn.utils import check_random_state
from sklearn.utils import check_random_state, check_scalar
from scipy.optimize import minimize

from .base import BaseContextualPolicy
Expand Down Expand Up @@ -49,15 +49,13 @@ class BaseLogisticPolicy(BaseContextualPolicy):
def __post_init__(self) -> None:
"""Initialize class."""
super().__post_init__()
if not isinstance(self.alpha_, float) or self.alpha_ <= 0.0:
raise ValueError(
f"alpha_ should be a positive float, but {self.alpha_} is given"
)
check_scalar(self.alpha_, "alpha_", float)
if self.alpha_ <= 0.0:
raise ValueError(f"`alpha_`= {self.alpha_}, must be > 0.0.")

if not isinstance(self.lambda_, float) or self.lambda_ <= 0.0:
raise ValueError(
f"lambda_ should be a positive float, but {self.lambda_} is given"
)
check_scalar(self.lambda_, "lambda_", float)
if self.alpha_ <= 0.0:
raise ValueError(f"`lambda_`= {self.lambda_}, must be > 0.0.")

self.alpha_list = self.alpha_ * np.ones(self.n_actions)
self.lambda_list = self.lambda_ * np.ones(self.n_actions)
Expand Down Expand Up @@ -138,10 +136,7 @@ class LogisticEpsilonGreedy(BaseLogisticPolicy):

def __post_init__(self) -> None:
"""Initialize class."""
if not 0 <= self.epsilon <= 1:
raise ValueError(
f"epsilon must be between 0 and 1, but {self.epsilon} is given"
)
check_scalar(self.epsilon, "epsilon", float, min_val=0.0, max_val=1.0)
self.policy_name = f"logistic_egreedy_{self.epsilon}"

super().__post_init__()
Expand Down Expand Up @@ -200,7 +195,7 @@ class LogisticUCB(BaseLogisticPolicy):
Regularization hyperparameter for the online logistic regression.
epsilon: float, default=0.
Exploration hyperparameter that must take value in the range of [0., 1.].
Exploration hyperparameter that must be greater than or equal to 0.0.
References
----------
Expand All @@ -213,10 +208,7 @@ class LogisticUCB(BaseLogisticPolicy):

def __post_init__(self) -> None:
"""Initialize class."""
if self.epsilon < 0:
raise ValueError(
f"epsilon must be positive scalar, but {self.epsilon} is given"
)
check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
self.policy_name = f"logistic_ucb_{self.epsilon}"

super().__post_init__()
Expand Down
8 changes: 4 additions & 4 deletions tests/policy/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_linear_base_exception():
with pytest.raises(ValueError):
LinEpsilonGreedy(n_actions=2, dim=0)

with pytest.raises(ValueError):
with pytest.raises(TypeError):
LinEpsilonGreedy(n_actions=2, dim="3")

# invalid n_actions
Expand All @@ -25,7 +25,7 @@ def test_linear_base_exception():
with pytest.raises(ValueError):
LinEpsilonGreedy(n_actions=1, dim=2)

with pytest.raises(ValueError):
with pytest.raises(TypeError):
LinEpsilonGreedy(n_actions="2", dim=2)

# invalid len_list
Expand All @@ -35,7 +35,7 @@ def test_linear_base_exception():
with pytest.raises(ValueError):
LinEpsilonGreedy(n_actions=2, dim=2, len_list=0)

with pytest.raises(ValueError):
with pytest.raises(TypeError):
LinEpsilonGreedy(n_actions=2, dim=2, len_list="3")

# invalid batch_size
Expand All @@ -45,7 +45,7 @@ def test_linear_base_exception():
with pytest.raises(ValueError):
LinEpsilonGreedy(n_actions=2, dim=2, batch_size=0)

with pytest.raises(ValueError):
with pytest.raises(TypeError):
LinEpsilonGreedy(n_actions=2, dim=2, batch_size="10")

# invalid relationship between n_actions and len_list
Expand Down

0 comments on commit 7d8d0ce

Please sign in to comment.