diff --git a/obp/dataset/synthetic.py b/obp/dataset/synthetic.py index 83d9620c..50488fff 100644 --- a/obp/dataset/synthetic.py +++ b/obp/dataset/synthetic.py @@ -201,7 +201,8 @@ def obtain_batch_bandit_feedback(self, n_rounds: int) -> BanditFeedback: action = np.array( [ self.random_.choice( - np.arange(self.n_actions), p=behavior_policy_[i], + np.arange(self.n_actions), + p=behavior_policy_[i], ) for i in np.arange(n_rounds) ] @@ -289,7 +290,9 @@ def calc_ground_truth_policy_value( def logistic_reward_function( - context: np.ndarray, action_context: np.ndarray, random_state: Optional[int] = None, + context: np.ndarray, + action_context: np.ndarray, + random_state: Optional[int] = None, ) -> np.ndarray: """Logistic mean reward function for synthetic bandit datasets. @@ -328,7 +331,9 @@ def logistic_reward_function( def linear_reward_function( - context: np.ndarray, action_context: np.ndarray, random_state: Optional[int] = None, + context: np.ndarray, + action_context: np.ndarray, + random_state: Optional[int] = None, ) -> np.ndarray: """Linear mean reward function for synthetic bandit datasets. @@ -367,7 +372,9 @@ def linear_reward_function( def linear_behavior_policy( - context: np.ndarray, action_context: np.ndarray, random_state: Optional[int] = None, + context: np.ndarray, + action_context: np.ndarray, + random_state: Optional[int] = None, ) -> np.ndarray: """Linear contextual behavior policy for synthetic bandit datasets. diff --git a/tests/dataset/test_synthetic.py b/tests/dataset/test_synthetic.py index a495723a..d55d269b 100644 --- a/tests/dataset/test_synthetic.py +++ b/tests/dataset/test_synthetic.py @@ -110,15 +110,22 @@ def test_synthetic_obtain_batch_bandit_feedback(): ] valid_input_of_calc_policy_value = [ - (np.ones((2, 3)), np.ones((2, 3, 1)), "valid shape",), + ( + np.ones((2, 3)), + np.ones((2, 3, 1)), + "valid shape", + ), ] @pytest.mark.parametrize( - "expected_reward, action_dist, description", invalid_input_of_calc_policy_value, + "expected_reward, action_dist, description", + invalid_input_of_calc_policy_value, ) def test_synthetic_calc_policy_value_using_invalid_inputs( - expected_reward, action_dist, description, + expected_reward, + action_dist, + description, ): n_actions = 10 dataset = SyntheticBanditDataset(n_actions=n_actions) @@ -130,10 +137,13 @@ def test_synthetic_calc_policy_value_using_invalid_inputs( @pytest.mark.parametrize( - "expected_reward, action_dist, description", valid_input_of_calc_policy_value, + "expected_reward, action_dist, description", + valid_input_of_calc_policy_value, ) def test_synthetic_calc_policy_value_using_valid_inputs( - expected_reward, action_dist, description, + expected_reward, + action_dist, + description, ): n_actions = 10 dataset = SyntheticBanditDataset(n_actions=n_actions)