Skip to content

Commit

Permalink
Address minor issues after clarification by @araffin
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamGleave committed Jul 8, 2020
1 parent e61d34a commit 91bbc28
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
10 changes: 6 additions & 4 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}")

# check if observation space and action space are part of the saved parameters
if ("observation_space" not in data or "action_space" not in data) and "env" not in data:
raise ValueError("The observation_space and action_space was not given, can't verify new environments")
if "observation_space" not in data or "action_space" not in data:
raise KeyError("The observation_space and action_space were not given, can't verify new environments")
# check if given env is valid
if env is not None:
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
Expand Down Expand Up @@ -425,8 +425,10 @@ def _setup_learn(self,
:return: (Tuple[int, BaseCallback])
"""
self.start_time = time.time()
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)
if self.ep_info_buffer is None or reset_num_timesteps:
# Initialize buffers if they don't exist, or reinitialize if resetting counters
self.ep_info_buffer = deque(maxlen=100)
self.ep_success_buffer = deque(maxlen=100)

if self.action_noise is not None:
self.action_noise.reset()
Expand Down
14 changes: 14 additions & 0 deletions stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ class Distribution(ABC):
def __init__(self):
super(Distribution, self).__init__()

@abstractmethod
def proba_distribution_net(self, *args, **kwargs):
"""Create the layers and parameters that represent the distribution.
Subclasses must define this, but the arguments and return type vary between
concrete classes."""

@abstractmethod
def proba_distribution(self, *args, **kwargs) -> 'Distribution':
"""Set parameters of the distribution.
:return: (Distribution) self
"""

@abstractmethod
def log_prob(self, x: th.Tensor) -> th.Tensor:
"""
Expand Down
3 changes: 3 additions & 0 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def predict(self,
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
(used in recurrent policies)
"""
# TODO (GH/1): add support for RNN policies
# if state is None:
# state = self.initial_state
# if mask is None:
Expand Down Expand Up @@ -438,6 +439,8 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None:
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, BernoulliDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
else:
raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")

self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
# Init weights: use orthogonal initialization
Expand Down

0 comments on commit 91bbc28

Please sign in to comment.