Skip to content

Commit

Permalink
fix offline check-scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
fullflu committed Sep 5, 2021
1 parent 7d8d0ce commit e979ca8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 85 deletions.
82 changes: 22 additions & 60 deletions obp/policy/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,10 +484,7 @@ def __post_init__(self) -> None:
if self.len_list != 1:
raise NotImplementedError("currently, len_list > 1 is not supported")

if not isinstance(self.dim_context, int) or self.dim_context <= 0:
raise ValueError(
f"dim_context must be a positive integer, but {self.dim_context} is given"
)
check_scalar(self.dim_context, "dim_context", int, min_val=1)

if not callable(self.off_policy_objective):
raise ValueError(
Expand All @@ -506,10 +503,7 @@ def __post_init__(self) -> None:
f"solver must be one of 'adam', 'lbfgs', or 'sgd', but {self.solver} is given"
)

if not isinstance(self.alpha, float) or self.alpha < 0.0:
raise ValueError(
f"alpha must be a non-negative float, but {self.alpha} is given"
)
check_scalar(self.alpha, "alpha", float, min_val=0.0)

if self.batch_size != "auto" and (
not isinstance(self.batch_size, int) or self.batch_size <= 0
Expand All @@ -518,29 +512,22 @@ def __post_init__(self) -> None:
f"batch_size must be a positive integer or 'auto', but {self.batch_size} is given"
)

if (
not isinstance(self.learning_rate_init, float)
or self.learning_rate_init <= 0.0
):
check_scalar(self.learning_rate_init, "learning_rate_init", float)
if self.learning_rate_init <= 0.0:
raise ValueError(
f"learning_rate_init must be a positive float, but {self.learning_rate_init} is given"
f"`learning_rate_init`= {self.learning_rate_init}, must be > 0.0"
)

if not isinstance(self.max_iter, int) or self.max_iter <= 0:
raise ValueError(
f"max_iter must be a positive integer, but {self.max_iter} is given"
)
check_scalar(self.max_iter, "max_iter", int, min_val=1)

if not isinstance(self.shuffle, bool):
raise ValueError(f"shuffle must be a bool, but {self.shuffle} is given")

if not isinstance(self.tol, float) or self.tol <= 0.0:
raise ValueError(f"tol must be a positive float, but {self.tol} is given")
check_scalar(self.tol, "tol", float)
if self.tol <= 0.0:
raise ValueError(f"`tol`= {self.tol}, must be > 0.0")

if not isinstance(self.momentum, float) or not 0.0 <= self.momentum <= 1.0:
raise ValueError(
f"momentum must be a float in [0., 1.], but {self.momentum} is given"
)
check_scalar(self.momentum, "momentum", float, min_val=0.0, max_val=1.0)

if not isinstance(self.nesterovs_momentum, bool):
raise ValueError(
Expand All @@ -557,43 +544,19 @@ def __post_init__(self) -> None:
f"if early_stopping is True, solver must be one of 'sgd' or 'adam', but {self.solver} is given"
)

if (
not isinstance(self.validation_fraction, float)
or not 0.0 < self.validation_fraction <= 1.0
):
raise ValueError(
f"validation_fraction must be a float in (0., 1.], but {self.validation_fraction} is given"
)

if not isinstance(self.beta_1, float) or not 0.0 <= self.beta_1 <= 1.0:
raise ValueError(
f"beta_1 must be a float in [0. 1.], but {self.beta_1} is given"
)

if not isinstance(self.beta_2, float) or not 0.0 <= self.beta_2 <= 1.0:
raise ValueError(
f"beta_2 must be a float in [0., 1.], but {self.beta_2} is given"
)

if not isinstance(self.beta_2, float) or not 0.0 <= self.beta_2 <= 1.0:
raise ValueError(
f"beta_2 must be a float in [0., 1.], but {self.beta_2} is given"
)

if not isinstance(self.epsilon, float) or self.epsilon < 0.0:
raise ValueError(
f"epsilon must be a non-negative float, but {self.epsilon} is given"
)

if not isinstance(self.n_iter_no_change, int) or self.n_iter_no_change <= 0:
check_scalar(
self.validation_fraction, "validation_fraction", float, max_val=1.0
)
if self.validation_fraction <= 0.0:
raise ValueError(
f"n_iter_no_change must be a positive integer, but {self.n_iter_no_change} is given"
f"`validation_fraction`= {self.validation_fraction}, must be > 0.0"
)

if not isinstance(self.max_fun, int) or self.max_fun <= 0:
raise ValueError(
f"max_fun must be a positive integer, but {self.max_fun} is given"
)
check_scalar(self.beta_1, "beta_1", float, min_val=0.0, max_val=1.0)
check_scalar(self.beta_2, "beta_2", float, min_val=0.0, max_val=1.0)
check_scalar(self.epsilon, "epsilon", float, min_val=0.0)
check_scalar(self.n_iter_no_change, "n_iter_no_change", int, min_val=1)
check_scalar(self.max_fun, "max_fun", int, min_val=1)

