Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
usaito committed Feb 7, 2021
1 parent c994893 commit 023c5b5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
15 changes: 11 additions & 4 deletions obp/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def obtain_batch_bandit_feedback(self, n_rounds: int) -> BanditFeedback:
action = np.array(
[
self.random_.choice(
np.arange(self.n_actions), p=behavior_policy_[i],
np.arange(self.n_actions),
p=behavior_policy_[i],
)
for i in np.arange(n_rounds)
]
Expand Down Expand Up @@ -289,7 +290,9 @@ def calc_ground_truth_policy_value(


def logistic_reward_function(
context: np.ndarray, action_context: np.ndarray, random_state: Optional[int] = None,
context: np.ndarray,
action_context: np.ndarray,
random_state: Optional[int] = None,
) -> np.ndarray:
"""Logistic mean reward function for synthetic bandit datasets.
Expand Down Expand Up @@ -328,7 +331,9 @@ def logistic_reward_function(


def linear_reward_function(
context: np.ndarray, action_context: np.ndarray, random_state: Optional[int] = None,
context: np.ndarray,
action_context: np.ndarray,
random_state: Optional[int] = None,
) -> np.ndarray:
"""Linear mean reward function for synthetic bandit datasets.
Expand Down Expand Up @@ -367,7 +372,9 @@ def linear_reward_function(


def linear_behavior_policy(
context: np.ndarray, action_context: np.ndarray, random_state: Optional[int] = None,
context: np.ndarray,
action_context: np.ndarray,
random_state: Optional[int] = None,
) -> np.ndarray:
"""Linear contextual behavior policy for synthetic bandit datasets.
Expand Down
20 changes: 15 additions & 5 deletions tests/dataset/test_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,22 @@ def test_synthetic_obtain_batch_bandit_feedback():
]

valid_input_of_calc_policy_value = [
(np.ones((2, 3)), np.ones((2, 3, 1)), "valid shape",),
(
np.ones((2, 3)),
np.ones((2, 3, 1)),
"valid shape",
),
]


@pytest.mark.parametrize(
"expected_reward, action_dist, description", invalid_input_of_calc_policy_value,
"expected_reward, action_dist, description",
invalid_input_of_calc_policy_value,
)
def test_synthetic_calc_policy_value_using_invalid_inputs(
expected_reward, action_dist, description,
expected_reward,
action_dist,
description,
):
n_actions = 10
dataset = SyntheticBanditDataset(n_actions=n_actions)
Expand All @@ -130,10 +137,13 @@ def test_synthetic_calc_policy_value_using_invalid_inputs(


@pytest.mark.parametrize(
"expected_reward, action_dist, description", valid_input_of_calc_policy_value,
"expected_reward, action_dist, description",
valid_input_of_calc_policy_value,
)
def test_synthetic_calc_policy_value_using_valid_inputs(
expected_reward, action_dist, description,
expected_reward,
action_dist,
description,
):
n_actions = 10
dataset = SyntheticBanditDataset(n_actions=n_actions)
Expand Down

0 comments on commit 023c5b5

Please sign in to comment.