Skip to content

Commit

Permalink
fix reg model to allow position=None
Browse files Browse the repository at this point in the history
  • Loading branch information
usaito committed Feb 22, 2021
1 parent 885754c commit b638601
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
12 changes: 2 additions & 10 deletions obp/ope/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,9 @@ def fit(
)
n_rounds = context.shape[0]

if self.len_list == 1:
if position is None or self.len_list == 1:
position = np.zeros_like(action)
else:
if not (isinstance(position, np.ndarray) and position.ndim == 1):
raise ValueError(
"when len_list > 1, position must be a 1-dimensional ndarray"
)
if position.max() >= self.len_list:
raise ValueError(
f"position elements must be smaller than len_list, but the maximum value is {position.max()} (>= {self.len_list})"
Expand Down Expand Up @@ -307,13 +303,9 @@ def fit_predict(
f"random_state must be an integer, but {random_state} is given"
)

if self.len_list == 1:
if position is None or self.len_list == 1:
position = np.zeros_like(action)
else:
if not (isinstance(position, np.ndarray) and position.ndim == 1):
raise ValueError(
"when len_list > 1, position must be a 1-dimensional ndarray"
)
if position.max() >= self.len_list:
raise ValueError(
f"position elements must be smaller than len_list, but the maximum value is {position.max()} (>= {self.len_list})"
Expand Down
10 changes: 5 additions & 5 deletions tests/ope/test_regression_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@
np.arange(n_rounds) % n_actions,
np.random.uniform(size=n_rounds),
np.ones(n_rounds) * 2,
None, #
np.ones((n_rounds, 2)), #
np.random.uniform(size=(n_actions, 8)),
n_actions,
len_list,
Expand All @@ -463,7 +463,7 @@
None,
3,
1,
"when len_list > 1, position must be a 1-dimensional ndarray",
"position must be 1-dimensional",
),
(
np.random.uniform(size=(n_rounds, 7)),
Expand Down Expand Up @@ -650,16 +650,16 @@
np.arange(n_rounds) % n_actions,
np.random.uniform(size=n_rounds),
None,
np.random.choice(len_list, size=n_rounds),
None,
np.random.uniform(size=(n_actions, 8)),
n_actions,
len_list,
1,
"normal",
Ridge(**hyperparams["ridge"]),
None,
1,
1,
"valid input without pscore and action_dist",
"valid input without pscore, position, and action_dist",
),
(
np.random.uniform(size=(n_rounds, 7)),
Expand Down

0 comments on commit b638601

Please sign in to comment.