Skip to content

Commit

Permalink
ActorCriticDataModule (facebookresearch#491)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#491

Reviewed By: czxttkl, bankawas

Differential Revision: D29251412

fbshipit-source-id: 0a6cbcf59956ecc113e9425079f91a6b3098c2de
  • Loading branch information
kittipatv authored and facebook-github-bot committed Jun 25, 2021
1 parent d5394c5 commit 395b079
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 51 deletions.
10 changes: 10 additions & 0 deletions reagent/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from .data_fetcher import DataFetcher
from .manual_data_module import ManualDataModule
from .reagent_data_module import ReAgentDataModule

__all__ = [
"DataFetcher",
"ManualDataModule",
"ReAgentDataModule",
]
4 changes: 1 addition & 3 deletions reagent/data/manual_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ def __getattr__(self, attr):
)
return normalization_data

raise AttributeError(
f"attr {attr} not available {type(self)} (subclass of ModelManager)."
)
raise AttributeError(f"attr {attr} not available {type(self)}")

@property
@abc.abstractmethod
Expand Down
145 changes: 97 additions & 48 deletions reagent/model_managers/actor_critic_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
NormalizationData,
NormalizationKey,
)
from reagent.data.data_fetcher import DataFetcher
from reagent.data.reagent_data_module import ReAgentDataModule
from reagent.data import DataFetcher, ReAgentDataModule, ManualDataModule
from reagent.evaluation.evaluator import get_metrics_to_score
from reagent.gym.policies.policy import Policy
from reagent.gym.policies.predictor_policies import create_predictor_policy_from_model
Expand Down Expand Up @@ -104,7 +103,7 @@ def __post_init_post_parse__(self):

@property
def should_generate_eval_dataset(self) -> bool:
return self.eval_parameters.calc_cpe_in_training
raise NotImplementedError

def create_policy(self, serving: bool) -> Policy:
"""Create online actor critic policy."""
Expand Down Expand Up @@ -172,28 +171,7 @@ def get_action_preprocessing_options(self) -> PreprocessingOptions:
def run_feature_identification(
self, input_table_spec: TableSpec
) -> Dict[str, NormalizationData]:
# Run state feature identification
state_normalization_parameters = identify_normalization_parameters(
input_table_spec,
InputColumn.STATE_FEATURES,
self.get_state_preprocessing_options(),
)

# Run action feature identification
action_normalization_parameters = identify_normalization_parameters(
input_table_spec,
InputColumn.ACTION,
self.get_action_preprocessing_options(),
)

return {
NormalizationKey.STATE: NormalizationData(
dense_normalization_parameters=state_normalization_parameters
),
NormalizationKey.ACTION: NormalizationData(
dense_normalization_parameters=action_normalization_parameters
),
}
raise NotImplementedError

@property
def required_normalization_keys(self) -> List[str]:
Expand All @@ -206,28 +184,29 @@ def query_data(
reward_options: RewardOptions,
data_fetcher: DataFetcher,
) -> Dataset:
logger.info("Starting query")
return data_fetcher.query_data(
input_table_spec=input_table_spec,
discrete_action=False,
include_possible_actions=False,
custom_reward_expression=reward_options.custom_reward_expression,
sample_range=sample_range,
)
raise NotImplementedError

