Skip to content

Commit

Permalink
reflect review
Browse files Browse the repository at this point in the history
  • Loading branch information
fullflu committed Sep 5, 2021
1 parent 1eeb28f commit 664416c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
4 changes: 2 additions & 2 deletions obp/policy/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def select_action(self, context: np.ndarray) -> np.ndarray:
"""
check_array(array=context, name="context", expected_dim=2)
if context.shape[0] != 1:
raise ValueError("Expected `context.shape[1] == 1`, but found it False")
raise ValueError("Expected `context.shape[0] == 1`, but found it False")

if self.random_.rand() > self.epsilon:
self.theta_hat = np.concatenate(
Expand Down Expand Up @@ -222,7 +222,7 @@ def select_action(self, context: np.ndarray) -> np.ndarray:
"""
check_array(array=context, name="context", expected_dim=2)
if context.shape[0] != 1:
raise ValueError("Expected `context.shape[1] == 1`, but found it False")
raise ValueError("Expected `context.shape[0] == 1`, but found it False")

self.theta_hat = np.concatenate(
[
Expand Down
2 changes: 1 addition & 1 deletion obp/policy/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __post_init__(self) -> None:
raise ValueError(f"`alpha_`= {self.alpha_}, must be > 0.0.")

check_scalar(self.lambda_, "lambda_", float)
if self.alpha_ <= 0.0:
if self.lambda_ <= 0.0:
raise ValueError(f"`lambda_`= {self.lambda_}, must be > 0.0.")

self.alpha_list = self.alpha_ * np.ones(self.n_actions)
Expand Down
49 changes: 49 additions & 0 deletions tests/policy/test_logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,55 @@
from obp.policy.logistic import MiniBatchLogisticRegression


def test_logistic_base_exception():
# invalid dim
with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, dim=-3)

with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, dim=0)

with pytest.raises(TypeError):
LogisticEpsilonGreedy(n_actions=2, dim="3")

# invalid n_actions
with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=-3, dim=2)

with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=1, dim=2)

with pytest.raises(TypeError):
LogisticEpsilonGreedy(n_actions="2", dim=2)

# invalid len_list
with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, dim=2, len_list=-3)

with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, dim=2, len_list=0)

with pytest.raises(TypeError):
LogisticEpsilonGreedy(n_actions=2, dim=2, len_list="3")

# invalid batch_size
with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, dim=2, batch_size=-2)

with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, dim=2, batch_size=0)

with pytest.raises(TypeError):
LogisticEpsilonGreedy(n_actions=2, dim=2, batch_size="10")

# invalid relationship between n_actions and len_list
with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=5, len_list=10, dim=2)

with pytest.raises(ValueError):
LogisticEpsilonGreedy(n_actions=2, len_list=3, dim=2)


def test_logistic_epsilon_normal_epsilon():

policy1 = LogisticEpsilonGreedy(n_actions=2, dim=2)
Expand Down

0 comments on commit 664416c

Please sign in to comment.