if self.random_state is not None:
self.random_ = check_random_state(self.random_state)
Expand Down Expand Up @@ -665,10 +628,9 @@ def _create_train_data_for_opl(
"""
if self.batch_size == "auto":
batch_size_ = min(200, context.shape[0])
elif isinstance(self.batch_size, int) and self.batch_size > 0:
batch_size_ = self.batch_size
else:
raise ValueError("batch_size must be a positive integer or 'auto'")
check_scalar(self.batch_size, "batch_size", int, min_val=1)
batch_size_ = self.batch_size

dataset = NNPolicyDataset(
torch.from_numpy(context).float(),
Expand Down
50 changes: 25 additions & 25 deletions tests/policy/test_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
0, #
1,
base_classifier,
"n_actions must be an integer larger than 1",
"`n_actions`= 0, must be >= 1",
),
(
10,
-1, #
base_classifier,
"len_list must be a positive integer",
"`len_list`= -1, must be >= 0",
),
(
10,
20, #
base_classifier,
"Expected `n_actions",
"`len_list`= 20, must be <= 10",
),
(10, 1, base_regressor, "base_classifier must be a classifier"),
]
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"n_actions must be an integer larger than 1",
"`n_actions`= 0, must be >= 1",
),
(
10,
Expand All @@ -281,7 +281,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"len_list must be a positive integer",
"`len_list`= -1, must be >= 0",
),
(
10,
Expand All @@ -307,7 +307,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"dim_context must be a positive integer",
"`dim_context`= -1, must be >= 0",
),
(
10,
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_ipw_learner_sample_action():
(100, 50, 100),
"relu",
"adam",
-1, #
-1.0, #
"auto",
0.0001,
200,
Expand All @@ -437,7 +437,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"alpha must be a non-negative float",
"`alpha`= -1.0, must be >= 0.0",
),
(
10,
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_ipw_learner_sample_action():
"adam",
0.001,
"auto",
0, #
0.0, #
200,
True,
123,
Expand All @@ -489,7 +489,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"learning_rate_init must be a positive float",
"`learning_rate_init`= 0.0, must be > 0.0",
),
(
10,
Expand All @@ -515,7 +515,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"max_iter must be a positive integer",
"`max_iter`= 0, must be >= 1",
),
(
10,
Expand Down Expand Up @@ -583,7 +583,7 @@ def test_ipw_learner_sample_action():
200,
True,
123,
-1, #
-1.0, #
0.9,
True,
True,
Expand All @@ -593,7 +593,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"tol must be a positive float",
"`tol`= -1.0, must be > 0.0",
),
(
10,
Expand All @@ -610,7 +610,7 @@ def test_ipw_learner_sample_action():
True,
123,
1e-4,
2, #
2.0, #
True,
True,
0.1,
Expand All @@ -619,7 +619,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
15000,
"momentum must be a float in [0., 1.]",
"`momentum`= 2.0, must be <= 1.0",
),
(
10,
Expand Down Expand Up @@ -717,13 +717,13 @@ def test_ipw_learner_sample_action():
0.9,
True,
True,
2, #
2.0, #
0.9,
0.999,
1e-8,
10,
15000,
"validation_fraction must be a float in",
"`validation_fraction`= 2.0, must be <= 1.0",
),
(
10,
Expand All @@ -744,12 +744,12 @@ def test_ipw_learner_sample_action():
True,
True,
0.1,
2, #
2.0, #
0.999,
1e-8,
10,
15000,
"beta_1 must be a float in [0. 1.]",
"`beta_1`= 2.0, must be <= 1.0",
),
(
10,
Expand All @@ -771,11 +771,11 @@ def test_ipw_learner_sample_action():
True,
0.1,
0.9,
2, #
2.0, #
1e-8,
10,
15000,
"beta_2 must be a float in [0., 1.]",
"`beta_2`= 2.0, must be <= 1.0",
),
(
10,
Expand All @@ -798,10 +798,10 @@ def test_ipw_learner_sample_action():
0.1,
0.9,
0.999,
-1, #
-1.0, #
10,
15000,
"epsilon must be a non-negative float",
"`epsilon`= -1.0, must be >= 0.0",
),
(
10,
Expand All @@ -827,7 +827,7 @@ def test_ipw_learner_sample_action():
1e-8,
0, #
15000,
"n_iter_no_change must be a positive integer",
"`n_iter_no_change`= 0, must be >= 1",
),
(
10,
Expand All @@ -853,7 +853,7 @@ def test_ipw_learner_sample_action():
1e-8,
10,
0, #
"max_fun must be a positive integer",
"`max_fun`= 0, must be >= 1",
),
]

Expand Down

0 comments on commit e979ca8

Please sign in to comment.