def build_batch_preprocessor(self, use_gpu: bool) -> BatchPreprocessor:
state_preprocessor = Preprocessor(
self.state_normalization_data.dense_normalization_parameters,
use_gpu=use_gpu,
)
action_preprocessor = Preprocessor(
self.action_normalization_data.dense_normalization_parameters,
use_gpu=use_gpu,
)
return PolicyNetworkBatchPreprocessor(
state_preprocessor=state_preprocessor,
action_preprocessor=action_preprocessor,
use_gpu=use_gpu,
raise NotImplementedError

def get_data_module(
self,
*,
input_table_spec: Optional[TableSpec] = None,
reward_options: Optional[RewardOptions] = None,
reader_options: Optional[ReaderOptions] = None,
setup_data: Optional[Dict[str, bytes]] = None,
saved_setup_data: Optional[Dict[str, bytes]] = None,
resource_options: Optional[ResourceOptions] = None,
) -> Optional[ReAgentDataModule]:
return ActorCriticDataModule(
input_table_spec=input_table_spec,
reward_options=reward_options,
setup_data=setup_data,
saved_setup_data=saved_setup_data,
reader_options=reader_options,
resource_options=resource_options,
model_manager=self,
)

def get_reporter(self):
Expand All @@ -244,11 +223,11 @@ def train(
reader_options: ReaderOptions,
resource_options: ResourceOptions,
) -> RLTrainingOutput:
batch_preprocessor = self.build_batch_preprocessor(resource_options.use_gpu)
reporter = self.get_reporter()
# pyre-fixme[16]: `Trainer` has no attribute `set_reporter`.
# pyre-fixme[16]: `Trainer` has no attribute `set_reporter`.
self.trainer.set_reporter(reporter)
assert data_module

# assert eval_dataset is None

Expand All @@ -261,8 +240,7 @@ def train(
data_module=data_module,
num_epochs=num_epochs,
logger_name="ActorCritic",
batch_preprocessor=batch_preprocessor,
reader_options=self.reader_options,
reader_options=reader_options,
checkpoint_path=self._lightning_checkpoint_path,
resource_options=resource_options or ResourceOptions(),
)
Expand All @@ -278,3 +256,74 @@ def train(
return RLTrainingOutput(
training_report=training_report, logger_data=logger_data
)


class ActorCriticDataModule(ManualDataModule):
def run_feature_identification(
self, input_table_spec: TableSpec
) -> Dict[str, NormalizationData]:
"""
Derive preprocessing parameters from data. The keys of the dict should
match the keys from `required_normalization_keys()`
"""
# Run state feature identification
state_normalization_parameters = identify_normalization_parameters(
input_table_spec,
InputColumn.STATE_FEATURES,
self.model_manager.get_state_preprocessing_options(),
)

# Run action feature identification
action_normalization_parameters = identify_normalization_parameters(
input_table_spec,
InputColumn.ACTION,
self.model_manager.get_action_preprocessing_options(),
)

return {
NormalizationKey.STATE: NormalizationData(
dense_normalization_parameters=state_normalization_parameters
),
NormalizationKey.ACTION: NormalizationData(
dense_normalization_parameters=action_normalization_parameters
),
}

@property
def required_normalization_keys(self) -> List[str]:
"""Get the normalization keys required for current instance"""
return [NormalizationKey.STATE, NormalizationKey.ACTION]

@property
def should_generate_eval_dataset(self) -> bool:
return self.model_manager.eval_parameters.calc_cpe_in_training

def query_data(
self,
input_table_spec: TableSpec,
sample_range: Optional[Tuple[float, float]],
reward_options: RewardOptions,
data_fetcher: DataFetcher,
) -> Dataset:
return data_fetcher.query_data(
input_table_spec=input_table_spec,
discrete_action=False,
include_possible_actions=False,
custom_reward_expression=reward_options.custom_reward_expression,
sample_range=sample_range,
)

def build_batch_preprocessor(self) -> BatchPreprocessor:
state_preprocessor = Preprocessor(
self.state_normalization_data.dense_normalization_parameters,
use_gpu=self.resource_options.use_gpu,
)
action_preprocessor = Preprocessor(
self.action_normalization_data.dense_normalization_parameters,
use_gpu=self.resource_options.use_gpu,
)
return PolicyNetworkBatchPreprocessor(
state_preprocessor=state_preprocessor,
action_preprocessor=action_preprocessor,
use_gpu=self.resource_options.use_gpu,
)

0 comments on commit 395b079

Please sign in to comment.