Skip to content

Commit

Permalink
Fix type annotations (DLR-RM#522)
Browse files Browse the repository at this point in the history
* Fix type annotations

* Add citation file

* Update CITATION.cff

* Add note about tb logging

Co-authored-by: Anssi <[email protected]>
  • Loading branch information
araffin and Miffyli authored Jul 29, 2021
1 parent 5034259 commit be86883
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 16 deletions.
30 changes: 30 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
cff-version: "1.1.0"
message: "If you use this software, please cite it as below."
authors:
-
family-names: Raffin
given-names: Antonin
orcid: https://orcid.org/0000-0001-6036-6950
-
family-names: Hill
given-names: Ashley
-
family-names: Enerstus
given-names: Maximilian
-
family-names: Gleave
given-names: Adam
orcid: https://orcid.org/0000-0002-3467-528X
-
family-names: Kanervisto
given-names: Anssi
orcid: https://orcid.org/0000-0002-7479-4574
-
family-names: Dormann
given-names: Noah

title: "Stable Baselines3"
repository-code: "https://github.com/DLR-RM/stable-baselines3"
date-released: 2020-05-05
license: MIT
doi: # TODO when paper is released
9 changes: 8 additions & 1 deletion docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ Here is a simple example on how to log both additional tensor or arbitrary scala
model.learn(50000, callback=TensorboardCallback())
.. note::

If you want to log values more often than the default to tensorboard, you manually call ``self.logger.dump(self.num_timesteps)`` in a callback
(see `issue #506 <https://github.com/DLR-RM/stable-baselines3/issues/506>`_).


Logging Images
--------------

Expand Down Expand Up @@ -230,7 +237,7 @@ Here is an example of how to render an episode and log the resulting video to Te
Directly Accessing The Summary Writer
-------------------------------------

If you would like to log arbitrary data (in one of the formats supported by `pytorch <https://pytorch.org/docs/stable/tensorboard.html>`_), you
If you would like to log arbitrary data (in one of the formats supported by `pytorch <https://pytorch.org/docs/stable/tensorboard.html>`_), you
can get direct access to the underlying SummaryWriter in a callback:

.. warning::
Expand Down
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Release 1.2.0a0 (WIP)
Release 1.2.0a1 (WIP)
---------------------------

Breaking Changes:
Expand All @@ -23,11 +23,13 @@ Deprecations:
Others:
^^^^^^^
- Enabled Python 3.9 in GitHub CI
- Fixed type annotations

Documentation:
^^^^^^^^^^^^^^
- Updated multiprocessing example
- Added example of ``VecEnvWrapper``
- Added a note about logging to tensorboard more often


Release 1.1.0 (2021-07-01)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
env: Union[GymEnv, str, None],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Schedule],
policy_kwargs: Dict[str, Any] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = "auto",
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def __init__(
callback_on_new_best: Optional[BaseCallback] = None,
n_eval_episodes: int = 5,
eval_freq: int = 10000,
log_path: str = None,
best_model_save_path: str = None,
log_path: Optional[str] = None,
best_model_save_path: Optional[str] = None,
deterministic: bool = True,
render: bool = False,
verbose: int = 1,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class ResultsWriter:
def __init__(
self,
filename: str = "",
header: Dict[str, Union[float, str]] = None,
header: Optional[Dict[str, Union[float, str]]] = None,
extra_keys: Tuple[str, ...] = (),
):
if header is None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_kwargs: Dict[str, Any] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = "auto",
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import gym
import numpy as np
Expand Down Expand Up @@ -755,7 +755,7 @@ def __init__(
self,
observation_space: gym.spaces.Dict,
action_space: gym.spaces.Space,
lr_schedule: Callable,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O

def save_to_zip_file(
save_path: Union[str, pathlib.Path, io.BufferedIOBase],
data: Dict[str, Any] = None,
params: Dict[str, Any] = None,
pytorch_variables: Dict[str, Any] = None,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
pytorch_variables: Optional[Dict[str, Any]] = None,
verbose: int = 0,
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
optimize_memory_usage: bool = False,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Dict[str, Any] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
use_sde_at_warmup: bool = False,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Dict[str, Any] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
target_noise_clip: float = 0.5,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Dict[str, Any] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0a0
1.2.0a1

0 comments on commit be86883

Please sign in to comment.