Skip to content

Commit

Permalink
Small fixes for SAC
Browse files Browse the repository at this point in the history
Summary:
- Made output transformer rescale from [-1 + 1e-6, 1 - 1e-6] to match the scaling from serving to training
- Use ActorExporter in unit test
- Made ActorExporter more generic

Reviewed By: econti

Differential Revision: D12814412

fbshipit-source-id: da797404b3b4ed488c81d8377dc8e2f92b952b51
  • Loading branch information
kittipatv authored and facebook-github-bot committed Oct 31, 2018
1 parent 555c8dc commit d7465e7
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
5 changes: 3 additions & 2 deletions ml/rl/models/output_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ def __init__(
self.serving_max_scale = np.array(serving_max_scale, dtype=np.float)
self.serving_min_scale = np.array(serving_min_scale, dtype=np.float)
self.training_max_scale = np.array(
training_max_scale or [1.0] * len(action_feature_ids), dtype=np.float
training_max_scale or [1.0 - 1e-6] * len(action_feature_ids), dtype=np.float
)
self.training_min_scale = np.array(
training_min_scale or [-1.0] * len(action_feature_ids), dtype=np.float
training_min_scale or [-1.0 + 1e-6] * len(action_feature_ids),
dtype=np.float,
)

def create_net(self, original_output):
Expand Down
17 changes: 9 additions & 8 deletions ml/rl/test/gridworld/test_gridworld_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SACModelParameters,
SACTrainingParameters,
)
from ml.rl.training.rl_exporter import ParametricDQNExporter
from ml.rl.training.rl_exporter import ActorExporter, ParametricDQNExporter
from ml.rl.training.sac_trainer import SACTrainer


Expand All @@ -47,7 +47,7 @@ def setUp(self):
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
super(self.__class__, self).setUp()
super(TestGridworldSAC, self).setUp()

def get_sac_parameters(self, use_2_q_functions=False):
return SACModelParameters(
Expand Down Expand Up @@ -140,12 +140,13 @@ def get_actor_predictor(self, trainer, environment):
environment.min_action_range.reshape(-1),
)

def container(actor_network):
return ActorWithPreprocessing(
actor_network, Preprocessor(environment.normalization, False, True)
)

return trainer.actor_predictor(feature_extractor, output_transformer, container)
predictor = ActorExporter(
trainer.actor_network,
feature_extractor,
output_transformer,
Preprocessor(environment.normalization, False, True),
).export()
return predictor

def _test_sac_trainer(self, use_2_q_functions=False, use_gpu=False):
environment = GridworldContinuous()
Expand Down
6 changes: 3 additions & 3 deletions ml/rl/test/models/test_output_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def fetch_blob(b):
keys = fetch_blob("output/float_features.keys")
values = fetch_blob("output/float_features.values")

scaled_actions = (actions + np.ones(len(action_feature_ids))) / 2 * (
serving_max_scale - serving_min_scale
) + serving_min_scale
scaled_actions = (actions + np.ones(len(action_feature_ids)) - 1e-6) / (
(1 - 1e-6) * 2
) * (serving_max_scale - serving_min_scale) + serving_min_scale

npt.assert_array_equal([len(action_feature_ids)] * N, lengths)
npt.assert_array_equal(action_feature_ids * N, keys)
Expand Down
6 changes: 5 additions & 1 deletion ml/rl/training/rl_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ def __init__(
feature_extractor=None,
output_transformer=None,
state_preprocessor=None,
predictor_class=ActorPredictor,
**kwargs
):
super(ActorExporter, self).__init__(dnn, feature_extractor, output_transformer)
self.state_preprocessor = state_preprocessor
self.predictor_class = predictor_class
self.kwargs = kwargs

def export(self):
module_to_export = self.dnn.cpu_model()
Expand All @@ -71,4 +75,4 @@ def export(self):
feature_extractor=self.feature_extractor,
output_transformer=self.output_transformer,
)
return ActorPredictor(pem, ws)
return self.predictor_class(pem, ws, **self.kwargs)

0 comments on commit d7465e7

Please sign in to comment.