Skip to content

Commit

Permalink
Remove trainer property from modelmanager
Browse files Browse the repository at this point in the history
Summary: This is the start of making model manager stateless to reduce complexity

Reviewed By: czxttkl

Differential Revision: D29253248

fbshipit-source-id: 681d141cb46784e40c8802f2325c1636044c61de
  • Loading branch information
kittipatv authored and facebook-github-bot committed Jun 25, 2021
1 parent 9b802c7 commit e23e20e
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
3 changes: 1 addition & 2 deletions reagent/model_managers/discrete_dqn_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def create_policy(self, serving: bool) -> Policy:
)
else:
sampler = GreedyActionSampler()
# pyre-fixme[16]: `RLTrainer` has no attribute `q_network`.
scorer = discrete_dqn_scorer(self.trainer.q_network)
scorer = discrete_dqn_scorer(self._q_network)
return Policy(scorer=scorer, sampler=sampler)

@property
Expand Down
9 changes: 6 additions & 3 deletions reagent/model_managers/model_based/seq2reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ class Seq2RewardModel(WorldModelBase):
# pyre-fixme[15]: `build_trainer` overrides method defined in `ModelManager`
# inconsistently.
def build_trainer(self, use_gpu: bool) -> Seq2RewardTrainer:
seq2reward_network = self.net_builder.value.build_value_network(
self.state_normalization_data
)
# pyre-fixme[16]: `Seq2RewardModel` has no attribute `_seq2reward_network`.
self._seq2reward_network = (
seq2reward_network
) = self.net_builder.value.build_value_network(self.state_normalization_data)
trainer = Seq2RewardTrainer(
seq2reward_network=seq2reward_network, params=self.trainer_param
)
# pyre-fixme[16]: `Seq2RewardModel` has no attribute `_step_predict_network`.
self._step_predict_network = trainer.step_predict_network
return trainer

def get_reporter(self) -> Seq2RewardReporter:
Expand Down
2 changes: 1 addition & 1 deletion reagent/model_managers/model_based/synthetic_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def build_trainer(self, use_gpu: bool) -> RewardNetTrainer:

def get_reporter(self):
return RewardNetworkReporter(
self.trainer.loss_type,
self.trainer_param.loss_type,
str(self.net_builder.value),
)

Expand Down

0 comments on commit e23e20e

Please sign in to comment.