diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 761e41e6a..2b35d61e8 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -24,7 +24,11 @@
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
-- [ ] I have checked the codestyle using `make lint`
-- [ ] I have ensured `make pytest` and `make type` both pass.
+- [ ] I have reformatted the code using `make format` (**required**)
+- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**)
+- [ ] I have ensured `make pytest` and `make type` both pass. (**required**)
+
+
+Note: we are using a maximum length of 127 characters per line
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index cbad4a146..47ca91980 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -38,6 +38,9 @@ jobs:
- name: Type check
run: |
make type
+ - name: Check codestyle
+ run: |
+ make check-codestyle
- name: Lint with flake8
run: |
make lint
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 8eeaf41f8..695c14543 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -1,4 +1,4 @@
-image: stablebaselines/stable-baselines3-cpu:0.8.0a1
+image: stablebaselines/stable-baselines3-cpu:0.8.0a4
type-check:
script:
@@ -15,4 +15,5 @@ doc-build:
lint-check:
script:
+ - make check-codestyle
- make lint
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ff56c13e3..e04e59ee4 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -38,23 +38,9 @@ pip install -e .[docs,tests,extra]
## Codestyle
-We follow the [PEP8 codestyle](https://www.python.org/dev/peps/pep-0008/). Please order the imports as follows:
+We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports.
-1. built-in
-2. packages
-3. current module
-
-with one space between each, that gives for instance:
-```python
-import os
-import warnings
-
-import numpy as np
-
-from stable_baselines3 import PPO
-```
-
-In general, we recommend using pycharm to format everything in an efficient way.
+**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`.
Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template:
@@ -77,11 +63,10 @@ def my_function(arg1: type1, arg2: type2) -> returntype:
Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli).
-A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch.
+A PR must pass the Continuous Integration tests to be merged with the master branch.
-Note: in rare cases, we can create exception for codacy failure.
-## Test
+## Tests
All new features must add tests in the `tests/` folder ensuring that everything works fine.
We use [pytest](https://pytest.org/).
@@ -99,12 +84,18 @@ Type checking with `pytype`:
make type
```
-Codestyle check with `flake8`:
+Codestyle check with `black`, `isort` and `flake8`:
```
+make check-codestyle
make lint
```
+To run `pytype`, `format` and `lint` in one command:
+```
+make commit-checks
+```
+
Build the documentation:
```
@@ -121,6 +112,7 @@ make spelling
## Changelog and Documentation
Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed.
+You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too.
A README is present in the `docs/` folder for instructions on how to build the documentation.
diff --git a/Dockerfile b/Dockerfile
index 0b30e6d76..2ed1d2309 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -10,7 +10,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
ca-certificates \
libjpeg-dev \
- libpng-dev && \
+ libpng-dev \
+ libglib2.0-0 && \
rm -rf /var/lib/apt/lists/*
# Install anaconda abd dependencies
@@ -37,15 +38,4 @@ RUN \
pip install opencv-python-headless && \
rm -rf $HOME/.cache/pip
-
-# Codacy deps
-RUN apt-get update && apt-get install -y --no-install-recommends \
- default-jre \
- jq && \
- rm -rf /var/lib/apt/lists/*
-
-# Codacy code coverage report: used for partial code coverage reporting
-RUN cd $CODE_DIR && \
- curl -Ls -o codacy-coverage-reporter.jar "$(curl -Ls https://api.github.com/repos/codacy/codacy-coverage-reporter/releases/latest | jq -r '.assets | map({name, browser_download_url} | select(.name | (startswith("codacy-coverage-reporter") and contains("assembly") and endswith(".jar")))) | .[0].browser_download_url')"
-
CMD /bin/bash
diff --git a/Makefile b/Makefile
index 05ce4d298..749bc026b 100644
--- a/Makefile
+++ b/Makefile
@@ -14,6 +14,20 @@ lint:
# exit-zero treats all errors as warnings.
flake8 ${LINT_PATHS} --count --exit-zero --statistics
+format:
+ # Sort imports
+ isort ${LINT_PATHS}
+ # Reformat using black
+ black -l 127 ${LINT_PATHS}
+
+check-codestyle:
+ # Sort imports
+ isort --check ${LINT_PATHS}
+ # Reformat using black
+ black --check -l 127 ${LINT_PATHS}
+
+commit-checks: format type lint
+
doc:
cd docs && make html
@@ -23,8 +37,6 @@ spelling:
clean:
cd docs && make clean
-.PHONY: clean spelling doc lint
-
# Build docker images
# If you do export RELEASE=True, it will also push them
docker: docker-cpu docker-gpu
@@ -46,3 +58,5 @@ test-release:
python setup.py sdist
python setup.py bdist_wheel
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
+
+.PHONY: clean spelling doc lint format check-codestyle commit-checks
diff --git a/README.md b/README.md
index d6338af7f..95e7c1536 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
[![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
-
+[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
**WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released**
diff --git a/docs/conf.py b/docs/conf.py
index 78138f600..320834c0a 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -20,12 +20,13 @@
# PyEnchant.
try:
import sphinxcontrib.spelling # noqa: F401
+
enable_spell_check = True
except ImportError:
enable_spell_check = False
# source code directory, relative to this file, for sphinx-autobuild
-sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath(".."))
class Mock(MagicMock):
@@ -44,18 +45,18 @@ def __getattr__(cls, name):
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# Read version from file
-version_file = os.path.join(os.path.dirname(__file__), '../stable_baselines3', 'version.txt')
-with open(version_file, 'r') as file_handler:
+version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt")
+with open(version_file, "r") as file_handler:
__version__ = file_handler.read().strip()
# -- Project information -----------------------------------------------------
-project = 'Stable Baselines3'
-copyright = '2020, Stable Baselines3'
-author = 'Stable Baselines3 Contributors'
+project = "Stable Baselines3"
+copyright = "2020, Stable Baselines3"
+author = "Stable Baselines3 Contributors"
# The short X.Y version
-version = 'master (' + __version__ + ' )'
+version = "master (" + __version__ + " )"
# The full version, including alpha/beta/rc tags
release = __version__
@@ -70,30 +71,30 @@ def __getattr__(cls, name):
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
+ "sphinx.ext.autodoc",
# 'sphinx_autodoc_typehints',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.viewcode',
+ "sphinx.ext.autosummary",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.viewcode",
# 'sphinx.ext.intersphinx',
# 'sphinx.ext.doctest'
]
if enable_spell_check:
- extensions.append('sphinxcontrib.spelling')
+ extensions.append("sphinxcontrib.spelling")
# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
+templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
-source_suffix = '.rst'
+source_suffix = ".rst"
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -105,10 +106,10 @@ def __getattr__(cls, name):
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'sphinx'
+pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
@@ -117,13 +118,13 @@ def __getattr__(cls, name):
# a list of builtin themes.
# Fix for read the docs
-on_rtd = os.environ.get('READTHEDOCS') == 'True'
+on_rtd = os.environ.get("READTHEDOCS") == "True"
if on_rtd:
- html_theme = 'default'
+ html_theme = "default"
else:
- html_theme = 'sphinx_rtd_theme'
+ html_theme = "sphinx_rtd_theme"
-html_logo = '_static/img/logo.png'
+html_logo = "_static/img/logo.png"
def setup(app):
@@ -139,7 +140,7 @@ def setup(app):
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
+html_static_path = ["_static"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
@@ -155,7 +156,7 @@ def setup(app):
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
-htmlhelp_basename = 'StableBaselines3doc'
+htmlhelp_basename = "StableBaselines3doc"
# -- Options for LaTeX output ------------------------------------------------
@@ -164,15 +165,12 @@ def setup(app):
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
-
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
-
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
-
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
@@ -182,8 +180,7 @@ def setup(app):
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
- (master_doc, 'StableBaselines3.tex', 'Stable Baselines3 Documentation',
- 'Stable Baselines3 Contributors', 'manual'),
+ (master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"),
]
@@ -191,10 +188,7 @@ def setup(app):
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
-man_pages = [
- (master_doc, 'stablebaselines3', 'Stable Baselines3 Documentation',
- [author], 1)
-]
+man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
@@ -203,9 +197,15 @@ def setup(app):
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
- (master_doc, 'StableBaselines3', 'Stable Baselines3 Documentation',
- author, 'StableBaselines3', 'One line description of project.',
- 'Miscellaneous'),
+ (
+ master_doc,
+ "StableBaselines3",
+ "Stable Baselines3 Documentation",
+ author,
+ "StableBaselines3",
+ "One line description of project.",
+ "Miscellaneous",
+ ),
]
diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst
index bb5cecafa..f49fb4579 100644
--- a/docs/misc/changelog.rst
+++ b/docs/misc/changelog.rst
@@ -40,8 +40,10 @@ Others:
- Split the ``collect_rollout()`` method for off-policy algorithms
- Added ``_on_step()`` for off-policy base class
- Optimized replay buffer size by removing the need of ``next_observations`` numpy array
+- Switch to ``black`` codestyle and added ``make format``, ``make check-codestyle`` and ``commit-checks``
- Ignored errors from newer pytype version
- Added a check when using ``gSDE``
+- Removed codacy dependency from Dockerfile
Documentation:
^^^^^^^^^^^^^^
diff --git a/setup.cfg b/setup.cfg
index 1d0170067..011c3d9b1 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,7 +21,7 @@ filterwarnings =
inputs = stable_baselines3
[flake8]
-ignore = W503,W504 # line breaks before and after binary operators
+ignore = W503,W504,E203,E231 # line breaks before and after binary operators
# Ignore import not used when aliases are defined
per-file-ignores =
./stable_baselines3/__init__.py:F401
@@ -48,3 +48,8 @@ exclude =
max-complexity = 15
# The GitHub editor is 127 chars wide
max-line-length = 127
+
+[isort]
+profile = black
+line_length = 127
+src_paths = stable_baselines3
diff --git a/setup.py b/setup.py
index 5733441ee..8401c7c1d 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,8 @@
import os
-from setuptools import setup, find_packages
-with open(os.path.join('stable_baselines3', 'version.txt'), 'r') as file_handler:
+from setuptools import find_packages, setup
+
+with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler:
__version__ = file_handler.read().strip()
@@ -64,66 +65,69 @@
""" # noqa:E501
-setup(name='stable_baselines3',
- packages=[package for package in find_packages()
- if package.startswith('stable_baselines3')],
- package_data={
- 'stable_baselines3': ['py.typed', 'version.txt']
- },
- install_requires=[
- 'gym>=0.17',
- 'numpy',
- 'torch>=1.4.0',
- # For saving models
- 'cloudpickle',
- # For reading logs
- 'pandas',
- # Plotting learning curves
- 'matplotlib'
- ],
- extras_require={
- 'tests': [
- # Run tests and coverage
- 'pytest',
- 'pytest-cov',
- 'pytest-env',
- 'pytest-xdist',
- # Type check
- 'pytype',
- # Lint code
- 'flake8>=3.8'
- ],
- 'docs': [
- 'sphinx',
- 'sphinx-autobuild',
- 'sphinx-rtd-theme',
- # For spelling
- 'sphinxcontrib.spelling',
- # Type hints support
- # 'sphinx-autodoc-typehints'
- ],
- 'extra': [
- # For render
- 'opencv-python',
- # For atari games,
- 'atari_py~=0.2.0', 'pillow',
- # Tensorboard support
- 'tensorboard',
- # Checking memory taken by replay buffer
- 'psutil'
- ]
- },
- description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.',
- author='Antonin Raffin',
- url='https://github.com/DLR-RM/stable-baselines3',
- author_email='antonin.raffin@dlr.de',
- keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
- "gym openai stable baselines toolbox python data-science",
- license="MIT",
- long_description=long_description,
- long_description_content_type='text/markdown',
- version=__version__,
- )
+setup(
+ name="stable_baselines3",
+ packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
+ package_data={"stable_baselines3": ["py.typed", "version.txt"]},
+ install_requires=[
+ "gym>=0.17",
+ "numpy",
+ "torch>=1.4.0",
+ # For saving models
+ "cloudpickle",
+ # For reading logs
+ "pandas",
+ # Plotting learning curves
+ "matplotlib",
+ ],
+ extras_require={
+ "tests": [
+ # Run tests and coverage
+ "pytest",
+ "pytest-cov",
+ "pytest-env",
+ "pytest-xdist",
+ # Type check
+ "pytype",
+ # Lint code
+ "flake8>=3.8",
+ # Sort imports
+ "isort>=5.0",
+ # Reformat
+ "black",
+ ],
+ "docs": [
+ "sphinx",
+ "sphinx-autobuild",
+ "sphinx-rtd-theme",
+ # For spelling
+ "sphinxcontrib.spelling",
+ # Type hints support
+ # 'sphinx-autodoc-typehints'
+ ],
+ "extra": [
+ # For render
+ "opencv-python",
+ # For atari games,
+ "atari_py~=0.2.0",
+ "pillow",
+ # Tensorboard support
+ "tensorboard",
+ # Checking memory taken by replay buffer
+ "psutil",
+ ],
+ },
+ description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
+ author="Antonin Raffin",
+ url="https://github.com/DLR-RM/stable-baselines3",
+ author_email="antonin.raffin@dlr.de",
+ keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
+ "gym openai stable baselines toolbox python data-science",
+ license="MIT",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ version=__version__,
+)
# python setup.py sdist
# python setup.py bdist_wheel
diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py
index 1bdb2e9e9..b88ca5d4c 100644
--- a/stable_baselines3/__init__.py
+++ b/stable_baselines3/__init__.py
@@ -8,6 +8,6 @@
from stable_baselines3.td3 import TD3
# Read version from file
-version_file = os.path.join(os.path.dirname(__file__), 'version.txt')
-with open(version_file, 'r') as file_handler:
+version_file = os.path.join(os.path.dirname(__file__), "version.txt")
+with open(version_file, "r") as file_handler:
__version__ = file_handler.read().strip()
diff --git a/stable_baselines3/a2c/__init__.py b/stable_baselines3/a2c/__init__.py
index f74a6497c..e6aeda5bb 100644
--- a/stable_baselines3/a2c/__init__.py
+++ b/stable_baselines3/a2c/__init__.py
@@ -1,2 +1,2 @@
from stable_baselines3.a2c.a2c import A2C
-from stable_baselines3.a2c.policies import MlpPolicy, CnnPolicy
+from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy
diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py
index a27d864c5..c2c7b34e1 100644
--- a/stable_baselines3/a2c/a2c.py
+++ b/stable_baselines3/a2c/a2c.py
@@ -1,13 +1,14 @@
+from typing import Any, Callable, Dict, Optional, Type, Union
+
import torch as th
-import torch.nn.functional as F
from gym import spaces
-from typing import Type, Union, Callable, Optional, Dict, Any
+from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
+from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import explained_variance
-from stable_baselines3.common.policies import ActorCriticPolicy
class A2C(OnPolicyAlgorithm):
@@ -50,44 +51,59 @@ class A2C(OnPolicyAlgorithm):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self, policy: Union[str, Type[ActorCriticPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable] = 7e-4,
- n_steps: int = 5,
- gamma: float = 0.99,
- gae_lambda: float = 1.0,
- ent_coef: float = 0.0,
- vf_coef: float = 0.5,
- max_grad_norm: float = 0.5,
- rms_prop_eps: float = 1e-5,
- use_rms_prop: bool = True,
- use_sde: bool = False,
- sde_sample_freq: int = -1,
- normalize_advantage: bool = False,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Optional[Dict[str, Any]] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = 'auto',
- _init_setup_model: bool = True):
-
- super(A2C, self).__init__(policy, env, learning_rate=learning_rate,
- n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda,
- ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm,
- use_sde=use_sde, sde_sample_freq=sde_sample_freq,
- tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
- verbose=verbose, device=device, create_eval_env=create_eval_env,
- seed=seed, _init_setup_model=False)
+ def __init__(
+ self,
+ policy: Union[str, Type[ActorCriticPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable] = 7e-4,
+ n_steps: int = 5,
+ gamma: float = 0.99,
+ gae_lambda: float = 1.0,
+ ent_coef: float = 0.0,
+ vf_coef: float = 0.5,
+ max_grad_norm: float = 0.5,
+ rms_prop_eps: float = 1e-5,
+ use_rms_prop: bool = True,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ normalize_advantage: bool = False,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super(A2C, self).__init__(
+ policy,
+ env,
+ learning_rate=learning_rate,
+ n_steps=n_steps,
+ gamma=gamma,
+ gae_lambda=gae_lambda,
+ ent_coef=ent_coef,
+ vf_coef=vf_coef,
+ max_grad_norm=max_grad_norm,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ tensorboard_log=tensorboard_log,
+ policy_kwargs=policy_kwargs,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ _init_setup_model=False,
+ )
self.normalize_advantage = normalize_advantage
# Update optimizer inside the policy if we want to use RMSProp
# (original implementation) rather than Adam
- if use_rms_prop and 'optimizer_class' not in self.policy_kwargs:
- self.policy_kwargs['optimizer_class'] = th.optim.RMSprop
- self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
- weight_decay=0)
+ if use_rms_prop and "optimizer_class" not in self.policy_kwargs:
+ self.policy_kwargs["optimizer_class"] = th.optim.RMSprop
+ self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0)
if _init_setup_model:
self._setup_model()
@@ -139,8 +155,7 @@ def train(self) -> None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
- explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
- self.rollout_buffer.values.flatten())
+ explained_var = explained_variance(self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten())
self._n_updates += 1
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
@@ -148,21 +163,30 @@ def train(self) -> None:
logger.record("train/entropy_loss", entropy_loss.item())
logger.record("train/policy_loss", policy_loss.item())
logger.record("train/value_loss", value_loss.item())
- if hasattr(self.policy, 'log_std'):
+ if hasattr(self.policy, "log_std"):
logger.record("train/std", th.exp(self.policy.log_std).mean().item())
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 100,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "A2C",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> 'A2C':
-
- return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback,
- log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name,
- eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps)
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 100,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "A2C",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "A2C":
+
+ return super(A2C, self).learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
diff --git a/stable_baselines3/a2c/policies.py b/stable_baselines3/a2c/policies.py
index ae1160d20..eed0ddea1 100644
--- a/stable_baselines3/a2c/policies.py
+++ b/stable_baselines3/a2c/policies.py
@@ -1,6 +1,6 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for A2C
-from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy
+from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
diff --git a/stable_baselines3/common/__init__.py b/stable_baselines3/common/__init__.py
index 32675071f..275e3ad27 100644
--- a/stable_baselines3/common/__init__.py
+++ b/stable_baselines3/common/__init__.py
@@ -1,2 +1,2 @@
-from stable_baselines3.common.cmd_util import make_vec_env, make_atari_env
+from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.utils import set_random_seed
diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py
index 7af95fa0e..fdae2d327 100644
--- a/stable_baselines3/common/atari_wrappers.py
+++ b/stable_baselines3/common/atari_wrappers.py
@@ -1,8 +1,10 @@
import gym
-from gym import spaces
import numpy as np
+from gym import spaces
+
try:
import cv2 # pytype:disable=import-error
+
cv2.ocl.setUseOpenCL(False)
except ImportError:
cv2 = None
@@ -23,7 +25,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30):
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
- assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
+ assert env.unwrapped.get_action_meanings()[0] == "NOOP"
def reset(self, **kwargs) -> np.ndarray:
self.env.reset(**kwargs)
@@ -48,7 +50,7 @@ def __init__(self, env: gym.Env):
:param env: (gym.Env) the environment to wrap
"""
gym.Wrapper.__init__(self, env)
- assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self, **kwargs) -> np.ndarray:
@@ -180,8 +182,9 @@ def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
gym.ObservationWrapper.__init__(self, env)
self.width = width
self.height = height
- self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width, 1),
- dtype=env.observation_space.dtype)
+ self.observation_space = spaces.Box(
+ low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype
+ )
def observation(self, frame: np.ndarray) -> np.ndarray:
"""
@@ -217,17 +220,21 @@ class AtariWrapper(gym.Wrapper):
life is lost.
:param clip_reward: (bool) If True (default), the reward is clip to {-1, 0, 1} depending on its sign.
"""
- def __init__(self, env: gym.Env,
- noop_max: int = 30,
- frame_skip: int = 4,
- screen_size: int = 84,
- terminal_on_life_loss: bool = True,
- clip_reward: bool = True):
+
+ def __init__(
+ self,
+ env: gym.Env,
+ noop_max: int = 30,
+ frame_skip: int = 4,
+ screen_size: int = 84,
+ terminal_on_life_loss: bool = True,
+ clip_reward: bool = True,
+ ):
env = NoopResetEnv(env, noop_max=noop_max)
env = MaxAndSkipEnv(env, skip=frame_skip)
if terminal_on_life_loss:
env = EpisodicLifeEnv(env)
- if 'FIRE' in env.unwrapped.get_action_meanings():
+ if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env, width=screen_size, height=screen_size)
if clip_reward:
diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py
index 56f9f9642..1ba7cc0ce 100644
--- a/stable_baselines3/common/base_class.py
+++ b/stable_baselines3/common/base_class.py
@@ -1,28 +1,32 @@
"""Abstract base classes for RL algorithms."""
+import io
+import pathlib
import time
-from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable
from abc import ABC, abstractmethod
from collections import deque
-import pathlib
-import io
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import gym
-import torch as th
import numpy as np
+import torch as th
from stable_baselines3.common import logger, utils
-from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
-from stable_baselines3.common.utils import (set_random_seed, get_schedule_fn, update_learning_rate, get_device,
- check_for_correct_spaces)
-from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
-from stable_baselines3.common.preprocessing import is_image_space
-from stable_baselines3.common.save_util import (recursive_getattr, recursive_setattr, save_to_zip_file,
- load_from_zip_file)
-from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
+from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
+from stable_baselines3.common.preprocessing import is_image_space
+from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
+from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
+from stable_baselines3.common.utils import (
+ check_for_correct_spaces,
+ get_device,
+ get_schedule_fn,
+ set_random_seed,
+ update_learning_rate,
+)
+from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecNormalize, VecTransposeImage, unwrap_vec_normalize
def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]:
@@ -72,21 +76,23 @@ class BaseAlgorithm(ABC):
Default: -1 (only sample at the beginning of the rollout)
"""
- def __init__(self,
- policy: Type[BasePolicy],
- env: Union[GymEnv, str, None],
- policy_base: Type[BasePolicy],
- learning_rate: Union[float, Callable],
- policy_kwargs: Dict[str, Any] = None,
- tensorboard_log: Optional[str] = None,
- verbose: int = 0,
- device: Union[th.device, str] = 'auto',
- support_multi_env: bool = False,
- create_eval_env: bool = False,
- monitor_wrapper: bool = True,
- seed: Optional[int] = None,
- use_sde: bool = False,
- sde_sample_freq: int = -1):
+ def __init__(
+ self,
+ policy: Type[BasePolicy],
+ env: Union[GymEnv, str, None],
+ policy_base: Type[BasePolicy],
+ learning_rate: Union[float, Callable],
+ policy_kwargs: Dict[str, Any] = None,
+ tensorboard_log: Optional[str] = None,
+ verbose: int = 0,
+ device: Union[th.device, str] = "auto",
+ support_multi_env: bool = False,
+ create_eval_env: bool = False,
+ monitor_wrapper: bool = True,
+ seed: Optional[int] = None,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ ):
if isinstance(policy, str) and policy_base is not None:
self.policy_class = get_policy_from_name(policy_base, policy)
@@ -147,12 +153,12 @@ def __init__(self,
self.env = env
if not support_multi_env and self.n_envs > 1:
- raise ValueError("Error: the model does not support multiple envs; it requires "
- "a single vectorized environment.")
+ raise ValueError(
+ "Error: the model does not support multiple envs; it requires " "a single vectorized environment."
+ )
if self.use_sde and not isinstance(self.observation_space, gym.spaces.Box):
- raise ValueError("generalized State-Dependent Exploration (gSDE) can only "
- "be used with continuous actions.")
+ raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.")
def _wrap_env(self, env: GymEnv) -> VecEnv:
if not isinstance(env, VecEnv):
@@ -262,15 +268,18 @@ def get_torch_variables(self) -> Tuple[List[str], List[str]]:
return state_dicts, []
@abstractmethod
- def learn(self, total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 100,
- tb_log_name: str = "run",
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> 'BaseAlgorithm':
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 100,
+ tb_log_name: str = "run",
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "BaseAlgorithm":
"""
Return a trained model.
@@ -286,10 +295,13 @@ def learn(self, total_timesteps: int,
:return: (BaseAlgorithm) the trained model
"""
- def predict(self, observation: np.ndarray,
- state: Optional[np.ndarray] = None,
- mask: Optional[np.ndarray] = None,
- deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ def predict(
+ self,
+ observation: np.ndarray,
+ state: Optional[np.ndarray] = None,
+ mask: Optional[np.ndarray] = None,
+ deterministic: bool = False,
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Get the model's action(s) from an observation
@@ -303,7 +315,7 @@ def predict(self, observation: np.ndarray,
return self.policy.predict(observation, state, mask, deterministic)
@classmethod
- def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm':
+ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm":
"""
Load the model from a zip-file
@@ -314,14 +326,16 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
"""
data, params, tensors = load_from_zip_file(load_path)
- if 'policy_kwargs' in data:
- for arg_to_remove in ['device']:
- if arg_to_remove in data['policy_kwargs']:
- del data['policy_kwargs'][arg_to_remove]
+ if "policy_kwargs" in data:
+ for arg_to_remove in ["device"]:
+ if arg_to_remove in data["policy_kwargs"]:
+ del data["policy_kwargs"][arg_to_remove]
- if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']:
- raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs."
- f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}")
+ if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]:
+ raise ValueError(
+ f"The specified policy kwargs do not equal the stored policy kwargs."
+ f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}"
+ )
# check if observation space and action space are part of the saved parameters
if "observation_space" not in data or "action_space" not in data:
@@ -334,8 +348,12 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl
env = data["env"]
# noinspection PyArgumentList
- model = cls(policy=data["policy_class"], env=env,
- device='auto', _init_setup_model=False) # pytype: disable=not-instantiable,wrong-keyword-args
+ model = cls(
+ policy=data["policy_class"],
+ env=env,
+ device="auto",
+ _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
+ )
# load parameters
model.__dict__.update(data)
@@ -367,19 +385,21 @@ def set_random_seed(self, seed: Optional[int] = None) -> None:
"""
if seed is None:
return
- set_random_seed(seed, using_cuda=self.device == th.device('cuda'))
+ set_random_seed(seed, using_cuda=self.device == th.device("cuda"))
self.action_space.seed(seed)
if self.env is not None:
self.env.seed(seed)
if self.eval_env is not None:
self.eval_env.seed(seed)
- def _init_callback(self,
- callback: MaybeCallback,
- eval_env: Optional[VecEnv] = None,
- eval_freq: int = 10000,
- n_eval_episodes: int = 5,
- log_path: Optional[str] = None) -> BaseCallback:
+ def _init_callback(
+ self,
+ callback: MaybeCallback,
+ eval_env: Optional[VecEnv] = None,
+ eval_freq: int = 10000,
+ n_eval_episodes: int = 5,
+ log_path: Optional[str] = None,
+ ) -> BaseCallback:
"""
:param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm.
:param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate.
@@ -398,24 +418,29 @@ def _init_callback(self,
# Create eval callback in charge of the evaluation
if eval_env is not None:
- eval_callback = EvalCallback(eval_env,
- best_model_save_path=log_path,
- log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes)
+ eval_callback = EvalCallback(
+ eval_env,
+ best_model_save_path=log_path,
+ log_path=log_path,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ )
callback = CallbackList([callback, eval_callback])
callback.init_callback(self)
return callback
- def _setup_learn(self,
- total_timesteps: int,
- eval_env: Optional[GymEnv],
- callback: MaybeCallback = None,
- eval_freq: int = 10000,
- n_eval_episodes: int = 5,
- log_path: Optional[str] = None,
- reset_num_timesteps: bool = True,
- tb_log_name: str = 'run',
- ) -> Tuple[int, BaseCallback]:
+ def _setup_learn(
+ self,
+ total_timesteps: int,
+ eval_env: Optional[GymEnv],
+ callback: MaybeCallback = None,
+ eval_freq: int = 10000,
+ n_eval_episodes: int = 5,
+ log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ tb_log_name: str = "run",
+ ) -> Tuple[int, BaseCallback]:
"""
Initialize different variables needed for training.
@@ -476,8 +501,8 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd
if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
- maybe_ep_info = info.get('episode')
- maybe_is_success = info.get('is_success')
+ maybe_ep_info = info.get("episode")
+ maybe_is_success = info.get("is_success")
if maybe_ep_info is not None:
self.ep_info_buffer.extend([maybe_ep_info])
if maybe_is_success is not None and dones[idx]:
@@ -522,7 +547,7 @@ def save(
torch_variables = state_dicts_names + tensors_names
for torch_var in torch_variables:
# we need to get only the name of the top most module as we'll remove that
- var_name = torch_var.split('.')[0]
+ var_name = torch_var.split(".")[0]
exclude.add(var_name)
# Remove parameter entries of parameters which are to be excluded
diff --git a/stable_baselines3/common/bit_flipping_env.py b/stable_baselines3/common/bit_flipping_env.py
index 1dd881502..b579fe157 100644
--- a/stable_baselines3/common/bit_flipping_env.py
+++ b/stable_baselines3/common/bit_flipping_env.py
@@ -21,27 +21,31 @@ class BitFlippingEnv(GoalEnv):
:param discrete_obs_space: (bool) Whether to use the discrete observation
version or not, by default, it uses the MultiBinary one
"""
- def __init__(self, n_bits: int = 10,
- continuous: bool = False,
- max_steps: Optional[int] = None,
- discrete_obs_space: bool = False):
+
+ def __init__(
+ self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False
+ ):
super(BitFlippingEnv, self).__init__()
# The achieved goal is determined by the current state
# here, it is a special where they are equal
if discrete_obs_space:
# In the discrete case, the agent act on the binary
# representation of the observation
- self.observation_space = spaces.Dict({
- 'observation': spaces.Discrete(2 ** n_bits - 1),
- 'achieved_goal': spaces.Discrete(2 ** n_bits - 1),
- 'desired_goal': spaces.Discrete(2 ** n_bits - 1)
- })
+ self.observation_space = spaces.Dict(
+ {
+ "observation": spaces.Discrete(2 ** n_bits - 1),
+ "achieved_goal": spaces.Discrete(2 ** n_bits - 1),
+ "desired_goal": spaces.Discrete(2 ** n_bits - 1),
+ }
+ )
else:
- self.observation_space = spaces.Dict({
- 'observation': spaces.MultiBinary(n_bits),
- 'achieved_goal': spaces.MultiBinary(n_bits),
- 'desired_goal': spaces.MultiBinary(n_bits)
- })
+ self.observation_space = spaces.Dict(
+ {
+ "observation": spaces.MultiBinary(n_bits),
+ "achieved_goal": spaces.MultiBinary(n_bits),
+ "desired_goal": spaces.MultiBinary(n_bits),
+ }
+ )
self.obs_space = spaces.MultiBinary(n_bits)
@@ -69,7 +73,7 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
if self.discrete_obs_space:
# The internal state is the binary representation of the
# observed one
- return int(sum([state[i] * 2**i for i in range(len(state))]))
+ return int(sum([state[i] * 2 ** i for i in range(len(state))]))
return state
def _get_obs(self) -> OrderedDict:
@@ -78,11 +82,13 @@ def _get_obs(self) -> OrderedDict:
:return: (OrderedDict)
"""
- return OrderedDict([
- ('observation', self.convert_if_needed(self.state.copy())),
- ('achieved_goal', self.convert_if_needed(self.state.copy())),
- ('desired_goal', self.convert_if_needed(self.desired_goal.copy()))
- ])
+ return OrderedDict(
+ [
+ ("observation", self.convert_if_needed(self.state.copy())),
+ ("achieved_goal", self.convert_if_needed(self.state.copy())),
+ ("desired_goal", self.convert_if_needed(self.desired_goal.copy())),
+ ]
+ )
def reset(self) -> OrderedDict:
self.current_step = 0
@@ -95,25 +101,22 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
else:
self.state[action] = 1 - self.state[action]
obs = self._get_obs()
- reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None)
+ reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None)
done = reward == 0
self.current_step += 1
# Episode terminate when we reached the goal or the max number of steps
- info = {'is_success': done}
+ info = {"is_success": done}
done = done or self.current_step >= self.max_steps
return obs, reward, done, info
- def compute_reward(self,
- achieved_goal: np.ndarray,
- desired_goal: np.ndarray,
- _info) -> float:
+ def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> float:
# Deceptive reward: it is positive only when the goal is achieved
if self.discrete_obs_space:
return 0.0 if achieved_goal == desired_goal else -1.0
return 0.0 if (achieved_goal == desired_goal).all() else -1.0
- def render(self, mode: str = 'human') -> Optional[np.ndarray]:
- if mode == 'rgb_array':
+ def render(self, mode: str = "human") -> Optional[np.ndarray]:
+ if mode == "rgb_array":
return self.state.copy()
print(self.state)
diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py
index 24ed8aee3..4534063a2 100644
--- a/stable_baselines3/common/buffers.py
+++ b/stable_baselines3/common/buffers.py
@@ -1,5 +1,5 @@
-from typing import Union, Optional, Generator
import warnings
+from typing import Generator, Optional, Union
import numpy as np
import torch as th
@@ -11,9 +11,9 @@
except ImportError:
psutil = None
-from stable_baselines3.common.vec_env import VecNormalize
-from stable_baselines3.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
+from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples
+from stable_baselines3.common.vec_env import VecNormalize
class BaseBuffer(object):
@@ -28,12 +28,14 @@ class BaseBuffer(object):
:param n_envs: (int) Number of parallel environments
"""
- def __init__(self,
- buffer_size: int,
- observation_space: spaces.Space,
- action_space: spaces.Space,
- device: Union[th.device, str] = 'cpu',
- n_envs: int = 1):
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ n_envs: int = 1,
+ ):
super(BaseBuffer, self).__init__()
self.buffer_size = buffer_size
self.observation_space = observation_space
@@ -89,10 +91,7 @@ def reset(self) -> None:
self.pos = 0
self.full = False
- def sample(self,
- batch_size: int,
- env: Optional[VecNormalize] = None
- ):
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None):
"""
:param batch_size: (int) Number of element to sample
:param env: (Optional[VecNormalize]) associated gym VecEnv
@@ -103,10 +102,7 @@ def sample(self,
batch_inds = np.random.randint(0, upper_bound, size=batch_size)
return self._get_samples(batch_inds, env=env)
- def _get_samples(self,
- batch_inds: np.ndarray,
- env: Optional[VecNormalize] = None
- ):
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None):
"""
:param batch_inds: (th.Tensor)
:param env: (Optional[VecNormalize])
@@ -129,15 +125,13 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
return th.as_tensor(array).to(self.device)
@staticmethod
- def _normalize_obs(obs: np.ndarray,
- env: Optional[VecNormalize] = None) -> np.ndarray:
+ def _normalize_obs(obs: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
if env is not None:
return env.normalize_obs(obs).astype(np.float32)
return obs
@staticmethod
- def _normalize_reward(reward: np.ndarray,
- env: Optional[VecNormalize] = None) -> np.ndarray:
+ def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray:
if env is not None:
return env.normalize_reward(reward).astype(np.float32)
return reward
@@ -158,15 +152,17 @@ class ReplayBuffer(BaseBuffer):
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
"""
- def __init__(self,
- buffer_size: int,
- observation_space: spaces.Space,
- action_space: spaces.Space,
- device: Union[th.device, str] = 'cpu',
- n_envs: int = 1,
- optimize_memory_usage: bool = False):
- super(ReplayBuffer, self).__init__(buffer_size, observation_space,
- action_space, device, n_envs=n_envs)
+
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ n_envs: int = 1,
+ optimize_memory_usage: bool = False,
+ ):
+ super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
assert n_envs == 1, "Replay buffer only support single environment for now"
@@ -186,8 +182,7 @@ def __init__(self,
self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
if psutil is not None:
- total_memory_usage = (self.observations.nbytes + self.actions.nbytes
- + self.rewards.nbytes + self.dones.nbytes)
+ total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes
if self.next_observations is not None:
total_memory_usage += self.next_observations.nbytes
@@ -195,15 +190,12 @@ def __init__(self,
# Convert to GB
total_memory_usage /= 1e9
mem_available /= 1e9
- warnings.warn("This system does not have apparently enough memory to store the complete "
- f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB")
-
- def add(self,
- obs: np.ndarray,
- next_obs: np.ndarray,
- action: np.ndarray,
- reward: np.ndarray,
- done: np.ndarray) -> None:
+ warnings.warn(
+ "This system does not have apparently enough memory to store the complete "
+ f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
+ )
+
+ def add(self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray) -> None:
# Copy to avoid modification by reference
self.observations[self.pos] = np.array(obs).copy()
if self.optimize_memory_usage:
@@ -220,10 +212,7 @@ def add(self,
self.full = True
self.pos = 0
- def sample(self,
- batch_size: int,
- env: Optional[VecNormalize] = None
- ) -> ReplayBufferSamples:
+ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
"""
Sample elements from the replay buffer.
Custom sampling when using memory efficient variant,
@@ -245,20 +234,19 @@ def sample(self,
batch_inds = np.random.randint(0, self.pos, size=batch_size)
return self._get_samples(batch_inds, env=env)
- def _get_samples(self,
- batch_inds: np.ndarray,
- env: Optional[VecNormalize] = None
- ) -> ReplayBufferSamples:
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
if self.optimize_memory_usage:
next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, 0, :], env)
else:
next_obs = self._normalize_obs(self.next_observations[batch_inds, 0, :], env)
- data = (self._normalize_obs(self.observations[batch_inds, 0, :], env),
- self.actions[batch_inds, 0, :],
- next_obs,
- self.dones[batch_inds],
- self._normalize_reward(self.rewards[batch_inds], env))
+ data = (
+ self._normalize_obs(self.observations[batch_inds, 0, :], env),
+ self.actions[batch_inds, 0, :],
+ next_obs,
+ self.dones[batch_inds],
+ self._normalize_reward(self.rewards[batch_inds], env),
+ )
return ReplayBufferSamples(*tuple(map(self.to_torch, data)))
@@ -276,17 +264,18 @@ class RolloutBuffer(BaseBuffer):
:param n_envs: (int) Number of parallel environments
"""
- def __init__(self,
- buffer_size: int,
- observation_space: spaces.Space,
- action_space: spaces.Space,
- device: Union[th.device, str] = 'cpu',
- gae_lambda: float = 1,
- gamma: float = 0.99,
- n_envs: int = 1):
-
- super(RolloutBuffer, self).__init__(buffer_size, observation_space,
- action_space, device, n_envs=n_envs)
+ def __init__(
+ self,
+ buffer_size: int,
+ observation_space: spaces.Space,
+ action_space: spaces.Space,
+ device: Union[th.device, str] = "cpu",
+ gae_lambda: float = 1,
+ gamma: float = 0.99,
+ n_envs: int = 1,
+ ):
+
+ super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
@@ -295,8 +284,7 @@ def __init__(self,
self.reset()
def reset(self) -> None:
- self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape,
- dtype=np.float32)
+ self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=np.float32)
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32)
self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
@@ -307,9 +295,7 @@ def reset(self) -> None:
self.generator_ready = False
super(RolloutBuffer, self).reset()
- def compute_returns_and_advantage(self,
- last_value: th.Tensor,
- dones: np.ndarray) -> None:
+ def compute_returns_and_advantage(self, last_value: th.Tensor, dones: np.ndarray) -> None:
"""
Post-processing step: compute the returns (sum of discounted rewards)
and GAE advantage.
@@ -340,13 +326,9 @@ def compute_returns_and_advantage(self,
self.advantages[step] = last_gae_lam
self.returns = self.advantages + self.values
- def add(self,
- obs: np.ndarray,
- action: np.ndarray,
- reward: np.ndarray,
- done: np.ndarray,
- value: th.Tensor,
- log_prob: th.Tensor) -> None:
+ def add(
+ self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor
+ ) -> None:
"""
:param obs: (np.ndarray) Observation
:param action: (np.ndarray) Action
@@ -372,12 +354,11 @@ def add(self,
self.full = True
def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]:
- assert self.full, ''
+ assert self.full, ""
indices = np.random.permutation(self.buffer_size * self.n_envs)
# Prepare the data
if not self.generator_ready:
- for tensor in ['observations', 'actions', 'values',
- 'log_probs', 'advantages', 'returns']:
+ for tensor in ["observations", "actions", "values", "log_probs", "advantages", "returns"]:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
@@ -387,15 +368,16 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample
start_idx = 0
while start_idx < self.buffer_size * self.n_envs:
- yield self._get_samples(indices[start_idx:start_idx + batch_size])
+ yield self._get_samples(indices[start_idx : start_idx + batch_size])
start_idx += batch_size
- def _get_samples(self, batch_inds: np.ndarray,
- env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
- data = (self.observations[batch_inds],
- self.actions[batch_inds],
- self.values[batch_inds].flatten(),
- self.log_probs[batch_inds].flatten(),
- self.advantages[batch_inds].flatten(),
- self.returns[batch_inds].flatten())
+ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
+ data = (
+ self.observations[batch_inds],
+ self.actions[batch_inds],
+ self.values[batch_inds].flatten(),
+ self.log_probs[batch_inds].flatten(),
+ self.advantages[batch_inds].flatten(),
+ self.returns[batch_inds].flatten(),
+ )
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py
index ccd812f08..c5e53e58e 100644
--- a/stable_baselines3/common/callbacks.py
+++ b/stable_baselines3/common/callbacks.py
@@ -1,15 +1,15 @@
import os
-from abc import ABC, abstractmethod
-import warnings
import typing
-from typing import Union, List, Dict, Any, Optional
+import warnings
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Union
import gym
import numpy as np
-from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
-from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common import logger
+from stable_baselines3.common.evaluation import evaluate_policy
+from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization
if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error
@@ -21,6 +21,7 @@ class BaseCallback(ABC):
:param verbose: (int)
"""
+
def __init__(self, verbose: int = 0):
super(BaseCallback, self).__init__()
# The RL model
@@ -40,7 +41,7 @@ def __init__(self, verbose: int = 0):
self.parent = None # type: Optional[BaseCallback]
# Type hint as string to avoid circular import
- def init_callback(self, model: 'BaseAlgorithm') -> None:
+ def init_callback(self, model: "BaseAlgorithm") -> None:
"""
Initialize the callback by saving references to the
RL model and the training environment for convenience.
@@ -111,6 +112,7 @@ class EventCallback(BaseCallback):
when an event is triggered.
:param verbose: (int)
"""
+
def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
super(EventCallback, self).__init__(verbose=verbose)
self.callback = callback
@@ -118,7 +120,7 @@ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0):
if callback is not None:
self.callback.parent = self
- def init_callback(self, model: 'BaseAlgorithm') -> None:
+ def init_callback(self, model: "BaseAlgorithm") -> None:
super(EventCallback, self).init_callback(model)
if self.callback is not None:
self.callback.init_callback(self.model)
@@ -143,6 +145,7 @@ class CallbackList(BaseCallback):
:param callbacks: (List[BaseCallback]) A list of callbacks that will be called
sequentially.
"""
+
def __init__(self, callbacks: List[BaseCallback]):
super(CallbackList, self).__init__()
assert isinstance(callbacks, list)
@@ -184,7 +187,8 @@ class CheckpointCallback(BaseCallback):
:param save_path: (str) Path to the folder where the model will be saved.
:param name_prefix: (str) Common prefix to the saved models
"""
- def __init__(self, save_freq: int, save_path: str, name_prefix='rl_model', verbose=0):
+
+ def __init__(self, save_freq: int, save_path: str, name_prefix="rl_model", verbose=0):
super(CheckpointCallback, self).__init__(verbose)
self.save_freq = save_freq
self.save_path = save_path
@@ -197,7 +201,7 @@ def _init_callback(self) -> None:
def _on_step(self) -> bool:
if self.n_calls % self.save_freq == 0:
- path = os.path.join(self.save_path, f'{self.name_prefix}_{self.num_timesteps}_steps')
+ path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps")
self.model.save(path)
if self.verbose > 1:
print(f"Saving model checkpoint to {path}")
@@ -211,6 +215,7 @@ class ConvertCallback(BaseCallback):
:param callback: (callable)
:param verbose: (int)
"""
+
def __init__(self, callback, verbose=0):
super(ConvertCallback, self).__init__(verbose)
self.callback = callback
@@ -240,15 +245,19 @@ class EvalCallback(EventCallback):
:param render: (bool) Whether to render or not the environment during evaluation
:param verbose: (int)
"""
- def __init__(self, eval_env: Union[gym.Env, VecEnv],
- callback_on_new_best: Optional[BaseCallback] = None,
- n_eval_episodes: int = 5,
- eval_freq: int = 10000,
- log_path: str = None,
- best_model_save_path: str = None,
- deterministic: bool = True,
- render: bool = False,
- verbose: int = 1):
+
+ def __init__(
+ self,
+ eval_env: Union[gym.Env, VecEnv],
+ callback_on_new_best: Optional[BaseCallback] = None,
+ n_eval_episodes: int = 5,
+ eval_freq: int = 10000,
+ log_path: str = None,
+ best_model_save_path: str = None,
+ deterministic: bool = True,
+ render: bool = False,
+ verbose: int = 1,
+ ):
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
self.n_eval_episodes = n_eval_episodes
self.eval_freq = eval_freq
@@ -268,7 +277,7 @@ def __init__(self, eval_env: Union[gym.Env, VecEnv],
self.best_model_save_path = best_model_save_path
# Logs will be written in ``evaluations.npz``
if log_path is not None:
- log_path = os.path.join(log_path, 'evaluations')
+ log_path = os.path.join(log_path, "evaluations")
self.log_path = log_path
self.evaluations_results = []
self.evaluations_timesteps = []
@@ -277,8 +286,7 @@ def __init__(self, eval_env: Union[gym.Env, VecEnv],
def _init_callback(self):
# Does not work in some corner cases, where the wrapper is not the same
if not isinstance(self.training_env, type(self.eval_env)):
- warnings.warn("Training and eval env are not of the same type"
- f"{self.training_env} != {self.eval_env}")
+ warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}")
# Create folders if needed
if self.best_model_save_path is not None:
@@ -292,36 +300,42 @@ def _on_step(self) -> bool:
# Sync training and eval env if there is VecNormalize
sync_envs_normalization(self.training_env, self.eval_env)
- episode_rewards, episode_lengths = evaluate_policy(self.model, self.eval_env,
- n_eval_episodes=self.n_eval_episodes,
- render=self.render,
- deterministic=self.deterministic,
- return_episode_rewards=True)
+ episode_rewards, episode_lengths = evaluate_policy(
+ self.model,
+ self.eval_env,
+ n_eval_episodes=self.n_eval_episodes,
+ render=self.render,
+ deterministic=self.deterministic,
+ return_episode_rewards=True,
+ )
if self.log_path is not None:
self.evaluations_timesteps.append(self.num_timesteps)
self.evaluations_results.append(episode_rewards)
self.evaluations_length.append(episode_lengths)
- np.savez(self.log_path, timesteps=self.evaluations_timesteps,
- results=self.evaluations_results, ep_lengths=self.evaluations_length)
+ np.savez(
+ self.log_path,
+ timesteps=self.evaluations_timesteps,
+ results=self.evaluations_results,
+ ep_lengths=self.evaluations_length,
+ )
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
self.last_mean_reward = mean_reward
if self.verbose > 0:
- print(f"Eval num_timesteps={self.num_timesteps}, "
- f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
+ print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")
print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}")
# Add to current Logger
- self.logger.record('eval/mean_reward', float(mean_reward))
- self.logger.record('eval/mean_ep_length', mean_ep_length)
+ self.logger.record("eval/mean_reward", float(mean_reward))
+ self.logger.record("eval/mean_ep_length", mean_ep_length)
if mean_reward > self.best_mean_reward:
if self.verbose > 0:
print("New best mean reward!")
if self.best_model_save_path is not None:
- self.model.save(os.path.join(self.best_model_save_path, 'best_model'))
+ self.model.save(os.path.join(self.best_model_save_path, "best_model"))
self.best_mean_reward = mean_reward
# Trigger callback if needed
if self.callback is not None:
@@ -341,18 +355,20 @@ class StopTrainingOnRewardThreshold(BaseCallback):
to stop training.
:param verbose: (int)
"""
+
def __init__(self, reward_threshold: float, verbose: int = 0):
super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose)
self.reward_threshold = reward_threshold
def _on_step(self) -> bool:
- assert self.parent is not None, ("``StopTrainingOnMinimumReward`` callback must be used "
- "with an ``EvalCallback``")
+ assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``"
# Convert np.bool to bool, otherwise callback() is False won't work
continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)
if self.verbose > 0 and not continue_training:
- print(f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
- f" is above the threshold {self.reward_threshold}")
+ print(
+ f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
+ f" is above the threshold {self.reward_threshold}"
+ )
return continue_training
@@ -364,6 +380,7 @@ class EveryNTimesteps(EventCallback):
:param callback: (BaseCallback) Callback that will be called
when the event is triggered.
"""
+
def __init__(self, n_steps: int, callback: BaseCallback):
super(EveryNTimesteps, self).__init__(callback)
self.n_steps = n_steps
diff --git a/stable_baselines3/common/cmd_util.py b/stable_baselines3/common/cmd_util.py
index ceeced7b7..4f49ccce5 100644
--- a/stable_baselines3/common/cmd_util.py
+++ b/stable_baselines3/common/cmd_util.py
@@ -1,23 +1,25 @@
import os
import warnings
-from typing import Dict, Any, Optional, Callable, Type, Union
+from typing import Any, Callable, Dict, Optional, Type, Union
import gym
-from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.atari_wrappers import AtariWrapper
+from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
-def make_vec_env(env_id: Union[str, Type[gym.Env]],
- n_envs: int = 1,
- seed: Optional[int] = None,
- start_index: int = 0,
- monitor_dir: Optional[str] = None,
- wrapper_class: Optional[Callable] = None,
- env_kwargs: Optional[Dict[str, Any]] = None,
- vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
- vec_env_kwargs: Optional[Dict[str, Any]] = None):
+def make_vec_env(
+ env_id: Union[str, Type[gym.Env]],
+ n_envs: int = 1,
+ seed: Optional[int] = None,
+ start_index: int = 0,
+ monitor_dir: Optional[str] = None,
+ wrapper_class: Optional[Callable] = None,
+ env_kwargs: Optional[Dict[str, Any]] = None,
+ vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
+):
"""
Create a wrapped, monitored ``VecEnv``.
By default it uses a ``DummyVecEnv`` which is usually faster
@@ -62,6 +64,7 @@ def _init():
if wrapper_class is not None:
env = wrapper_class(env)
return env
+
return _init
# No custom VecEnv is passed
@@ -72,15 +75,17 @@ def _init():
return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs)
-def make_atari_env(env_id: Union[str, Type[gym.Env]],
- n_envs: int = 1,
- seed: Optional[int] = None,
- start_index: int = 0,
- monitor_dir: Optional[str] = None,
- wrapper_kwargs: Optional[Dict[str, Any]] = None,
- env_kwargs: Optional[Dict[str, Any]] = None,
- vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
- vec_env_kwargs: Optional[Dict[str, Any]] = None):
+def make_atari_env(
+ env_id: Union[str, Type[gym.Env]],
+ n_envs: int = 1,
+ seed: Optional[int] = None,
+ start_index: int = 0,
+ monitor_dir: Optional[str] = None,
+ wrapper_kwargs: Optional[Dict[str, Any]] = None,
+ env_kwargs: Optional[Dict[str, Any]] = None,
+ vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None,
+ vec_env_kwargs: Optional[Dict[str, Any]] = None,
+):
"""
Create a wrapped, monitored VecEnv for Atari.
It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games.
@@ -105,6 +110,14 @@ def atari_wrapper(env: gym.Env) -> gym.Env:
env = AtariWrapper(env, **wrapper_kwargs)
return env
- return make_vec_env(env_id, n_envs=n_envs, seed=seed, start_index=start_index,
- monitor_dir=monitor_dir, wrapper_class=atari_wrapper,
- env_kwargs=env_kwargs, vec_env_cls=vec_env_cls, vec_env_kwargs=vec_env_kwargs)
+ return make_vec_env(
+ env_id,
+ n_envs=n_envs,
+ seed=seed,
+ start_index=start_index,
+ monitor_dir=monitor_dir,
+ wrapper_class=atari_wrapper,
+ env_kwargs=env_kwargs,
+ vec_env_cls=vec_env_cls,
+ vec_env_kwargs=vec_env_kwargs,
+ )
diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py
index 1a20928f2..f46691de6 100644
--- a/stable_baselines3/common/distributions.py
+++ b/stable_baselines3/common/distributions.py
@@ -1,12 +1,13 @@
"""Probability distributions."""
from abc import ABC, abstractmethod
-from typing import Optional, Tuple, Dict, Any, List
+from typing import Any, Dict, List, Optional, Tuple
+
import gym
import torch as th
-import torch.nn as nn
-from torch.distributions import Normal, Categorical, Bernoulli
from gym import spaces
+from torch import nn as nn
+from torch.distributions import Bernoulli, Categorical, Normal
from stable_baselines3.common.preprocessing import get_action_dim
@@ -25,7 +26,7 @@ def proba_distribution_net(self, *args, **kwargs):
concrete classes."""
@abstractmethod
- def proba_distribution(self, *args, **kwargs) -> 'Distribution':
+ def proba_distribution(self, *args, **kwargs) -> "Distribution":
"""Set parameters of the distribution.
:return: (Distribution) self
@@ -124,8 +125,7 @@ def __init__(self, action_dim: int):
self.mean_actions = None
self.log_std = None
- def proba_distribution_net(self, latent_dim: int,
- log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
+ def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the mean of the Gaussian, the other parameter will be the
@@ -140,8 +140,7 @@ def proba_distribution_net(self, latent_dim: int,
log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True)
return mean_actions, log_std
- def proba_distribution(self, mean_actions: th.Tensor,
- log_std: th.Tensor) -> 'DiagGaussianDistribution':
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
"""
Create the distribution given its parameters (mean, std)
@@ -174,15 +173,12 @@ def sample(self) -> th.Tensor:
def mode(self) -> th.Tensor:
return self.distribution.mean
- def actions_from_params(self, mean_actions: th.Tensor,
- log_std: th.Tensor,
- deterministic: bool = False) -> th.Tensor:
+ def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std)
return self.get_actions(deterministic=deterministic)
- def log_prob_from_params(self, mean_actions: th.Tensor,
- log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
Compute the log probability of taking an action
given the distribution parameters.
@@ -210,13 +206,11 @@ def __init__(self, action_dim: int, epsilon: float = 1e-6):
self.epsilon = epsilon
self.gaussian_actions = None
- def proba_distribution(self, mean_actions: th.Tensor,
- log_std: th.Tensor) -> 'SquashedDiagGaussianDistribution':
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution":
super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std)
return self
- def log_prob(self, actions: th.Tensor,
- gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
+ def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
# Inverse tanh
# Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x))
# We use numpy to avoid numerical instability
@@ -246,8 +240,7 @@ def mode(self) -> th.Tensor:
# Squash the output
return th.tanh(self.gaussian_actions)
- def log_prob_from_params(self, mean_actions: th.Tensor,
- log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
action = self.actions_from_params(mean_actions, log_std)
log_prob = self.log_prob(action, self.gaussian_actions)
return action, log_prob
@@ -278,7 +271,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
action_logits = nn.Linear(latent_dim, self.action_dim)
return action_logits
- def proba_distribution(self, action_logits: th.Tensor) -> 'CategoricalDistribution':
+ def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution":
self.distribution = Categorical(logits=action_logits)
return self
@@ -294,8 +287,7 @@ def sample(self) -> th.Tensor:
def mode(self) -> th.Tensor:
return th.argmax(self.distribution.probs, dim=1)
- def actions_from_params(self, action_logits: th.Tensor,
- deterministic: bool = False) -> th.Tensor:
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
@@ -332,14 +324,15 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
return action_logits
- def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution':
+ def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution":
self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)]
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
# Extract each discrete action and compute log prob for their respective distributions
- return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions,
- th.unbind(actions, dim=1))], dim=1).sum(dim=1)
+ return th.stack(
+ [dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1
+ ).sum(dim=1)
def entropy(self) -> th.Tensor:
return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1)
@@ -350,8 +343,7 @@ def sample(self) -> th.Tensor:
def mode(self) -> th.Tensor:
return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1)
- def actions_from_params(self, action_logits: th.Tensor,
- deterministic: bool = False) -> th.Tensor:
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
@@ -386,7 +378,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module:
action_logits = nn.Linear(latent_dim, self.action_dims)
return action_logits
- def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution':
+ def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution":
self.distribution = Bernoulli(logits=action_logits)
return self
@@ -402,8 +394,7 @@ def sample(self) -> th.Tensor:
def mode(self) -> th.Tensor:
return th.round(self.distribution.probs)
- def actions_from_params(self, action_logits: th.Tensor,
- deterministic: bool = False) -> th.Tensor:
+ def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(action_logits)
return self.get_actions(deterministic=deterministic)
@@ -436,12 +427,15 @@ class StateDependentNoiseDistribution(Distribution):
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
"""
- def __init__(self, action_dim: int,
- full_std: bool = True,
- use_expln: bool = False,
- squash_output: bool = False,
- learn_features: bool = False,
- epsilon: float = 1e-6):
+ def __init__(
+ self,
+ action_dim: int,
+ full_std: bool = True,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ learn_features: bool = False,
+ epsilon: float = 1e-6,
+ ):
super(StateDependentNoiseDistribution, self).__init__()
self.distribution = None
self.action_dim = action_dim
@@ -501,8 +495,9 @@ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None:
# Pre-compute matrices in case of parallel exploration
self.exploration_matrices = self.weights_dist.rsample((batch_size,))
- def proba_distribution_net(self, latent_dim: int, log_std_init: float = -2.0,
- latent_sde_dim: Optional[int] = None) -> Tuple[nn.Module, nn.Parameter]:
+ def proba_distribution_net(
+ self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None
+ ) -> Tuple[nn.Module, nn.Parameter]:
"""
Create the layers and parameter that represent the distribution:
one output will be the deterministic action, the other parameter will be the
@@ -527,9 +522,9 @@ def proba_distribution_net(self, latent_dim: int, log_std_init: float = -2.0,
self.sample_weights(log_std)
return mean_actions_net, log_std
- def proba_distribution(self, mean_actions: th.Tensor,
- log_std: th.Tensor,
- latent_sde: th.Tensor) -> 'StateDependentNoiseDistribution':
+ def proba_distribution(
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
+ ) -> "StateDependentNoiseDistribution":
"""
Create the distribution given its parameters (mean, std)
@@ -591,17 +586,16 @@ def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
noise = th.bmm(latent_sde, self.exploration_matrices)
return noise.squeeze(1)
- def actions_from_params(self, mean_actions: th.Tensor,
- log_std: th.Tensor,
- latent_sde: th.Tensor,
- deterministic: bool = False) -> th.Tensor:
+ def actions_from_params(
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False
+ ) -> th.Tensor:
# Update the proba distribution
self.proba_distribution(mean_actions, log_std, latent_sde)
return self.get_actions(deterministic=deterministic)
- def log_prob_from_params(self, mean_actions: th.Tensor,
- log_std: th.Tensor,
- latent_sde: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
+ def log_prob_from_params(
+ self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor
+ ) -> Tuple[th.Tensor, th.Tensor]:
actions = self.actions_from_params(mean_actions, log_std, latent_sde)
log_prob = self.log_prob(actions)
return actions, log_prob
@@ -644,16 +638,16 @@ def inverse(y: th.Tensor) -> th.Tensor:
"""
eps = th.finfo(y.dtype).eps
# Clip the action to avoid NaN
- return TanhBijector.atanh(y.clamp(min=-1. + eps, max=1. - eps))
+ return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))
def log_prob_correction(self, x: th.Tensor) -> th.Tensor:
# Squash correction (from original SAC implementation)
return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon)
-def make_proba_distribution(action_space: gym.spaces.Space,
- use_sde: bool = False,
- dist_kwargs: Optional[Dict[str, Any]] = None) -> Distribution:
+def make_proba_distribution(
+ action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
+) -> Distribution:
"""
Return an instance of Distribution for the correct type of action space
@@ -677,6 +671,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
elif isinstance(action_space, spaces.MultiBinary):
return BernoulliDistribution(action_space.n, **dist_kwargs)
else:
- raise NotImplementedError("Error: probability distribution, not implemented for action space"
- f"of type {type(action_space)}."
- " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary.")
+ raise NotImplementedError(
+ "Error: probability distribution, not implemented for action space"
+ f"of type {type(action_space)}."
+ " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary."
+ )
diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py
index dbed5b8bf..558a9fe0a 100644
--- a/stable_baselines3/common/env_checker.py
+++ b/stable_baselines3/common/env_checker.py
@@ -2,8 +2,8 @@
from typing import Union
import gym
-from gym import spaces
import numpy as np
+from gym import spaces
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
@@ -22,41 +22,48 @@ def _check_image_input(observation_space: spaces.Box) -> None:
when the observation is apparently an image.
"""
if observation_space.dtype != np.uint8:
- warnings.warn("It seems that your observation is an image but the `dtype` "
- "of your observation_space is not `np.uint8`. "
- "If your observation is not an image, we recommend you to flatten the observation "
- "to have only a 1D vector")
+ warnings.warn(
+ "It seems that your observation is an image but the `dtype` "
+ "of your observation_space is not `np.uint8`. "
+ "If your observation is not an image, we recommend you to flatten the observation "
+ "to have only a 1D vector"
+ )
if np.any(observation_space.low != 0) or np.any(observation_space.high != 255):
- warnings.warn("It seems that your observation space is an image but the "
- "upper and lower bounds are not in [0, 255]. "
- "Because the CNN policy normalize automatically the observation "
- "you may encounter issue if the values are not in that range."
- )
+ warnings.warn(
+ "It seems that your observation space is an image but the "
+ "upper and lower bounds are not in [0, 255]. "
+ "Because the CNN policy normalize automatically the observation "
+ "you may encounter issue if the values are not in that range."
+ )
if observation_space.shape[0] < 36 or observation_space.shape[1] < 36:
- warnings.warn("The minimal resolution for an image is 36x36 for the default CnnPolicy. "
- "You might need to use a custom `cnn_extractor` "
- "cf https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html")
+ warnings.warn(
+ "The minimal resolution for an image is 36x36 for the default CnnPolicy. "
+ "You might need to use a custom `cnn_extractor` "
+ "cf https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html"
+ )
def _check_unsupported_obs_spaces(env: gym.Env, observation_space: spaces.Space) -> None:
"""Emit warnings when the observation space used is not supported by Stable-Baselines."""
if isinstance(observation_space, spaces.Dict) and not isinstance(env, gym.GoalEnv):
- warnings.warn("The observation space is a Dict but the environment is not a gym.GoalEnv "
- "(cf https://github.com/openai/gym/blob/master/gym/core.py), "
- "this is currently not supported by Stable Baselines "
- "(cf https://github.com/hill-a/stable-baselines/issues/133), "
- "you will need to use a custom policy. "
- )
+ warnings.warn(
+ "The observation space is a Dict but the environment is not a gym.GoalEnv "
+ "(cf https://github.com/openai/gym/blob/master/gym/core.py), "
+ "this is currently not supported by Stable Baselines "
+ "(cf https://github.com/hill-a/stable-baselines/issues/133), "
+ "you will need to use a custom policy. "
+ )
if isinstance(observation_space, spaces.Tuple):
- warnings.warn("The observation space is a Tuple,"
- "this is currently not supported by Stable Baselines "
- "(cf https://github.com/hill-a/stable-baselines/issues/133), "
- "you will need to flatten the observation and maybe use a custom policy. "
- )
+ warnings.warn(
+ "The observation space is a Tuple,"
+ "this is currently not supported by Stable Baselines "
+ "(cf https://github.com/hill-a/stable-baselines/issues/133), "
+ "you will need to flatten the observation and maybe use a custom policy. "
+ )
def _check_nan(env: gym.Env) -> None:
@@ -67,26 +74,27 @@ def _check_nan(env: gym.Env) -> None:
_, _, _, _ = vec_env.step(action)
-def _check_obs(obs: Union[tuple, dict, np.ndarray, int],
- observation_space: spaces.Space,
- method_name: str) -> None:
+def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None:
"""
Check that the observation returned by the environment
correspond to the declared one.
"""
if not isinstance(observation_space, spaces.Tuple):
- assert not isinstance(obs, tuple), ("The observation returned by the `{}()` "
- "method should be a single value, not a tuple".format(method_name))
+ assert not isinstance(
+ obs, tuple
+ ), "The observation returned by the `{}()` method should be a single value, not a tuple".format(method_name)
# The check for a GoalEnv is done by the base class
if isinstance(observation_space, spaces.Discrete):
assert isinstance(obs, int), "The observation returned by `{}()` method must be an int".format(method_name)
elif _enforce_array_obs(observation_space):
- assert isinstance(obs, np.ndarray), ("The observation returned by `{}()` "
- "method must be a numpy array".format(method_name))
+ assert isinstance(obs, np.ndarray), "The observation returned by `{}()` method must be a numpy array".format(
+ method_name
+ )
- assert observation_space.contains(obs), ("The observation returned by the `{}()` "
- "method does not match the given observation space".format(method_name))
+ assert observation_space.contains(
+ obs
+ ), "The observation returned by the `{}()` method does not match the given observation space".format(method_name)
def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None:
@@ -96,7 +104,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
# because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists
obs = env.reset()
- _check_obs(obs, observation_space, 'reset')
+ _check_obs(obs, observation_space, "reset")
# Sample a random action
action = action_space.sample()
@@ -107,7 +115,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
# Unpack
obs, reward, done, info = data
- _check_obs(obs, observation_space, 'step')
+ _check_obs(obs, observation_space, "step")
# We also allow int because the reward will be cast to float
assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float"
@@ -116,7 +124,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action
if isinstance(env, gym.GoalEnv):
# For a GoalEnv, the keys are checked at reset
- assert reward == env.compute_reward(obs['achieved_goal'], obs['desired_goal'], info)
+ assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info)
def _check_spaces(env: gym.Env) -> None:
@@ -127,11 +135,10 @@ def _check_spaces(env: gym.Env) -> None:
# Helper to link to the code, because gym has no proper documentation
gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/"
- assert hasattr(env, 'observation_space'), "You must specify an observation space (cf gym.spaces)" + gym_spaces
- assert hasattr(env, 'action_space'), "You must specify an action space (cf gym.spaces)" + gym_spaces
+ assert hasattr(env, "observation_space"), "You must specify an observation space (cf gym.spaces)" + gym_spaces
+ assert hasattr(env, "action_space"), "You must specify an action space (cf gym.spaces)" + gym_spaces
- assert isinstance(env.observation_space,
- spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces
+ assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces
@@ -145,18 +152,20 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No
:param headless: (bool) Whether to disable render modes
that require a graphical interface. False by default.
"""
- render_modes = env.metadata.get('render.modes')
+ render_modes = env.metadata.get("render.modes")
if render_modes is None:
if warn:
- warnings.warn("No render modes was declared in the environment "
- " (env.metadata['render.modes'] is None or not defined), "
- "you may have trouble when calling `.render()`")
+ warnings.warn(
+ "No render modes was declared in the environment "
+ " (env.metadata['render.modes'] is None or not defined), "
+ "you may have trouble when calling `.render()`"
+ )
else:
# Don't check render mode that require a
# graphical interface (useful for CI)
- if headless and 'human' in render_modes:
- render_modes.remove('human')
+ if headless and "human" in render_modes:
+ render_modes.remove("human")
# Check all declared render modes
for render_mode in render_modes:
env.render(mode=render_mode)
@@ -178,8 +187,9 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
:param skip_render_check: (bool) Whether to skip the checks for the render method.
True by default (useful for the CI)
"""
- assert isinstance(env, gym.Env), ("You environment must inherit from gym.Env class "
- " cf https://github.com/openai/gym/blob/master/gym/core.py")
+ assert isinstance(
+ env, gym.Env
+ ), "You environment must inherit from gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py"
# ============= Check the spaces (observation and action) ================
_check_spaces(env)
@@ -199,16 +209,22 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -
_check_image_input(observation_space)
if isinstance(observation_space, spaces.Box) and len(observation_space.shape) not in [1, 3]:
- warnings.warn("Your observation has an unconventional shape (neither an image, nor a 1D vector). "
- "We recommend you to flatten the observation "
- "to have only a 1D vector")
+ warnings.warn(
+ "Your observation has an unconventional shape (neither an image, nor a 1D vector). "
+ "We recommend you to flatten the observation "
+ "to have only a 1D vector"
+ )
# Check for the action space, it may lead to hard-to-debug issues
- if (isinstance(action_space, spaces.Box) and
- (np.any(np.abs(action_space.low) != np.abs(action_space.high))
- or np.any(np.abs(action_space.low) > 1) or np.any(np.abs(action_space.high) > 1))):
- warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
- "cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html")
+ if isinstance(action_space, spaces.Box) and (
+ np.any(np.abs(action_space.low) != np.abs(action_space.high))
+ or np.any(np.abs(action_space.low) > 1)
+ or np.any(np.abs(action_space.high) > 1)
+ ):
+ warnings.warn(
+ "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
+ "cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html"
+ )
# ============ Check the returned values ===============
_check_returned_values(env, observation_space, action_space)
diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py
index b298b65b4..6dac4d5a7 100644
--- a/stable_baselines3/common/evaluation.py
+++ b/stable_baselines3/common/evaluation.py
@@ -4,9 +4,16 @@
from stable_baselines3.common.vec_env import VecEnv
-def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
- render=False, callback=None, reward_threshold=None,
- return_episode_rewards=False):
+def evaluate_policy(
+ model,
+ env,
+ n_eval_episodes=10,
+ deterministic=True,
+ render=False,
+ callback=None,
+ reward_threshold=None,
+ return_episode_rewards=False,
+):
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.
@@ -49,8 +56,7 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True,
mean_reward = np.mean(episode_rewards)
std_reward = np.std(episode_rewards)
if reward_threshold is not None:
- assert mean_reward > reward_threshold, ('Mean reward below threshold: '
- f'{mean_reward:.2f} < {reward_threshold:.2f}')
+ assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}"
if return_episode_rewards:
return episode_rewards, episode_lengths
return mean_reward, std_reward
diff --git a/stable_baselines3/common/identity_env.py b/stable_baselines3/common/identity_env.py
index df77f6ac2..4a492e75a 100644
--- a/stable_baselines3/common/identity_env.py
+++ b/stable_baselines3/common/identity_env.py
@@ -1,18 +1,14 @@
-from typing import Union, Optional
+from typing import Optional, Union
import numpy as np
from gym import Env, Space
-from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box
+from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete
-
-from stable_baselines3.common.type_aliases import GymStepReturn, GymObs
+from stable_baselines3.common.type_aliases import GymObs, GymStepReturn
class IdentityEnv(Env):
- def __init__(self,
- dim: Optional[int] = None,
- space: Optional[Space] = None,
- ep_length: int = 100):
+ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100):
"""
Identity environment for testing purposes
@@ -55,14 +51,12 @@ def _choose_next_state(self) -> None:
def _get_reward(self, action: Union[int, np.ndarray]) -> float:
return 1.0 if np.all(self.state == action) else 0.0
- def render(self, mode: str = 'human') -> None:
+ def render(self, mode: str = "human") -> None:
pass
class IdentityEnvBox(IdentityEnv):
- def __init__(self, low: float = -1.0,
- high: float = 1.0, eps: float = 0.05,
- ep_length: int = 100):
+ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100):
"""
Identity environment for testing purposes
@@ -120,14 +114,12 @@ class FakeImageEnv(Env):
:param n_channels: (int) Number of color channels
:param discrete: (bool)
"""
- def __init__(self, action_dim: int = 6,
- screen_height: int = 84,
- screen_width: int = 84,
- n_channels: int = 1,
- discrete: bool = True):
-
- self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width,
- n_channels), dtype=np.uint8)
+
+ def __init__(
+ self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True
+ ):
+
+ self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8)
if discrete:
self.action_space = Discrete(action_dim)
else:
@@ -145,5 +137,5 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
done = self.current_step >= self.ep_length
return self.observation_space.sample(), reward, done, {}
- def render(self, mode: str = 'human') -> None:
+ def render(self, mode: str = "human") -> None:
pass
diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py
index d23726e64..bed63273e 100644
--- a/stable_baselines3/common/logger.py
+++ b/stable_baselines3/common/logger.py
@@ -1,15 +1,16 @@
-import sys
import datetime
import json
import os
+import sys
import tempfile
import warnings
from collections import defaultdict
-from typing import Dict, List, TextIO, Union, Any, Optional, Tuple
+from typing import Any, Dict, List, Optional, TextIO, Tuple, Union
-import pandas
import numpy as np
+import pandas
import torch as th
+
try:
from torch.utils.tensorboard import SummaryWriter
except ImportError:
@@ -27,8 +28,7 @@ class KVWriter(object):
Key Value writer
"""
- def write(self, key_values: Dict[str, Any],
- key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
"""
Write a dictionary to file
@@ -67,10 +67,10 @@ def __init__(self, filename_or_file: Union[str, TextIO]):
:param filename_or_file: (str or File) the file to write the log to
"""
if isinstance(filename_or_file, str):
- self.file = open(filename_or_file, 'wt')
+ self.file = open(filename_or_file, "wt")
self.own_file = True
else:
- assert hasattr(filename_or_file, 'write'), f'Expected file or str, got {filename_or_file}'
+ assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}"
self.file = filename_or_file
self.own_file = False
@@ -80,56 +80,56 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
tag = None
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
- if excluded is not None and 'stdout' in excluded:
+ if excluded is not None and "stdout" in excluded:
continue
if isinstance(value, float):
# Align left
- value_str = f'{value:<8.3g}'
+ value_str = f"{value:<8.3g}"
else:
value_str = str(value)
- if key.find('/') > 0: # Find tag and add it to the dict
- tag = key[:key.find('/') + 1]
- key2str[self._truncate(tag)] = ''
+ if key.find("/") > 0: # Find tag and add it to the dict
+ tag = key[: key.find("/") + 1]
+ key2str[self._truncate(tag)] = ""
# Remove tag from key
if tag is not None and tag in key:
- key = str(' ' + key[len(tag):])
+ key = str(" " + key[len(tag) :])
key2str[self._truncate(key)] = self._truncate(value_str)
# Find max widths
if len(key2str) == 0:
- warnings.warn('Tried to write empty key-value dict')
+ warnings.warn("Tried to write empty key-value dict")
return
else:
key_width = max(map(len, key2str.keys()))
val_width = max(map(len, key2str.values()))
# Write out the data
- dashes = '-' * (key_width + val_width + 7)
+ dashes = "-" * (key_width + val_width + 7)
lines = [dashes]
for key, value in key2str.items():
- key_space = ' ' * (key_width - len(key))
- val_space = ' ' * (val_width - len(value))
+ key_space = " " * (key_width - len(key))
+ val_space = " " * (val_width - len(value))
lines.append(f"| {key}{key_space} | {value}{val_space} |")
lines.append(dashes)
- self.file.write('\n'.join(lines) + '\n')
+ self.file.write("\n".join(lines) + "\n")
# Flush the output to the file
self.file.flush()
@classmethod
def _truncate(cls, string: str, max_length: int = 23) -> str:
- return string[:max_length - 3] + '...' if len(string) > max_length else string
+ return string[: max_length - 3] + "..." if len(string) > max_length else string
def write_sequence(self, sequence: List) -> None:
sequence = list(sequence)
for i, elem in enumerate(sequence):
self.file.write(elem)
if i < len(sequence) - 1: # add space unless this is the last one
- self.file.write(' ')
- self.file.write('\n')
+ self.file.write(" ")
+ self.file.write("\n")
self.file.flush()
def close(self) -> None:
@@ -147,23 +147,22 @@ def __init__(self, filename: str):
:param filename: (str) the file to write the log to
"""
- self.file = open(filename, 'wt')
+ self.file = open(filename, "wt")
- def write(self, key_values: Dict[str, Any],
- key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
- if excluded is not None and 'json' in excluded:
+ if excluded is not None and "json" in excluded:
continue
- if hasattr(value, 'dtype'):
+ if hasattr(value, "dtype"):
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
key_values[key] = float(value)
else:
# otherwise, a value is a numpy array, serialize as a list or nested lists
key_values[key] = value.tolist()
- self.file.write(json.dumps(key_values) + '\n')
+ self.file.write(json.dumps(key_values) + "\n")
self.file.flush()
def close(self) -> None:
@@ -182,12 +181,11 @@ def __init__(self, filename: str):
:param filename: (str) the file to write the log to
"""
- self.file = open(filename, 'w+t')
+ self.file = open(filename, "w+t")
self.keys = []
- self.separator = ','
+ self.separator = ","
- def write(self, key_values: Dict[str, Any],
- key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
# Add our current row to the history
extra_keys = key_values.keys() - self.keys
if extra_keys:
@@ -197,20 +195,20 @@ def write(self, key_values: Dict[str, Any],
self.file.seek(0)
for (i, key) in enumerate(self.keys):
if i > 0:
- self.file.write(',')
+ self.file.write(",")
self.file.write(key)
- self.file.write('\n')
+ self.file.write("\n")
for line in lines[1:]:
self.file.write(line[:-1])
self.file.write(self.separator * len(extra_keys))
- self.file.write('\n')
+ self.file.write("\n")
for i, key in enumerate(self.keys):
if i > 0:
- self.file.write(',')
+ self.file.write(",")
value = key_values.get(key)
if value is not None:
self.file.write(str(value))
- self.file.write('\n')
+ self.file.write("\n")
self.file.flush()
def close(self) -> None:
@@ -227,17 +225,14 @@ def __init__(self, folder: str):
:param folder: (str) the folder to write the log to
"""
- assert SummaryWriter is not None, ("tensorboard is not installed, you can use "
- "pip install tensorboard to do so")
+ assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so"
self.writer = SummaryWriter(log_dir=folder)
- def write(self, key_values: Dict[str, Any],
- key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
+ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None:
- for (key, value), (_, excluded) in zip(sorted(key_values.items()),
- sorted(key_excluded.items())):
+ for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())):
- if excluded is not None and 'tensorboard' in excluded:
+ if excluded is not None and "tensorboard" in excluded:
continue
if isinstance(value, np.ScalarType):
@@ -258,7 +253,7 @@ def close(self) -> None:
self.writer = None
-def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWriter:
+def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter:
"""
return a logger for the requested format
@@ -268,26 +263,26 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWr
:return: (KVWriter) the logger
"""
os.makedirs(log_dir, exist_ok=True)
- if _format == 'stdout':
+ if _format == "stdout":
return HumanOutputFormat(sys.stdout)
- elif _format == 'log':
- return HumanOutputFormat(os.path.join(log_dir, f'log{log_suffix}.txt'))
- elif _format == 'json':
- return JSONOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.json'))
- elif _format == 'csv':
- return CSVOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.csv'))
- elif _format == 'tensorboard':
+ elif _format == "log":
+ return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt"))
+ elif _format == "json":
+ return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json"))
+ elif _format == "csv":
+ return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv"))
+ elif _format == "tensorboard":
return TensorBoardOutputFormat(log_dir)
else:
- raise ValueError(f'Unknown format specified: {_format}')
+ raise ValueError(f"Unknown format specified: {_format}")
# ================================================================
# API
# ================================================================
-def record(key: str, value: Any,
- exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
+
+def record(key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
@@ -300,8 +295,7 @@ def record(key: str, value: Any,
Logger.CURRENT.record(key, value, exclude)
-def record_mean(key: str, value: Union[int, float],
- exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
+def record_mean(key: str, value: Union[int, float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
The same as record(), but if called many times, values averaged.
@@ -431,6 +425,7 @@ def get_dir() -> str:
# Backend
# ================================================================
+
class Logger(object):
# A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output files
@@ -453,8 +448,7 @@ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]):
# Logging API, forwarded
# ----------------------------------------
- def record(self, key: str, value: Any,
- exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
+ def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
@@ -467,8 +461,7 @@ def record(self, key: str, value: Any,
self.name_to_value[key] = value
self.name_to_excluded[key] = exclude
- def record_mean(self, key: str, value: Any,
- exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
+ def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None:
"""
The same as record(), but if called many times, values averaged.
@@ -565,21 +558,21 @@ def configure(folder: Optional[str] = None, format_strings: Optional[List[str]]
(if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv'])
"""
if folder is None:
- folder = os.getenv('SB3_LOGDIR')
+ folder = os.getenv("SB3_LOGDIR")
if folder is None:
folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"))
assert isinstance(folder, str)
os.makedirs(folder, exist_ok=True)
- log_suffix = ''
+ log_suffix = ""
if format_strings is None:
- format_strings = os.getenv('SB3_LOG_FORMAT', 'stdout,log,csv').split(',')
+ format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",")
format_strings = filter(None, format_strings)
output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings]
Logger.CURRENT = Logger(folder=folder, output_formats=output_formats)
- log(f'Logging to {folder}')
+ log(f"Logging to {folder}")
def reset() -> None:
@@ -589,7 +582,7 @@ def reset() -> None:
if Logger.CURRENT is not Logger.DEFAULT:
Logger.CURRENT.close()
Logger.CURRENT = Logger.DEFAULT
- log('Reset logger')
+ log("Reset logger")
class ScopedConfigure(object):
@@ -621,6 +614,7 @@ def __exit__(self, *args) -> None:
# Readers
# ================================================================
+
def read_json(filename: str) -> pandas.DataFrame:
"""
read a json file using pandas
@@ -629,7 +623,7 @@ def read_json(filename: str) -> pandas.DataFrame:
:return: (pandas.DataFrame) the data in the json
"""
data = []
- with open(filename, 'rt') as file_handler:
+ with open(filename, "rt") as file_handler:
for line in file_handler:
data.append(json.loads(line))
return pandas.DataFrame(data)
@@ -642,4 +636,4 @@ def read_csv(filename: str) -> pandas.DataFrame:
:param filename: (str) the file path to read
:return: (pandas.DataFrame) the data in the csv
"""
- return pandas.read_csv(filename, index_col=None, comment='#')
+ return pandas.read_csv(filename, index_col=None, comment="#")
diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py
index 3210a74d2..df6719f2a 100644
--- a/stable_baselines3/common/monitor.py
+++ b/stable_baselines3/common/monitor.py
@@ -1,15 +1,15 @@
-__all__ = ['Monitor', 'get_monitor_files', 'load_results']
+__all__ = ["Monitor", "get_monitor_files", "load_results"]
import csv
import json
import os
import time
from glob import glob
-from typing import Tuple, Dict, Any, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
import gym
-import pandas
import numpy as np
+import pandas
class Monitor(gym.Wrapper):
@@ -23,14 +23,17 @@ class Monitor(gym.Wrapper):
if extra parameters are needed at reset
:param info_keywords: (Tuple[str, ...]) extra information to log, from the information return of env.step()
"""
+
EXT = "monitor.csv"
- def __init__(self,
- env: gym.Env,
- filename: Optional[str] = None,
- allow_early_resets: bool = True,
- reset_keywords: Tuple[str, ...] = (),
- info_keywords: Tuple[str, ...] = ()):
+ def __init__(
+ self,
+ env: gym.Env,
+ filename: Optional[str] = None,
+ allow_early_resets: bool = True,
+ reset_keywords: Tuple[str, ...] = (),
+ info_keywords: Tuple[str, ...] = (),
+ ):
super(Monitor, self).__init__(env=env)
self.t_start = time.time()
if filename is None:
@@ -43,9 +46,8 @@ def __init__(self,
else:
filename = filename + "." + Monitor.EXT
self.file_handler = open(filename, "wt")
- self.file_handler.write('#%s\n' % json.dumps({"t_start": self.t_start, 'env_id': env.spec and env.spec.id}))
- self.logger = csv.DictWriter(self.file_handler,
- fieldnames=('r', 'l', 't') + reset_keywords + info_keywords)
+ self.file_handler.write("#%s\n" % json.dumps({"t_start": self.t_start, "env_id": env.spec and env.spec.id}))
+ self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + reset_keywords + info_keywords)
self.logger.writeheader()
self.file_handler.flush()
@@ -68,14 +70,16 @@ def reset(self, **kwargs) -> np.ndarray:
:return: (np.ndarray) the first observation of the environment
"""
if not self.allow_early_resets and not self.needs_reset:
- raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, "
- "wrap your env with Monitor(env, path, allow_early_resets=True)")
+ raise RuntimeError(
+ "Tried to reset an environment before done. If you want to allow early resets, "
+ "wrap your env with Monitor(env, path, allow_early_resets=True)"
+ )
self.rewards = []
self.needs_reset = False
for key in self.reset_keywords:
value = kwargs.get(key)
if value is None:
- raise ValueError('Expected you to pass kwarg {} into reset'.format(key))
+ raise ValueError("Expected you to pass kwarg {} into reset".format(key))
self.current_reset_info[key] = value
return self.env.reset(**kwargs)
@@ -104,7 +108,7 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, A
if self.logger:
self.logger.writerow(ep_info)
self.file_handler.flush()
- info['episode'] = ep_info
+ info["episode"] = ep_info
self.total_steps += 1
return observation, reward, done, info
@@ -153,6 +157,7 @@ class LoadMonitorResultsError(Exception):
"""
Raised when loading the monitor log fails.
"""
+
pass
@@ -178,16 +183,16 @@ def load_results(path: str) -> pandas.DataFrame:
raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, path))
data_frames, headers = [], []
for file_name in monitor_files:
- with open(file_name, 'rt') as file_handler:
+ with open(file_name, "rt") as file_handler:
first_line = file_handler.readline()
- assert first_line[0] == '#'
+ assert first_line[0] == "#"
header = json.loads(first_line[1:])
data_frame = pandas.read_csv(file_handler, index_col=None)
headers.append(header)
- data_frame['t'] += header['t_start']
+ data_frame["t"] += header["t_start"]
data_frames.append(data_frame)
data_frame = pandas.concat(data_frames)
- data_frame.sort_values('t', inplace=True)
+ data_frame.sort_values("t", inplace=True)
data_frame.reset_index(inplace=True)
- data_frame['t'] -= min(header['t_start'] for header in headers)
+ data_frame["t"] -= min(header["t_start"] for header in headers)
return data_frame
diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py
index 63b4b4a48..f7b8b7372 100644
--- a/stable_baselines3/common/noise.py
+++ b/stable_baselines3/common/noise.py
@@ -1,6 +1,6 @@
-from typing import Optional, List, Iterable
-from abc import ABC, abstractmethod
import copy
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Optional
import numpy as np
@@ -41,7 +41,7 @@ def __call__(self) -> np.ndarray:
return np.random.normal(self._mu, self._sigma)
def __repr__(self) -> str:
- return f'NormalActionNoise(mu={self._mu}, sigma={self._sigma})'
+ return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})"
class OrnsteinUhlenbeckActionNoise(ActionNoise):
@@ -57,11 +57,14 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
:param initial_noise: (Optional[np.ndarray]) the initial value for the noise output, (if None: 0)
"""
- def __init__(self, mean: np.ndarray,
- sigma: np.ndarray,
- theta: float = .15,
- dt: float = 1e-2,
- initial_noise: Optional[np.ndarray] = None):
+ def __init__(
+ self,
+ mean: np.ndarray,
+ sigma: np.ndarray,
+ theta: float = 0.15,
+ dt: float = 1e-2,
+ initial_noise: Optional[np.ndarray] = None,
+ ):
self._theta = theta
self._mu = mean
self._sigma = sigma
@@ -72,8 +75,11 @@ def __init__(self, mean: np.ndarray,
super(OrnsteinUhlenbeckActionNoise, self).__init__()
def __call__(self) -> np.ndarray:
- noise = (self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt
- + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape))
+ noise = (
+ self.noise_prev
+ + self._theta * (self._mu - self.noise_prev) * self._dt
+ + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
+ )
self.noise_prev = noise
return noise
@@ -84,7 +90,7 @@ def reset(self) -> None:
self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
def __repr__(self) -> str:
- return f'OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})'
+ return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})"
class VectorizedActionNoise(ActionNoise):
@@ -149,15 +155,11 @@ def noises(self, noises: List[ActionNoise]) -> None:
noises = list(noises) # raises TypeError if not iterable
assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}."
- different_types = [
- i for i, noise in enumerate(noises)
- if not isinstance(noise, type(self.base_noise))
- ]
+ different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))]
if len(different_types):
raise ValueError(
- f"Noise instances at indices {different_types} don't match the type of base_noise",
- type(self.base_noise)
+ f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise)
)
self._noises = noises
diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py
index 5a8d085b0..b2814df14 100644
--- a/stable_baselines3/common/off_policy_algorithm.py
+++ b/stable_baselines3/common/off_policy_algorithm.py
@@ -1,23 +1,23 @@
+import io
+import pathlib
import time
import warnings
-import pathlib
-from typing import Union, Type, Optional, Dict, Any, Callable, List, Tuple
-import io
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
-import torch as th
import numpy as np
+import torch as th
from stable_baselines3.common import logger
from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.buffers import ReplayBuffer
+from stable_baselines3.common.callbacks import BaseCallback
+from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy
+from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
+from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn
from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.vec_env import VecEnv
-from stable_baselines3.common.type_aliases import GymEnv, RolloutReturn, MaybeCallback
-from stable_baselines3.common.callbacks import BaseCallback
-from stable_baselines3.common.noise import ActionNoise
-from stable_baselines3.common.buffers import ReplayBuffer
-from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl
class OffPolicyAlgorithm(BaseAlgorithm):
@@ -69,40 +69,52 @@ class OffPolicyAlgorithm(BaseAlgorithm):
:param sde_support: (bool) Whether the model support gSDE or not
"""
- def __init__(self,
- policy: Type[BasePolicy],
- env: Union[GymEnv, str],
- policy_base: Type[BasePolicy],
- learning_rate: Union[float, Callable],
- buffer_size: int = int(1e6),
- learning_starts: int = 100,
- batch_size: int = 256,
- tau: float = 0.005,
- gamma: float = 0.99,
- train_freq: int = 1,
- gradient_steps: int = 1,
- n_episodes_rollout: int = -1,
- action_noise: Optional[ActionNoise] = None,
- optimize_memory_usage: bool = False,
- policy_kwargs: Dict[str, Any] = None,
- tensorboard_log: Optional[str] = None,
- verbose: int = 0,
- device: Union[th.device, str] = 'auto',
- support_multi_env: bool = False,
- create_eval_env: bool = False,
- monitor_wrapper: bool = True,
- seed: Optional[int] = None,
- use_sde: bool = False,
- sde_sample_freq: int = -1,
- use_sde_at_warmup: bool = False,
- sde_support: bool = True):
-
- super(OffPolicyAlgorithm, self).__init__(policy=policy, env=env, policy_base=policy_base,
- learning_rate=learning_rate, policy_kwargs=policy_kwargs,
- tensorboard_log=tensorboard_log, verbose=verbose,
- device=device, support_multi_env=support_multi_env,
- create_eval_env=create_eval_env, monitor_wrapper=monitor_wrapper,
- seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq)
+ def __init__(
+ self,
+ policy: Type[BasePolicy],
+ env: Union[GymEnv, str],
+ policy_base: Type[BasePolicy],
+ learning_rate: Union[float, Callable],
+ buffer_size: int = int(1e6),
+ learning_starts: int = 100,
+ batch_size: int = 256,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ train_freq: int = 1,
+ gradient_steps: int = 1,
+ n_episodes_rollout: int = -1,
+ action_noise: Optional[ActionNoise] = None,
+ optimize_memory_usage: bool = False,
+ policy_kwargs: Dict[str, Any] = None,
+ tensorboard_log: Optional[str] = None,
+ verbose: int = 0,
+ device: Union[th.device, str] = "auto",
+ support_multi_env: bool = False,
+ create_eval_env: bool = False,
+ monitor_wrapper: bool = True,
+ seed: Optional[int] = None,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ use_sde_at_warmup: bool = False,
+ sde_support: bool = True,
+ ):
+
+ super(OffPolicyAlgorithm, self).__init__(
+ policy=policy,
+ env=env,
+ policy_base=policy_base,
+ learning_rate=learning_rate,
+ policy_kwargs=policy_kwargs,
+ tensorboard_log=tensorboard_log,
+ verbose=verbose,
+ device=device,
+ support_multi_env=support_multi_env,
+ create_eval_env=create_eval_env,
+ monitor_wrapper=monitor_wrapper,
+ seed=seed,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ )
self.buffer_size = buffer_size
self.batch_size = batch_size
self.learning_starts = learning_starts
@@ -115,30 +127,40 @@ def __init__(self,
self.optimize_memory_usage = optimize_memory_usage
if train_freq > 0 and n_episodes_rollout > 0:
- warnings.warn("You passed a positive value for `train_freq` and `n_episodes_rollout`."
- "Please make sure this is intended. "
- "The agent will collect data by stepping in the environment "
- "until both conditions are true: "
- "`number of steps in the env` >= `train_freq` and "
- "`number of episodes` > `n_episodes_rollout`")
+ warnings.warn(
+ "You passed a positive value for `train_freq` and `n_episodes_rollout`."
+ "Please make sure this is intended. "
+ "The agent will collect data by stepping in the environment "
+ "until both conditions are true: "
+ "`number of steps in the env` >= `train_freq` and "
+ "`number of episodes` > `n_episodes_rollout`"
+ )
self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer = None # type: Optional[ReplayBuffer]
# Update policy keyword arguments
if sde_support:
- self.policy_kwargs['use_sde'] = self.use_sde
- self.policy_kwargs['device'] = self.device
+ self.policy_kwargs["use_sde"] = self.use_sde
+ self.policy_kwargs["device"] = self.device
# For gSDE only
self.use_sde_at_warmup = use_sde_at_warmup
def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
- self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
- self.action_space, self.device,
- optimize_memory_usage=self.optimize_memory_usage)
- self.policy = self.policy_class(self.observation_space, self.action_space,
- self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable
+ self.replay_buffer = ReplayBuffer(
+ self.buffer_size,
+ self.observation_space,
+ self.action_space,
+ self.device,
+ optimize_memory_usage=self.optimize_memory_usage,
+ )
+ self.policy = self.policy_class(
+ self.observation_space,
+ self.action_space,
+ self.lr_schedule,
+ **self.policy_kwargs # pytype:disable=not-instantiable
+ )
self.policy = self.policy.to(self.device)
def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None:
@@ -158,65 +180,78 @@ def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase])
:param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer.
"""
self.replay_buffer = load_from_pkl(path, self.verbose)
- assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'
-
- def _setup_learn(self,
- total_timesteps: int,
- eval_env: Optional[GymEnv],
- callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
- eval_freq: int = 10000,
- n_eval_episodes: int = 5,
- log_path: Optional[str] = None,
- reset_num_timesteps: bool = True,
- tb_log_name: str = 'run',
- ) -> Tuple[int, BaseCallback]:
+ assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class"
+
+ def _setup_learn(
+ self,
+ total_timesteps: int,
+ eval_env: Optional[GymEnv],
+ callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
+ eval_freq: int = 10000,
+ n_eval_episodes: int = 5,
+ log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ tb_log_name: str = "run",
+ ) -> Tuple[int, BaseCallback]:
"""
cf `BaseAlgorithm`.
"""
# Prevent continuity issue by truncating trajectory
# when using memory efficient replay buffer
# see https://github.com/DLR-RM/stable-baselines3/issues/46
- truncate_last_traj = (self.optimize_memory_usage and reset_num_timesteps
- and self.replay_buffer is not None
- and (self.replay_buffer.full or self.replay_buffer.pos > 0))
+ truncate_last_traj = (
+ self.optimize_memory_usage
+ and reset_num_timesteps
+ and self.replay_buffer is not None
+ and (self.replay_buffer.full or self.replay_buffer.pos > 0)
+ )
if truncate_last_traj:
- warnings.warn("The last trajectory in the replay buffer will be truncated, "
- "see https://github.com/DLR-RM/stable-baselines3/issues/46."
- "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
- "to avoid that issue.")
+ warnings.warn(
+ "The last trajectory in the replay buffer will be truncated, "
+ "see https://github.com/DLR-RM/stable-baselines3/issues/46."
+ "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
+ "to avoid that issue."
+ )
# Go to the previous index
pos = (self.replay_buffer.pos - 1) % self.replay_buffer.buffer_size
self.replay_buffer.dones[pos] = True
- return super()._setup_learn(total_timesteps, eval_env, callback, eval_freq,
- n_eval_episodes, log_path, reset_num_timesteps, tb_log_name)
-
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "run",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> 'OffPolicyAlgorithm':
-
- total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
- n_eval_episodes, eval_log_path, reset_num_timesteps,
- tb_log_name)
+ return super()._setup_learn(
+ total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, log_path, reset_num_timesteps, tb_log_name
+ )
+
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 4,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "run",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "OffPolicyAlgorithm":
+
+ total_timesteps, callback = self._setup_learn(
+ total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
+ )
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
- rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout,
- n_steps=self.train_freq, action_noise=self.action_noise,
- callback=callback,
- learning_starts=self.learning_starts,
- replay_buffer=self.replay_buffer,
- log_interval=log_interval)
+ rollout = self.collect_rollouts(
+ self.env,
+ n_episodes=self.n_episodes_rollout,
+ n_steps=self.train_freq,
+ action_noise=self.action_noise,
+ callback=callback,
+ learning_starts=self.learning_starts,
+ replay_buffer=self.replay_buffer,
+ log_interval=log_interval,
+ )
if rollout.continue_training is False:
break
@@ -238,8 +273,9 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
"""
raise NotImplementedError()
- def _sample_action(self, learning_starts: int,
- action_noise: Optional[ActionNoise] = None) -> Tuple[np.ndarray, np.ndarray]:
+ def _sample_action(
+ self, learning_starts: int, action_noise: Optional[ActionNoise] = None
+ ) -> Tuple[np.ndarray, np.ndarray]:
"""
Sample an action according to the exploration policy.
This is either done by sampling the probability distribution of the policy,
@@ -288,16 +324,16 @@ def _dump_logs(self) -> None:
fps = int(self.num_timesteps / (time.time() - self.start_time))
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
- logger.record('rollout/ep_rew_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
- logger.record('rollout/ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
+ logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
+ logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
logger.record("time/fps", fps)
- logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard")
+ logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
if self.use_sde:
logger.record("train/std", (self.actor.get_std()).mean().item())
if len(self.ep_success_buffer) > 0:
- logger.record('rollout/success rate', safe_mean(self.ep_success_buffer))
+ logger.record("rollout/success rate", safe_mean(self.ep_success_buffer))
# Pass the number of timesteps for tensorboard
logger.dump(step=self.num_timesteps)
@@ -309,15 +345,17 @@ def _on_step(self) -> None:
"""
pass
- def collect_rollouts(self,
- env: VecEnv,
- callback: BaseCallback,
- n_episodes: int = 1,
- n_steps: int = -1,
- action_noise: Optional[ActionNoise] = None,
- learning_starts: int = 0,
- replay_buffer: Optional[ReplayBuffer] = None,
- log_interval: Optional[int] = None) -> RolloutReturn:
+ def collect_rollouts(
+ self,
+ env: VecEnv,
+ callback: BaseCallback,
+ n_episodes: int = 1,
+ n_steps: int = -1,
+ action_noise: Optional[ActionNoise] = None,
+ learning_starts: int = 0,
+ replay_buffer: Optional[ReplayBuffer] = None,
+ log_interval: Optional[int] = None,
+ ) -> RolloutReturn:
"""
Collect experiences and store them into a ReplayBuffer.
diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py
index 2937b77cb..f84d18f34 100644
--- a/stable_baselines3/common/on_policy_algorithm.py
+++ b/stable_baselines3/common/on_policy_algorithm.py
@@ -1,18 +1,18 @@
import time
-from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
-import torch as th
import numpy as np
+import torch as th
from stable_baselines3.common import logger
-from stable_baselines3.common.utils import safe_mean
from stable_baselines3.common.base_class import BaseAlgorithm
+from stable_baselines3.common.buffers import RolloutBuffer
+from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
-from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
-from stable_baselines3.common.callbacks import BaseCallback
-from stable_baselines3.common.buffers import RolloutBuffer
+from stable_baselines3.common.utils import safe_mean
+from stable_baselines3.common.vec_env import VecEnv
class OnPolicyAlgorithm(BaseAlgorithm):
@@ -48,32 +48,44 @@ class OnPolicyAlgorithm(BaseAlgorithm):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self,
- policy: Union[str, Type[ActorCriticPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable],
- n_steps: int,
- gamma: float,
- gae_lambda: float,
- ent_coef: float,
- vf_coef: float,
- max_grad_norm: float,
- use_sde: bool,
- sde_sample_freq: int,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- monitor_wrapper: bool = True,
- policy_kwargs: Optional[Dict[str, Any]] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = 'auto',
- _init_setup_model: bool = True):
-
- super(OnPolicyAlgorithm, self).__init__(policy=policy, env=env, policy_base=ActorCriticPolicy,
- learning_rate=learning_rate, policy_kwargs=policy_kwargs,
- verbose=verbose, device=device, use_sde=use_sde,
- sde_sample_freq=sde_sample_freq, create_eval_env=create_eval_env,
- support_multi_env=True, seed=seed, tensorboard_log=tensorboard_log)
+ def __init__(
+ self,
+ policy: Union[str, Type[ActorCriticPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable],
+ n_steps: int,
+ gamma: float,
+ gae_lambda: float,
+ ent_coef: float,
+ vf_coef: float,
+ max_grad_norm: float,
+ use_sde: bool,
+ sde_sample_freq: int,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ monitor_wrapper: bool = True,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super(OnPolicyAlgorithm, self).__init__(
+ policy=policy,
+ env=env,
+ policy_base=ActorCriticPolicy,
+ learning_rate=learning_rate,
+ policy_kwargs=policy_kwargs,
+ verbose=verbose,
+ device=device,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ create_eval_env=create_eval_env,
+ support_multi_env=True,
+ seed=seed,
+ tensorboard_log=tensorboard_log,
+ )
self.n_steps = n_steps
self.gamma = gamma
@@ -90,20 +102,28 @@ def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)
- self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space,
- self.action_space, self.device,
- gamma=self.gamma, gae_lambda=self.gae_lambda,
- n_envs=self.n_envs)
- self.policy = self.policy_class(self.observation_space, self.action_space,
- self.lr_schedule, use_sde=self.use_sde, device=self.device,
- **self.policy_kwargs) # pytype:disable=not-instantiable
+ self.rollout_buffer = RolloutBuffer(
+ self.n_steps,
+ self.observation_space,
+ self.action_space,
+ self.device,
+ gamma=self.gamma,
+ gae_lambda=self.gae_lambda,
+ n_envs=self.n_envs,
+ )
+ self.policy = self.policy_class(
+ self.observation_space,
+ self.action_space,
+ self.lr_schedule,
+ use_sde=self.use_sde,
+ device=self.device,
+ **self.policy_kwargs # pytype:disable=not-instantiable
+ )
self.policy = self.policy.to(self.device)
- def collect_rollouts(self,
- env: VecEnv,
- callback: BaseCallback,
- rollout_buffer: RolloutBuffer,
- n_rollout_steps: int) -> bool:
+ def collect_rollouts(
+ self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int
+ ) -> bool:
"""
Collect rollouts using the current policy and fill a `RolloutBuffer`.
@@ -169,29 +189,29 @@ def train(self) -> None:
"""
raise NotImplementedError
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 1,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "OnPolicyAlgorithm",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> 'OnPolicyAlgorithm':
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 1,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "OnPolicyAlgorithm",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "OnPolicyAlgorithm":
iteration = 0
- total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq,
- n_eval_episodes, eval_log_path, reset_num_timesteps,
- tb_log_name)
+ total_timesteps, callback = self._setup_learn(
+ total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name
+ )
callback.on_training_start(locals(), globals())
while self.num_timesteps < total_timesteps:
- continue_training = self.collect_rollouts(self.env, callback,
- self.rollout_buffer,
- n_rollout_steps=self.n_steps)
+ continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
if continue_training is False:
break
@@ -204,10 +224,8 @@ def learn(self,
fps = int(self.num_timesteps / (time.time() - self.start_time))
logger.record("time/iterations", iteration, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
- logger.record("rollout/ep_rew_mean",
- safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
- logger.record("rollout/ep_len_mean",
- safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
+ logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
+ logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
logger.record("time/fps", fps)
logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard")
logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py
index 276a0f318..acadd8845 100644
--- a/stable_baselines3/common/policies.py
+++ b/stable_baselines3/common/policies.py
@@ -1,24 +1,28 @@
"""Policies: abstract base class and concrete implementations."""
-from abc import ABC, abstractmethod
import collections
-from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable
+from abc import ABC, abstractmethod
from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
-import torch as th
-import torch.nn as nn
import numpy as np
-
-from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim
-from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp,
- NatureCNN, MlpExtractor)
+import torch as th
+from torch import nn as nn
+
+from stable_baselines3.common.distributions import (
+ BernoulliDistribution,
+ CategoricalDistribution,
+ DiagGaussianDistribution,
+ Distribution,
+ MultiCategoricalDistribution,
+ StateDependentNoiseDistribution,
+ make_proba_distribution,
+)
+from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs
+from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp
from stable_baselines3.common.utils import get_device, is_vectorized_observation
from stable_baselines3.common.vec_env import VecTransposeImage
-from stable_baselines3.common.distributions import (make_proba_distribution, Distribution,
- DiagGaussianDistribution, CategoricalDistribution,
- MultiCategoricalDistribution, BernoulliDistribution,
- StateDependentNoiseDistribution)
class BaseModel(nn.Module, ABC):
@@ -44,16 +48,18 @@ class BaseModel(nn.Module, ABC):
excluding the learning rate, to pass to the optimizer
"""
- def __init__(self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- device: Union[th.device, str] = 'auto',
- features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- features_extractor: Optional[nn.Module] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ device: Union[th.device, str] = "auto",
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ features_extractor: Optional[nn.Module] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
super(BaseModel, self).__init__()
if optimizer_kwargs is None:
@@ -86,7 +92,7 @@ def extract_features(self, obs: th.Tensor) -> th.Tensor:
:param obs: (th.Tensor)
:return: (th.Tensor)
"""
- assert self.features_extractor is not None, 'No feature extractor was set'
+ assert self.features_extractor is not None, "No feature extractor was set"
preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
return self.features_extractor(preprocessed_obs)
@@ -112,10 +118,10 @@ def save(self, path: str) -> None:
:param path: (str)
"""
- th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path)
+ th.save({"state_dict": self.state_dict(), "data": self._get_data()}, path)
@classmethod
- def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel':
+ def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel":
"""
Load model from path.
@@ -126,9 +132,9 @@ def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel':
device = get_device(device)
saved_variables = th.load(path, map_location=device)
# Create policy object
- model = cls(**saved_variables['data']) # pytype: disable=not-instantiable
+ model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable
# Load weights
- model.load_state_dict(saved_variables['state_dict'])
+ model.load_state_dict(saved_variables["state_dict"])
model.to(device)
return model
@@ -159,6 +165,7 @@ class BasePolicy(BaseModel):
:param squash_output: (bool) For continuous actions, whether the output is squashed
or not using a ``tanh()`` function.
"""
+
def __init__(self, *args, squash_output: bool = False, **kwargs):
super(BasePolicy, self).__init__(*args, **kwargs)
self._squash_output = squash_output
@@ -196,11 +203,13 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
:return: (th.Tensor) Taken action according to the policy
"""
- def predict(self,
- observation: np.ndarray,
- state: Optional[np.ndarray] = None,
- mask: Optional[np.ndarray] = None,
- deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ def predict(
+ self,
+ observation: np.ndarray,
+ state: Optional[np.ndarray] = None,
+ mask: Optional[np.ndarray] = None,
+ deterministic: bool = False,
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Get the policy action and state from an observation (and optional state).
Includes sugar-coating to handle different observations (e.g. normalizing images).
@@ -222,12 +231,15 @@ def predict(self,
# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
- if not (observation.shape == self.observation_space.shape
- or observation.shape[1:] == self.observation_space.shape):
+ if not (
+ observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape
+ ):
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
- if (transpose_obs.shape == self.observation_space.shape
- or transpose_obs.shape[1:] == self.observation_space.shape):
+ if (
+ transpose_obs.shape == self.observation_space.shape
+ or transpose_obs.shape[1:] == self.observation_space.shape
+ ):
observation = transpose_obs
vectorized_env = is_vectorized_observation(observation, self.observation_space)
@@ -313,40 +325,44 @@ class ActorCriticPolicy(BasePolicy):
excluding the learning rate, to pass to the optimizer
"""
- def __init__(self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable[[float], float],
- net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.Tanh,
- ortho_init: bool = True,
- use_sde: bool = False,
- log_std_init: float = 0.0,
- full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- squash_output: bool = False,
- features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable[[float], float],
+ net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.Tanh,
+ ortho_init: bool = True,
+ use_sde: bool = False,
+ log_std_init: float = 0.0,
+ full_std: bool = True,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
if optimizer_kwargs is None:
optimizer_kwargs = {}
# Small values to avoid NaN in Adam optimizer
if optimizer_class == th.optim.Adam:
- optimizer_kwargs['eps'] = 1e-5
-
- super(ActorCriticPolicy, self).__init__(observation_space,
- action_space,
- device,
- features_extractor_class,
- features_extractor_kwargs,
- optimizer_class=optimizer_class,
- optimizer_kwargs=optimizer_kwargs,
- squash_output=squash_output)
+ optimizer_kwargs["eps"] = 1e-5
+
+ super(ActorCriticPolicy, self).__init__(
+ observation_space,
+ action_space,
+ device,
+ features_extractor_class,
+ features_extractor_kwargs,
+ optimizer_class=optimizer_class,
+ optimizer_kwargs=optimizer_kwargs,
+ squash_output=squash_output,
+ )
# Default network architecture, from stable-baselines
if net_arch is None:
@@ -359,8 +375,7 @@ def __init__(self,
self.activation_fn = activation_fn
self.ortho_init = ortho_init
- self.features_extractor = features_extractor_class(self.observation_space,
- **self.features_extractor_kwargs)
+ self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.normalize_images = normalize_images
@@ -369,10 +384,10 @@ def __init__(self,
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
- 'full_std': full_std,
- 'squash_output': squash_output,
- 'use_expln': use_expln,
- 'learn_features': sde_net_arch is not None
+ "full_std": full_std,
+ "squash_output": squash_output,
+ "use_expln": use_expln,
+ "learn_features": sde_net_arch is not None,
}
self.sde_features_extractor = None
@@ -390,22 +405,24 @@ def _get_data(self) -> Dict[str, Any]:
default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None)
- data.update(dict(
- net_arch=self.net_arch,
- activation_fn=self.activation_fn,
- use_sde=self.use_sde,
- log_std_init=self.log_std_init,
- squash_output=default_none_kwargs['squash_output'],
- full_std=default_none_kwargs['full_std'],
- sde_net_arch=default_none_kwargs['sde_net_arch'],
- use_expln=default_none_kwargs['use_expln'],
- lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
- ortho_init=self.ortho_init,
- optimizer_class=self.optimizer_class,
- optimizer_kwargs=self.optimizer_kwargs,
- features_extractor_class=self.features_extractor_class,
- features_extractor_kwargs=self.features_extractor_kwargs
- ))
+ data.update(
+ dict(
+ net_arch=self.net_arch,
+ activation_fn=self.activation_fn,
+ use_sde=self.use_sde,
+ log_std_init=self.log_std_init,
+ squash_output=default_none_kwargs["squash_output"],
+ full_std=default_none_kwargs["full_std"],
+ sde_net_arch=default_none_kwargs["sde_net_arch"],
+ use_expln=default_none_kwargs["use_expln"],
+ lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
+ ortho_init=self.ortho_init,
+ optimizer_class=self.optimizer_class,
+ optimizer_kwargs=self.optimizer_kwargs,
+ features_extractor_class=self.features_extractor_class,
+ features_extractor_kwargs=self.features_extractor_kwargs,
+ )
+ )
return data
def reset_noise(self, n_envs: int = 1) -> None:
@@ -414,8 +431,7 @@ def reset_noise(self, n_envs: int = 1) -> None:
:param n_envs: (int)
"""
- assert isinstance(self.action_dist,
- StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE'
+ assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE"
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
def _build(self, lr_schedule: Callable[[float], float]) -> None:
@@ -428,25 +444,27 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None:
# Note: If net_arch is None and some features extractor is used,
# net_arch here is an empty list and mlp_extractor does not
# really contain any layers (acts like an identity module).
- self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
- activation_fn=self.activation_fn, device=self.device)
+ self.mlp_extractor = MlpExtractor(
+ self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device
+ )
latent_dim_pi = self.mlp_extractor.latent_dim_pi
# Separate feature extractor for gSDE
if self.sde_net_arch is not None:
- self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(self.features_dim,
- self.sde_net_arch,
- self.activation_fn)
+ self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
+ self.features_dim, self.sde_net_arch, self.activation_fn
+ )
if isinstance(self.action_dist, DiagGaussianDistribution):
- self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
- log_std_init=self.log_std_init)
+ self.action_net, self.log_std = self.action_dist.proba_distribution_net(
+ latent_dim=latent_dim_pi, log_std_init=self.log_std_init
+ )
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim
- self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi,
- latent_sde_dim=latent_sde_dim,
- log_std_init=self.log_std_init)
+ self.action_net, self.log_std = self.action_dist.proba_distribution_net(
+ latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init
+ )
elif isinstance(self.action_dist, CategoricalDistribution):
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
elif isinstance(self.action_dist, MultiCategoricalDistribution):
@@ -468,7 +486,7 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None:
self.features_extractor: np.sqrt(2),
self.mlp_extractor: np.sqrt(2),
self.action_net: 0.01,
- self.value_net: 1
+ self.value_net: 1,
}
for module, gain in module_gains.items():
module.apply(partial(self.init_weights, gain=gain))
@@ -476,8 +494,7 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None:
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
- def forward(self, obs: th.Tensor,
- deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
+ def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Forward pass in all the networks (actor and critic)
@@ -512,8 +529,7 @@ def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
latent_sde = self.sde_features_extractor(features)
return latent_pi, latent_vf, latent_sde
- def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
- latent_sde: Optional[th.Tensor] = None) -> Distribution:
+ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution:
"""
Retrieve action distribution given the latent codes.
@@ -537,7 +553,7 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor,
elif isinstance(self.action_dist, StateDependentNoiseDistribution):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
else:
- raise ValueError('Invalid action distribution')
+ raise ValueError("Invalid action distribution")
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
@@ -551,8 +567,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
return distribution.get_actions(deterministic=deterministic)
- def evaluate_actions(self, obs: th.Tensor,
- actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
+ def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
"""
Evaluate actions according to the current policy,
given the observations.
@@ -604,43 +619,47 @@ class ActorCriticCnnPolicy(ActorCriticPolicy):
excluding the learning rate, to pass to the optimizer
"""
- def __init__(self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.Tanh,
- ortho_init: bool = True,
- use_sde: bool = False,
- log_std_init: float = 0.0,
- full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- squash_output: bool = False,
- features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
- super(ActorCriticCnnPolicy, self).__init__(observation_space,
- action_space,
- lr_schedule,
- net_arch,
- device,
- activation_fn,
- ortho_init,
- use_sde,
- log_std_init,
- full_std,
- sde_net_arch,
- use_expln,
- squash_output,
- features_extractor_class,
- features_extractor_kwargs,
- normalize_images,
- optimizer_class,
- optimizer_kwargs)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.Tanh,
+ ortho_init: bool = True,
+ use_sde: bool = False,
+ log_std_init: float = 0.0,
+ full_std: bool = True,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ squash_output: bool = False,
+ features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super(ActorCriticCnnPolicy, self).__init__(
+ observation_space,
+ action_space,
+ lr_schedule,
+ net_arch,
+ device,
+ activation_fn,
+ ortho_init,
+ use_sde,
+ log_std_init,
+ full_std,
+ sde_net_arch,
+ use_expln,
+ squash_output,
+ features_extractor_class,
+ features_extractor_kwargs,
+ normalize_images,
+ optimizer_class,
+ optimizer_kwargs,
+ )
class ContinuousCritic(BaseModel):
@@ -669,19 +688,25 @@ class ContinuousCritic(BaseModel):
:param n_critics: (int) Number of critic networks to create.
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- net_arch: List[int],
- features_extractor: nn.Module,
- features_dim: int,
- activation_fn: Type[nn.Module] = nn.ReLU,
- normalize_images: bool = True,
- device: Union[th.device, str] = 'auto',
- n_critics: int = 2):
- super().__init__(observation_space, action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- device=device)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ net_arch: List[int],
+ features_extractor: nn.Module,
+ features_dim: int,
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ normalize_images: bool = True,
+ device: Union[th.device, str] = "auto",
+ n_critics: int = 2,
+ ):
+ super().__init__(
+ observation_space,
+ action_space,
+ features_extractor=features_extractor,
+ normalize_images=normalize_images,
+ device=device,
+ )
action_dim = get_action_dim(self.action_space)
@@ -690,7 +715,7 @@ def __init__(self, observation_space: gym.spaces.Space,
for idx in range(n_critics):
q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn)
q_net = nn.Sequential(*q_net)
- self.add_module(f'qf{idx}', q_net)
+ self.add_module(f"qf{idx}", q_net)
self.q_networks.append(q_net)
def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]:
@@ -711,9 +736,9 @@ def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor:
return self.q_networks[0](th.cat([features, actions], dim=1))
-def create_sde_features_extractor(features_dim: int,
- sde_net_arch: List[int],
- activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]:
+def create_sde_features_extractor(
+ features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module]
+) -> Tuple[nn.Sequential, int]:
"""
Create the neural network that will be used to extract features
for the gSDE exploration function.
@@ -747,8 +772,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[
if base_policy_type not in _policy_registry:
raise KeyError(f"Error: the policy type {base_policy_type} is not registered!")
if name not in _policy_registry[base_policy_type]:
- raise KeyError(f"Error: unknown policy type {name},"
- f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!")
+ raise KeyError(
+ f"Error: unknown policy type {name},"
+ f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!"
+ )
return _policy_registry[base_policy_type][name]
diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py
index 849756f17..3efaf9b53 100644
--- a/stable_baselines3/common/preprocessing.py
+++ b/stable_baselines3/common/preprocessing.py
@@ -1,14 +1,12 @@
from typing import Tuple
+import numpy as np
import torch as th
-import torch.nn.functional as F
from gym import spaces
-import numpy as np
+from torch.nn import functional as F
-def is_image_space(observation_space: spaces.Space,
- channels_last: bool = True,
- check_channels: bool = False) -> bool:
+def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool:
"""
Check if a observation space has the shape, limits and dtype
of a valid image.
@@ -45,8 +43,7 @@ def is_image_space(observation_space: spaces.Space,
return False
-def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space,
- normalize_images: bool = True) -> th.Tensor:
+def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor:
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])
@@ -69,9 +66,13 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space,
elif isinstance(observation_space, spaces.MultiDiscrete):
# Tensor concatenation of one hot encodings of each Categorical sub-space
- return th.cat([F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
- for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))],
- dim=-1).view(obs.shape[0], sum(observation_space.nvec))
+ return th.cat(
+ [
+ F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float()
+ for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))
+ ],
+ dim=-1,
+ ).view(obs.shape[0], sum(observation_space.nvec))
elif isinstance(observation_space, spaces.MultiBinary):
return obs.float()
@@ -91,13 +92,13 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]:
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
- return 1,
+ return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
- return int(len(observation_space.nvec)),
+ return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
- return int(observation_space.n),
+ return (int(observation_space.n),)
else:
raise NotImplementedError()
diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py
index 4f879be2a..8fe805c1c 100644
--- a/stable_baselines3/common/results_plotter.py
+++ b/stable_baselines3/common/results_plotter.py
@@ -1,17 +1,17 @@
-from typing import Tuple, Callable, List, Optional
+from typing import Callable, List, Optional, Tuple
import numpy as np
import pandas as pd
+
# import matplotlib
# matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode
-import matplotlib.pyplot as plt
+from matplotlib import pyplot as plt
from stable_baselines3.common.monitor import load_results
-
-X_TIMESTEPS = 'timesteps'
-X_EPISODES = 'episodes'
-X_WALLTIME = 'walltime_hrs'
+X_TIMESTEPS = "timesteps"
+X_EPISODES = "episodes"
+X_WALLTIME = "walltime_hrs"
POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
EPISODES_WINDOW = 100
@@ -29,8 +29,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
-def window_func(var_1: np.ndarray, var_2: np.ndarray,
- window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
+def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]:
"""
Apply a function to the rolling window of 2 arrays
@@ -42,7 +41,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray,
"""
var_2_window = rolling_window(var_2, window)
function_on_var2 = func(var_2_window, axis=-1)
- return var_1[window - 1:], function_on_var2
+ return var_1[window - 1 :], function_on_var2
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
@@ -62,15 +61,16 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray
y_var = data_frame.r.values
elif x_axis == X_WALLTIME:
# Convert to hours
- x_var = data_frame.t.values / 3600.
+ x_var = data_frame.t.values / 3600.0
y_var = data_frame.r.values
else:
raise NotImplementedError
return x_var, y_var
-def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
- x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)) -> None:
+def plot_curves(
+ xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)
+) -> None:
"""
plot the curves
@@ -98,8 +98,9 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
plt.tight_layout()
-def plot_results(dirs: List[str], num_timesteps: Optional[int],
- x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)) -> None:
+def plot_results(
+ dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)
+) -> None:
"""
Plot the results using csv files from ``Monitor`` wrapper.
diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py
index f50310676..b98fccc65 100644
--- a/stable_baselines3/common/running_mean_std.py
+++ b/stable_baselines3/common/running_mean_std.py
@@ -22,9 +22,7 @@ def update(self, arr: np.ndarray) -> None:
batch_count = arr.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
- def update_from_moments(self, batch_mean: np.ndarray,
- batch_var: np.ndarray,
- batch_count: int) -> None:
+ def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None:
delta = batch_mean - self.mean
tot_count = self.count + batch_count
diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py
index 0d2fe9823..51fa8cd17 100644
--- a/stable_baselines3/common/save_util.py
+++ b/stable_baselines3/common/save_util.py
@@ -2,19 +2,19 @@
Save util taken from stable_baselines
used to serialize data (class parameters) of model classes
"""
-import io
-import os
-import json
import base64
import functools
-from typing import Dict, Any, Tuple, Optional, Union
-import warnings
-import zipfile
+import io
+import json
+import os
import pathlib
import pickle
+import warnings
+import zipfile
+from typing import Any, Dict, Optional, Tuple, Union
-import torch as th
import cloudpickle
+import torch as th
from stable_baselines3.common.type_aliases import TensorDict
from stable_baselines3.common.utils import get_device
@@ -112,11 +112,7 @@ def data_to_json(data: Dict[str, Any]) -> str:
# e.g. numpy scalars)
if hasattr(data_item, "__dict__") or isinstance(data_item, dict):
# Take elements from __dict__ for custom classes
- item_generator = (
- data_item.items
- if isinstance(data_item, dict)
- else data_item.__dict__.items
- )
+ item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items
for variable_name, variable_item in item_generator():
# Check if serializable. If not, just include the
# string-representation of the object.
@@ -130,9 +126,7 @@ def data_to_json(data: Dict[str, Any]) -> str:
return json_string
-def json_to_data(
- json_string: str, custom_objects: Optional[Dict[str, Any]] = None
-) -> Dict[str, Any]:
+def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Turn JSON serialization of class-parameters back into dictionary.
@@ -181,9 +175,7 @@ def json_to_data(
@functools.singledispatch
-def open_path(
- path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None
-):
+def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None):
"""
Opens a path for reading or writing with a preferred suffix and raises debug information.
If the provided path is a derivative of io.BufferedIOBase it ensures that the file
@@ -221,9 +213,7 @@ def open_path(
@open_path.register(str)
-def open_path_str(
- path: str, mode: str, verbose=0, suffix=None
-) -> io.BufferedIOBase:
+def open_path_str(path: str, mode: str, verbose=0, suffix=None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
@@ -240,9 +230,7 @@ def open_path_str(
@open_path.register(pathlib.Path)
-def open_path_pathlib(
- path: pathlib.Path, mode: str, verbose=0, suffix=None
-) -> io.BufferedIOBase:
+def open_path_pathlib(path: pathlib.Path, mode: str, verbose=0, suffix=None) -> io.BufferedIOBase:
"""
Open a path given by a string. If writing to the path, the function ensures
that the path exists.
@@ -331,9 +319,7 @@ def save_to_zip_file(
th.save(dict_, param_file)
-def save_to_pkl(
- path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0
-) -> None:
+def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0) -> None:
"""
Save an object to path creating the necessary folders along the way.
If the path exists and is a directory, it will raise a warning and rename the path.
@@ -364,9 +350,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0)
def load_from_zip_file(
- load_path: Union[str, pathlib.Path, io.BufferedIOBase],
- load_data: bool = True,
- verbose=0,
+ load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, verbose=0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
"""
Load model data from a .zip archive
@@ -412,10 +396,7 @@ def load_from_zip_file(
# check for all other .pth files
other_files = [
- file_name
- for file_name in namelist
- if os.path.splitext(file_name)[1] == ".pth"
- and file_name != "tensors.pth"
+ file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth"
]
# if there are any other files which end with .pth and aren't "params.pth"
# assume that they each are optimizer parameters
@@ -429,9 +410,7 @@ def load_from_zip_file(
# go to start of file
file_content.seek(0)
# load the parameters with the right ``map_location``
- params[os.path.splitext(file_path)[0]] = th.load(
- file_content, map_location=device
- )
+ params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py
index f92b50fc8..9c74017cb 100644
--- a/stable_baselines3/common/torch_layers.py
+++ b/stable_baselines3/common/torch_layers.py
@@ -1,10 +1,9 @@
-from typing import Union, Type, Dict, List, Tuple
-
from itertools import zip_longest
+from typing import Dict, List, Tuple, Type, Union
import gym
import torch as th
-import torch.nn as nn
+from torch import nn as nn
from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space
from stable_baselines3.common.utils import get_device
@@ -60,22 +59,25 @@ class NatureCNN(BaseFeaturesExtractor):
This corresponds to the number of unit for the last layer.
"""
- def __init__(self, observation_space: gym.spaces.Box,
- features_dim: int = 512):
+ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
# We assume CxHxW images (channels first)
# Re-ordering will be done by pre-preprocessing or wrapper
- assert is_image_space(observation_space), ('You should use NatureCNN '
- f'only with images not with {observation_space} '
- '(you are probably using `CnnPolicy` instead of `MlpPolicy`)')
+ assert is_image_space(observation_space), (
+ "You should use NatureCNN "
+ f"only with images not with {observation_space} "
+ "(you are probably using `CnnPolicy` instead of `MlpPolicy`)"
+ )
n_input_channels = observation_space.shape[0]
- self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
- nn.ReLU(),
- nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
- nn.ReLU(),
- nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
- nn.ReLU(),
- nn.Flatten())
+ self.cnn = nn.Sequential(
+ nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
+ nn.ReLU(),
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
+ nn.ReLU(),
+ nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
+ nn.ReLU(),
+ nn.Flatten(),
+ )
# Compute shape by doing one forward pass
with th.no_grad():
@@ -87,11 +89,9 @@ def forward(self, observations: th.Tensor) -> th.Tensor:
return self.linear(self.cnn(observations))
-def create_mlp(input_dim: int,
- output_dim: int,
- net_arch: List[int],
- activation_fn: Type[nn.Module] = nn.ReLU,
- squash_output: bool = False) -> List[nn.Module]:
+def create_mlp(
+ input_dim: int, output_dim: int, net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False
+) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
@@ -152,10 +152,13 @@ class MlpExtractor(nn.Module):
:param device: (th.device)
"""
- def __init__(self, feature_dim: int,
- net_arch: List[Union[int, Dict[str, List[int]]]],
- activation_fn: Type[nn.Module],
- device: Union[th.device, str] = 'auto'):
+ def __init__(
+ self,
+ feature_dim: int,
+ net_arch: List[Union[int, Dict[str, List[int]]]],
+ activation_fn: Type[nn.Module],
+ device: Union[th.device, str] = "auto",
+ ):
super(MlpExtractor, self).__init__()
device = get_device(device)
shared_net, policy_net, value_net = [], [], []
@@ -173,13 +176,13 @@ def __init__(self, feature_dim: int,
last_layer_dim_shared = layer_size
else:
assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts"
- if 'pi' in layer:
- assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
- policy_only_layers = layer['pi']
+ if "pi" in layer:
+ assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers."
+ policy_only_layers = layer["pi"]
- if 'vf' in layer:
- assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
- value_only_layers = layer['vf']
+ if "vf" in layer:
+ assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers."
+ value_only_layers = layer["vf"]
break # From here on the network splits up in policy and value network
last_layer_dim_pi = last_layer_dim_shared
diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py
index 5ef94f91e..e16f43539 100644
--- a/stable_baselines3/common/type_aliases.py
+++ b/stable_baselines3/common/type_aliases.py
@@ -1,14 +1,13 @@
"""Common aliases for type hints"""
-from typing import Union, Dict, Any, NamedTuple, List, Callable, Tuple
+from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
+import gym
import numpy as np
import torch as th
-import gym
-from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.callbacks import BaseCallback
-
+from stable_baselines3.common.vec_env import VecEnv
GymEnv = Union[gym.Env, VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py
index fbe357219..9c00f0e6b 100644
--- a/stable_baselines3/common/utils.py
+++ b/stable_baselines3/common/utils.py
@@ -1,13 +1,13 @@
-from collections import deque
-from typing import Callable, Union, Optional
-import random
-import os
import glob
-
+import os
+import random
+from collections import deque
+from typing import Callable, Optional, Union
import gym
import numpy as np
import torch as th
+
# Check if tensorboard is available for pytorch
try:
from torch.utils.tensorboard import SummaryWriter
@@ -15,8 +15,8 @@
SummaryWriter = None
from stable_baselines3.common import logger
-from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.preprocessing import is_image_space
+from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.vec_env import VecTransposeImage
@@ -68,7 +68,7 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
:param learning_rate: (float)
"""
for param_group in optimizer.param_groups:
- param_group['lr'] = learning_rate
+ param_group["lr"] = learning_rate
def get_schedule_fn(value_schedule: Union[Callable, float]) -> Callable:
@@ -128,7 +128,7 @@ def func(_):
return func
-def get_device(device: Union[th.device, str] = 'auto') -> th.device:
+def get_device(device: Union[th.device, str] = "auto") -> th.device:
"""
Retrieve PyTorch device.
It checks that the requested device is available first.
@@ -139,19 +139,19 @@ def get_device(device: Union[th.device, str] = 'auto') -> th.device:
:return: (th.device)
"""
# Cuda by default
- if device == 'auto':
- device = 'cuda'
+ if device == "auto":
+ device = "cuda"
# Force conversion to th.device
device = th.device(device)
# Cuda not available
- if device == th.device('cuda') and not th.cuda.is_available():
- return th.device('cpu')
+ if device == th.device("cuda") and not th.cuda.is_available():
+ return th.device("cpu")
return device
-def get_latest_run_id(log_path: Optional[str] = None, log_name: str = '') -> int:
+def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int:
"""
Returns the latest run number for the given log name and log path,
by finding the greatest number in the directories.
@@ -167,8 +167,9 @@ def get_latest_run_id(log_path: Optional[str] = None, log_name: str = '') -> int
return max_run_id
-def configure_logger(verbose: int = 0, tensorboard_log: Optional[str] = None,
- tb_log_name: str = '', reset_num_timesteps: bool = True) -> None:
+def configure_logger(
+ verbose: int = 0, tensorboard_log: Optional[str] = None, tb_log_name: str = "", reset_num_timesteps: bool = True
+) -> None:
"""
Configure the logger's outputs.
@@ -202,13 +203,17 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a
:param observation_space: (gym.spaces.Space) Observation space to check against
:param action_space: (gym.spaces.Space) Action space to check against
"""
- if (observation_space != env.observation_space
+ if (
+ observation_space != env.observation_space
# Special cases for images that need to be transposed
- and not (is_image_space(env.observation_space)
- and observation_space == VecTransposeImage.transpose_space(env.observation_space))):
- raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}')
+ and not (
+ is_image_space(env.observation_space)
+ and observation_space == VecTransposeImage.transpose_space(env.observation_space)
+ )
+ ):
+ raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}")
if action_space != env.action_space:
- raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}')
+ raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}")
def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
@@ -226,18 +231,21 @@ def is_vectorized_observation(observation: np.ndarray, observation_space: gym.sp
elif observation.shape[1:] == observation_space.shape:
return True
else:
- raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
- + f"Box environment, please use {observation_space.shape} "
- + "or (n_env, {}) for the observation shape."
- .format(", ".join(map(str, observation_space.shape))))
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for "
+ + f"Box environment, please use {observation_space.shape} "
+ + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape)))
+ )
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
- raise ValueError(f"Error: Unexpected observation shape {observation.shape} for "
- + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for "
+ + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape."
+ )
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
@@ -245,21 +253,26 @@ def is_vectorized_observation(observation: np.ndarray, observation_space: gym.sp
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
- raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
- + f"environment, please use ({len(observation_space.nvec)},) or "
- + f"(n_env, {len(observation_space.nvec)}) for the observation shape.")
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete "
+ + f"environment, please use ({len(observation_space.nvec)},) or "
+ + f"(n_env, {len(observation_space.nvec)}) for the observation shape."
+ )
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
- raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
- + f"environment, please use ({observation_space.n},) or "
- + f"(n_env, {observation_space.n}) for the observation shape.")
+ raise ValueError(
+ f"Error: Unexpected observation shape {observation.shape} for MultiBinary "
+ + f"environment, please use ({observation_space.n},) or "
+ + f"(n_env, {observation_space.n}) for the observation shape."
+ )
else:
- raise ValueError("Error: Cannot determine if the observation is vectorized "
- + f" with the space type {observation_space}.")
+ raise ValueError(
+ "Error: Cannot determine if the observation is vectorized " + f" with the space type {observation_space}."
+ )
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py
index 535ed031b..9944130b2 100644
--- a/stable_baselines3/common/vec_env/__init__.py
+++ b/stable_baselines3/common/vec_env/__init__.py
@@ -1,24 +1,29 @@
# flake8: noqa F401
import typing
-from typing import Optional, Union
from copy import deepcopy
+from typing import Optional, Union
-from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError,
- VecEnv, VecEnvWrapper, CloudpickleWrapper)
+from stable_baselines3.common.vec_env.base_vec_env import (
+ AlreadySteppingError,
+ CloudpickleWrapper,
+ NotSteppingError,
+ VecEnv,
+ VecEnvWrapper,
+)
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
+from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
-from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
# Avoid circular import
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymEnv
-def unwrap_vec_normalize(env: Union['GymEnv', VecEnv]) -> Optional[VecNormalize]:
+def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
"""
:param env: (gym.Env)
:return: (VecNormalize)
@@ -32,7 +37,7 @@ def unwrap_vec_normalize(env: Union['GymEnv', VecEnv]) -> Optional[VecNormalize]
# Define here to avoid circular import
-def sync_envs_normalization(env: 'GymEnv', eval_env: 'GymEnv') -> None:
+def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
"""
Sync eval env and train env when using VecNormalize
diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py
index b2caa461c..6b5a42143 100644
--- a/stable_baselines3/common/vec_env/base_vec_env.py
+++ b/stable_baselines3/common/vec_env/base_vec_env.py
@@ -1,10 +1,10 @@
import inspect
import pickle
from abc import ABC, abstractmethod
-from typing import Sequence, Optional, List, Union
+from typing import List, Optional, Sequence, Union
-import numpy as np
import cloudpickle
+import numpy as np
from stable_baselines3.common import logger
@@ -42,7 +42,7 @@ class AlreadySteppingError(Exception):
"""
def __init__(self):
- msg = 'already running an async step'
+ msg = "already running an async step"
Exception.__init__(self, msg)
@@ -53,7 +53,7 @@ class NotSteppingError(Exception):
"""
def __init__(self):
- msg = 'not running an async step'
+ msg = "not running an async step"
Exception.__init__(self, msg)
@@ -65,9 +65,8 @@ class VecEnv(ABC):
:param observation_space: (Gym Space) the observation space
:param action_space: (Gym Space) the action space
"""
- metadata = {
- 'render.modes': ['human', 'rgb_array']
- }
+
+ metadata = {"render.modes": ["human", "rgb_array"]}
def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
@@ -168,7 +167,7 @@ def get_images(self) -> Sequence[np.ndarray]:
"""
raise NotImplementedError
- def render(self, mode: str = 'human'):
+ def render(self, mode: str = "human"):
"""
Gym environment rendering
@@ -177,19 +176,20 @@ def render(self, mode: str = 'human'):
try:
imgs = self.get_images()
except NotImplementedError:
- logger.warn(f'Render not defined for {self}')
+ logger.warn(f"Render not defined for {self}")
return
# Create a big image by tiling images from subprocesses
bigimg = tile_images(imgs)
- if mode == 'human':
+ if mode == "human":
import cv2 # pytype:disable=import-error
- cv2.imshow('vecenv', bigimg[:, :, ::-1])
+
+ cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
- elif mode == 'rgb_array':
+ elif mode == "rgb_array":
return bigimg
else:
- raise NotImplementedError(f'Render mode {mode} is not supported by VecEnvs')
+ raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs")
@abstractmethod
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
@@ -247,8 +247,12 @@ class VecEnvWrapper(VecEnv):
def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv
- VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space,
- action_space=action_space or venv.action_space)
+ VecEnv.__init__(
+ self,
+ num_envs=venv.num_envs,
+ observation_space=observation_space or venv.observation_space,
+ action_space=action_space or venv.action_space,
+ )
self.class_attributes = dict(inspect.getmembers(self.__class__))
def step_async(self, actions):
@@ -268,7 +272,7 @@ def seed(self, seed=None):
def close(self):
return self.venv.close()
- def render(self, mode: str = 'human'):
+ def render(self, mode: str = "human"):
return self.venv.render(mode=mode)
def get_images(self):
@@ -291,8 +295,10 @@ def __getattr__(self, name):
blocked_class = self.getattr_depth_check(name, already_found=False)
if blocked_class is not None:
own_class = f"{type(self).__module__}.{type(self).__name__}"
- error_str = (f"Error: Recursive attribute lookup for {name} from {own_class} is "
- "ambiguous and hides attribute from {blocked_class}")
+ error_str = (
+ f"Error: Recursive attribute lookup for {name} from {own_class} is "
+ "ambiguous and hides attribute from {blocked_class}"
+ )
raise AttributeError(error_str)
return self.getattr_recursive(name)
@@ -315,7 +321,7 @@ def getattr_recursive(self, name):
all_attributes = self._get_all_attributes()
if name in all_attributes: # attribute is present in this wrapper
attr = getattr(self, name)
- elif hasattr(self.venv, 'getattr_recursive'):
+ elif hasattr(self.venv, "getattr_recursive"):
# Attribute not present, child is wrapper. Call getattr_recursive rather than getattr
# to avoid a duplicate call to getattr_depth_check.
attr = self.venv.getattr_recursive(name)
diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py
index a11108b6b..c577bd87f 100644
--- a/stable_baselines3/common/vec_env/dummy_vec_env.py
+++ b/stable_baselines3/common/vec_env/dummy_vec_env.py
@@ -26,9 +26,7 @@ def __init__(self, env_fns):
obs_space = env.observation_space
self.keys, shapes, dtypes = obs_space_info(obs_space)
- self.buf_obs = OrderedDict([
- (k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]))
- for k in self.keys])
+ self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys])
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)]
@@ -40,15 +38,15 @@ def step_async(self, actions):
def step_wait(self):
for env_idx in range(self.num_envs):
- obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\
- self.envs[env_idx].step(self.actions[env_idx])
+ obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
+ self.actions[env_idx]
+ )
if self.buf_dones[env_idx]:
# save final observation where user can get it, then reset
- self.buf_infos[env_idx]['terminal_observation'] = obs
+ self.buf_infos[env_idx]["terminal_observation"] = obs
obs = self.envs[env_idx].reset()
self._save_obs(env_idx, obs)
- return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
- deepcopy(self.buf_infos))
+ return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos))
def seed(self, seed=None):
seeds = list()
@@ -67,9 +65,9 @@ def close(self):
env.close()
def get_images(self) -> Sequence[np.ndarray]:
- return [env.render(mode='rgb_array') for env in self.envs]
+ return [env.render(mode="rgb_array") for env in self.envs]
- def render(self, mode: str = 'human'):
+ def render(self, mode: str = "human"):
"""
Gym environment rendering. If there are multiple environments then
they are tiled together in one image via ``BaseVecEnv.render()``.
diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py
index b12218f62..937e94365 100644
--- a/stable_baselines3/common/vec_env/subproc_vec_env.py
+++ b/stable_baselines3/common/vec_env/subproc_vec_env.py
@@ -5,7 +5,7 @@
import gym
import numpy as np
-from stable_baselines3.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper
+from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv
def _worker(remote, parent_remote, env_fn_wrapper):
@@ -14,32 +14,32 @@ def _worker(remote, parent_remote, env_fn_wrapper):
while True:
try:
cmd, data = remote.recv()
- if cmd == 'step':
+ if cmd == "step":
observation, reward, done, info = env.step(data)
if done:
# save final observation where user can get it, then reset
- info['terminal_observation'] = observation
+ info["terminal_observation"] = observation
observation = env.reset()
remote.send((observation, reward, done, info))
- elif cmd == 'seed':
+ elif cmd == "seed":
remote.send(env.seed(data))
- elif cmd == 'reset':
+ elif cmd == "reset":
observation = env.reset()
remote.send(observation)
- elif cmd == 'render':
+ elif cmd == "render":
remote.send(env.render(data))
- elif cmd == 'close':
+ elif cmd == "close":
env.close()
remote.close()
break
- elif cmd == 'get_spaces':
+ elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
- elif cmd == 'env_method':
+ elif cmd == "env_method":
method = getattr(env, data[0])
remote.send(method(*data[1], **data[2]))
- elif cmd == 'get_attr':
+ elif cmd == "get_attr":
remote.send(getattr(env, data))
- elif cmd == 'set_attr':
+ elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1]))
else:
raise NotImplementedError(f"`{cmd}` is not implemented in the worker")
@@ -80,8 +80,8 @@ def __init__(self, env_fns, start_method=None):
# Fork is not a thread safe method (see issue #217)
# but is more user friendly (does not require to wrap the code in
# a `if __name__ == "__main__":`)
- forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods()
- start_method = 'forkserver' if forkserver_available else 'spawn'
+ forkserver_available = "forkserver" in multiprocessing.get_all_start_methods()
+ start_method = "forkserver" if forkserver_available else "spawn"
ctx = multiprocessing.get_context(start_method)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)])
@@ -94,13 +94,13 @@ def __init__(self, env_fns, start_method=None):
self.processes.append(process)
work_remote.close()
- self.remotes[0].send(('get_spaces', None))
+ self.remotes[0].send(("get_spaces", None))
observation_space, action_space = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
def step_async(self, actions):
for remote, action in zip(self.remotes, actions):
- remote.send(('step', action))
+ remote.send(("step", action))
self.waiting = True
def step_wait(self):
@@ -111,12 +111,12 @@ def step_wait(self):
def seed(self, seed=None):
for idx, remote in enumerate(self.remotes):
- remote.send(('seed', seed + idx))
+ remote.send(("seed", seed + idx))
return [remote.recv() for remote in self.remotes]
def reset(self):
for remote in self.remotes:
- remote.send(('reset', None))
+ remote.send(("reset", None))
obs = [remote.recv() for remote in self.remotes]
return _flatten_obs(obs, self.observation_space)
@@ -127,7 +127,7 @@ def close(self):
for remote in self.remotes:
remote.recv()
for remote in self.remotes:
- remote.send(('close', None))
+ remote.send(("close", None))
for process in self.processes:
process.join()
self.closed = True
@@ -136,7 +136,7 @@ def get_images(self) -> Sequence[np.ndarray]:
for pipe in self.remotes:
# gather images from subprocesses
# `mode` will be taken into account later
- pipe.send(('render', 'rgb_array'))
+ pipe.send(("render", "rgb_array"))
imgs = [pipe.recv() for pipe in self.remotes]
return imgs
@@ -144,14 +144,14 @@ def get_attr(self, attr_name, indices=None):
"""Return attribute from vectorized environment (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
- remote.send(('get_attr', attr_name))
+ remote.send(("get_attr", attr_name))
return [remote.recv() for remote in target_remotes]
def set_attr(self, attr_name, value, indices=None):
"""Set attribute inside vectorized environments (see base class)."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
- remote.send(('set_attr', (attr_name, value)))
+ remote.send(("set_attr", (attr_name, value)))
for remote in target_remotes:
remote.recv()
@@ -159,7 +159,7 @@ def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""Call instance methods of vectorized environments."""
target_remotes = self._get_target_remotes(indices)
for remote in target_remotes:
- remote.send(('env_method', (method_name, method_args, method_kwargs)))
+ remote.send(("env_method", (method_name, method_args, method_kwargs)))
return [remote.recv() for remote in target_remotes]
def _get_target_remotes(self, indices):
diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py
index 0e8629a71..0ebecb05b 100644
--- a/stable_baselines3/common/vec_env/util.py
+++ b/stable_baselines3/common/vec_env/util.py
@@ -61,7 +61,7 @@ def obs_space_info(obs_space):
elif isinstance(obs_space, gym.spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)}
else:
- assert not hasattr(obs_space, 'spaces'), f"Unsupported structured space '{type(obs_space)}'"
+ assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space}
keys = []
shapes = {}
diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py
index b32ddebd0..5e6b7f251 100644
--- a/stable_baselines3/common/vec_env/vec_frame_stack.py
+++ b/stable_baselines3/common/vec_env/vec_frame_stack.py
@@ -30,16 +30,14 @@ def step_wait(self):
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
for i, done in enumerate(dones):
if done:
- if 'terminal_observation' in infos[i]:
- old_terminal = infos[i]['terminal_observation']
- new_terminal = np.concatenate(
- (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
- infos[i]['terminal_observation'] = new_terminal
+ if "terminal_observation" in infos[i]:
+ old_terminal = infos[i]["terminal_observation"]
+ new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
+ infos[i]["terminal_observation"] = new_terminal
else:
- warnings.warn(
- "VecFrameStack wrapping a VecEnv without terminal_observation info")
+ warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
- self.stackedobs[..., -observations.shape[-1]:] = observations
+ self.stackedobs[..., -observations.shape[-1] :] = observations
return self.stackedobs, rewards, dones, infos
def reset(self):
@@ -48,7 +46,7 @@ def reset(self):
"""
obs = self.venv.reset()
self.stackedobs[...] = 0
- self.stackedobs[..., -obs.shape[-1]:] = obs
+ self.stackedobs[..., -obs.shape[-1] :] = obs
return self.stackedobs
def close(self):
diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py
index 47939a577..213dd6a83 100644
--- a/stable_baselines3/common/vec_env/vec_normalize.py
+++ b/stable_baselines3/common/vec_env/vec_normalize.py
@@ -2,8 +2,8 @@
import numpy as np
-from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.running_mean_std import RunningMeanStd
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
class VecNormalize(VecEnvWrapper):
@@ -21,8 +21,9 @@ class VecNormalize(VecEnvWrapper):
:param epsilon: (float) To avoid division by zero
"""
- def __init__(self, venv, training=True, norm_obs=True, norm_reward=True,
- clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8):
+ def __init__(
+ self, venv, training=True, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0, gamma=0.99, epsilon=1e-8
+ ):
VecEnvWrapper.__init__(self, venv)
self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
self.ret_rms = RunningMeanStd(shape=())
@@ -45,10 +46,10 @@ def __getstate__(self):
Excludes self.venv, as in general VecEnv's may not be pickleable."""
state = self.__dict__.copy()
# these attributes are not pickleable
- del state['venv']
- del state['class_attributes']
+ del state["venv"]
+ del state["class_attributes"]
# these attributes depend on the above and so we would prefer not to pickle
- del state['ret']
+ del state["ret"]
return state
def __setstate__(self, state):
@@ -59,7 +60,7 @@ def __setstate__(self, state):
:param state: (dict)"""
self.__dict__.update(state)
- assert 'venv' not in state
+ assert "venv" not in state
self.venv = None
def set_venv(self, venv):
@@ -110,9 +111,7 @@ def normalize_obs(self, obs):
Calling this method does not update statistics.
"""
if self.norm_obs:
- obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon),
- -self.clip_obs,
- self.clip_obs)
+ obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
return obs
def normalize_reward(self, reward):
@@ -121,8 +120,7 @@ def normalize_reward(self, reward):
Calling this method does not update statistics.
"""
if self.norm_reward:
- reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon),
- -self.clip_reward, self.clip_reward)
+ reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
return reward
def unnormalize_obs(self, obs):
diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py
index 3d25649b1..5dc0364dd 100644
--- a/stable_baselines3/common/vec_env/vec_transpose.py
+++ b/stable_baselines3/common/vec_env/vec_transpose.py
@@ -1,9 +1,10 @@
import typing
+
import numpy as np
from gym import spaces
-from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
from stable_baselines3.common.preprocessing import is_image_space
+from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
if typing.TYPE_CHECKING:
from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401
@@ -18,7 +19,7 @@ class VecTransposeImage(VecEnvWrapper):
"""
def __init__(self, venv: VecEnv):
- assert is_image_space(venv.observation_space), 'The observation space must be an image'
+ assert is_image_space(venv.observation_space), "The observation space must be an image"
observation_space = self.transpose_space(venv.observation_space)
super(VecTransposeImage, self).__init__(venv, observation_space=observation_space)
@@ -31,7 +32,7 @@ def transpose_space(observation_space: spaces.Box) -> spaces.Box:
:param observation_space: (spaces.Box)
:return: (spaces.Box)
"""
- assert is_image_space(observation_space), 'The observation space must be an image'
+ assert is_image_space(observation_space), "The observation space must be an image"
width, height, channels = observation_space.shape
new_shape = (channels, width, height)
return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype)
@@ -48,7 +49,7 @@ def transpose_image(image: np.ndarray) -> np.ndarray:
return np.transpose(image, (2, 0, 1))
return np.transpose(image, (0, 3, 1, 2))
- def step_wait(self) -> 'GymStepReturn':
+ def step_wait(self) -> "GymStepReturn":
observations, rewards, dones, infos = self.venv.step_wait()
return self.transpose_image(observations), rewards, dones, infos
diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py
index 1343ce66d..f7b6ffac7 100644
--- a/stable_baselines3/common/vec_env/vec_video_recorder.py
+++ b/stable_baselines3/common/vec_env/vec_video_recorder.py
@@ -24,8 +24,7 @@ class VecVideoRecorder(VecEnvWrapper):
:param name_prefix: (str) Prefix to the video name
"""
- def __init__(self, venv, video_folder, record_video_trigger,
- video_length=200, name_prefix='rl-video'):
+ def __init__(self, venv, video_folder, record_video_trigger, video_length=200, name_prefix="rl-video"):
VecEnvWrapper.__init__(self, venv)
@@ -39,7 +38,7 @@ def __init__(self, venv, video_folder, record_video_trigger,
temp_env = temp_env.venv
if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv):
- metadata = temp_env.get_attr('metadata')[0]
+ metadata = temp_env.get_attr("metadata")[0]
else:
metadata = temp_env.metadata
@@ -67,12 +66,11 @@ def reset(self):
def start_video_recorder(self):
self.close_video_recorder()
- video_name = f'{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}'
+ video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}"
base_path = os.path.join(self.video_folder, video_name)
- self.video_recorder = video_recorder.VideoRecorder(env=self.env,
- base_path=base_path,
- metadata={'step_id': self.step_id}
- )
+ self.video_recorder = video_recorder.VideoRecorder(
+ env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
+ )
self.video_recorder.capture_frame()
self.recorded_frames = 1
diff --git a/stable_baselines3/ddpg/__init__.py b/stable_baselines3/ddpg/__init__.py
index 11a890b95..0b164b2de 100644
--- a/stable_baselines3/ddpg/__init__.py
+++ b/stable_baselines3/ddpg/__init__.py
@@ -1,2 +1,2 @@
from stable_baselines3.ddpg.ddpg import DDPG
-from stable_baselines3.ddpg.policies import MlpPolicy, CnnPolicy
+from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy
diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py
index 83325b450..cc6e2899f 100644
--- a/stable_baselines3/ddpg/ddpg.py
+++ b/stable_baselines3/ddpg/ddpg.py
@@ -1,11 +1,12 @@
+from typing import Any, Callable, Dict, Optional, Type, Union
+
import torch as th
-from typing import Type, Union, Callable, Optional, Dict, Any
-from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.noise import ActionNoise
+from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
-from stable_baselines3.td3.td3 import TD3
from stable_baselines3.td3.policies import TD3Policy
+from stable_baselines3.td3.td3 import TD3
class DDPG(TD3):
@@ -50,67 +51,86 @@ class DDPG(TD3):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self, policy: Union[str, Type[TD3Policy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable] = 1e-3,
- buffer_size: int = int(1e6),
- learning_starts: int = 100,
- batch_size: int = 100,
- tau: float = 0.005,
- gamma: float = 0.99,
- train_freq: int = -1,
- gradient_steps: int = -1,
- n_episodes_rollout: int = 1,
- action_noise: Optional[ActionNoise] = None,
- optimize_memory_usage: bool = False,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Dict[str, Any] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = 'auto',
- _init_setup_model: bool = True):
+ def __init__(
+ self,
+ policy: Union[str, Type[TD3Policy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable] = 1e-3,
+ buffer_size: int = int(1e6),
+ learning_starts: int = 100,
+ batch_size: int = 100,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ train_freq: int = -1,
+ gradient_steps: int = -1,
+ n_episodes_rollout: int = 1,
+ action_noise: Optional[ActionNoise] = None,
+ optimize_memory_usage: bool = False,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Dict[str, Any] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
- super(DDPG, self).__init__(policy=policy,
- env=env,
- learning_rate=learning_rate,
- buffer_size=buffer_size,
- learning_starts=learning_starts,
- batch_size=batch_size,
- tau=tau, gamma=gamma,
- train_freq=train_freq,
- gradient_steps=gradient_steps,
- n_episodes_rollout=n_episodes_rollout,
- action_noise=action_noise,
- policy_kwargs=policy_kwargs,
- tensorboard_log=tensorboard_log,
- verbose=verbose, device=device,
- create_eval_env=create_eval_env, seed=seed,
- optimize_memory_usage=optimize_memory_usage,
- # Remove all tricks from TD3 to obtain DDPG:
- # we still need to specify target_policy_noise > 0 to avoid errors
- policy_delay=1, target_noise_clip=0.0, target_policy_noise=0.1,
- _init_setup_model=False)
+ super(DDPG, self).__init__(
+ policy=policy,
+ env=env,
+ learning_rate=learning_rate,
+ buffer_size=buffer_size,
+ learning_starts=learning_starts,
+ batch_size=batch_size,
+ tau=tau,
+ gamma=gamma,
+ train_freq=train_freq,
+ gradient_steps=gradient_steps,
+ n_episodes_rollout=n_episodes_rollout,
+ action_noise=action_noise,
+ policy_kwargs=policy_kwargs,
+ tensorboard_log=tensorboard_log,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ optimize_memory_usage=optimize_memory_usage,
+ # Remove all tricks from TD3 to obtain DDPG:
+ # we still need to specify target_policy_noise > 0 to avoid errors
+ policy_delay=1,
+ target_noise_clip=0.0,
+ target_policy_noise=0.1,
+ _init_setup_model=False,
+ )
# Use only one critic
- if 'n_critics' not in self.policy_kwargs:
- self.policy_kwargs['n_critics'] = 1
+ if "n_critics" not in self.policy_kwargs:
+ self.policy_kwargs["n_critics"] = 1
if _init_setup_model:
self._setup_model()
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "DDPG",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 4,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "DDPG",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> OffPolicyAlgorithm:
- return super(DDPG, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
- eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
- tb_log_name=tb_log_name, eval_log_path=eval_log_path,
- reset_num_timesteps=reset_num_timesteps)
+ return super(DDPG, self).learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
diff --git a/stable_baselines3/ddpg/policies.py b/stable_baselines3/ddpg/policies.py
index b826c7af8..64b166f44 100644
--- a/stable_baselines3/ddpg/policies.py
+++ b/stable_baselines3/ddpg/policies.py
@@ -1,2 +1,2 @@
# DDPG can be view as a special case of TD3
-from stable_baselines3.td3.policies import MlpPolicy, CnnPolicy # noqa:F401
+from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy # noqa:F401
diff --git a/stable_baselines3/dqn/__init__.py b/stable_baselines3/dqn/__init__.py
index 8f6968354..4ae42872c 100644
--- a/stable_baselines3/dqn/__init__.py
+++ b/stable_baselines3/dqn/__init__.py
@@ -1,3 +1,2 @@
from stable_baselines3.dqn.dqn import DQN
-from stable_baselines3.dqn.policies import MlpPolicy
-from stable_baselines3.dqn.policies import CnnPolicy
+from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy
diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py
index 05e0a0ae9..55d4fb69a 100644
--- a/stable_baselines3/dqn/dqn.py
+++ b/stable_baselines3/dqn/dqn.py
@@ -1,8 +1,8 @@
-from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch as th
-import torch.nn.functional as F
+from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
@@ -55,41 +55,57 @@ class DQN(OffPolicyAlgorithm):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self, policy: Union[str, Type[DQNPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable] = 1e-4,
- buffer_size: int = 1000000,
- learning_starts: int = 50000,
- batch_size: Optional[int] = 32,
- tau: float = 1.0,
- gamma: float = 0.99,
- train_freq: int = 4,
- gradient_steps: int = 1,
- n_episodes_rollout: int = -1,
- optimize_memory_usage: bool = False,
- target_update_interval: int = 10000,
- exploration_fraction: float = 0.1,
- exploration_initial_eps: float = 1.0,
- exploration_final_eps: float = 0.05,
- max_grad_norm: float = 10,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Optional[Dict[str, Any]] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = 'auto',
- _init_setup_model: bool = True):
-
- super(DQN, self).__init__(policy, env, DQNPolicy, learning_rate,
- buffer_size, learning_starts, batch_size,
- tau, gamma, train_freq, gradient_steps,
- n_episodes_rollout, action_noise=None, # No action noise
- policy_kwargs=policy_kwargs,
- tensorboard_log=tensorboard_log,
- verbose=verbose, device=device,
- create_eval_env=create_eval_env,
- seed=seed, sde_support=False,
- optimize_memory_usage=optimize_memory_usage)
+ def __init__(
+ self,
+ policy: Union[str, Type[DQNPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable] = 1e-4,
+ buffer_size: int = 1000000,
+ learning_starts: int = 50000,
+ batch_size: Optional[int] = 32,
+ tau: float = 1.0,
+ gamma: float = 0.99,
+ train_freq: int = 4,
+ gradient_steps: int = 1,
+ n_episodes_rollout: int = -1,
+ optimize_memory_usage: bool = False,
+ target_update_interval: int = 10000,
+ exploration_fraction: float = 0.1,
+ exploration_initial_eps: float = 1.0,
+ exploration_final_eps: float = 0.05,
+ max_grad_norm: float = 10,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super(DQN, self).__init__(
+ policy,
+ env,
+ DQNPolicy,
+ learning_rate,
+ buffer_size,
+ learning_starts,
+ batch_size,
+ tau,
+ gamma,
+ train_freq,
+ gradient_steps,
+ n_episodes_rollout,
+ action_noise=None, # No action noise
+ policy_kwargs=policy_kwargs,
+ tensorboard_log=tensorboard_log,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ sde_support=False,
+ optimize_memory_usage=optimize_memory_usage,
+ )
self.exploration_initial_eps = exploration_initial_eps
self.exploration_final_eps = exploration_final_eps
@@ -108,8 +124,9 @@ def __init__(self, policy: Union[str, Type[DQNPolicy]],
def _setup_model(self) -> None:
super(DQN, self)._setup_model()
self._create_aliases()
- self.exploration_schedule = get_linear_fn(self.exploration_initial_eps, self.exploration_final_eps,
- self.exploration_fraction)
+ self.exploration_schedule = get_linear_fn(
+ self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
+ )
def _create_aliases(self) -> None:
self.q_net = self.policy.q_net
@@ -164,12 +181,15 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Increase update counter
self._n_updates += gradient_steps
- logger.record("train/n_updates", self._n_updates, exclude='tensorboard')
+ logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
- def predict(self, observation: np.ndarray,
- state: Optional[np.ndarray] = None,
- mask: Optional[np.ndarray] = None,
- deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ def predict(
+ self,
+ observation: np.ndarray,
+ state: Optional[np.ndarray] = None,
+ mask: Optional[np.ndarray] = None,
+ deterministic: bool = False,
+ ) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
@@ -187,21 +207,30 @@ def predict(self, observation: np.ndarray,
action, state = self.policy.predict(observation, state, mask, deterministic)
return action, state
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "DQN",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:
-
- return super(DQN, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
- eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
- tb_log_name=tb_log_name, eval_log_path=eval_log_path,
- reset_num_timesteps=reset_num_timesteps)
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 4,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "DQN",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> OffPolicyAlgorithm:
+
+ return super(DQN, self).learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
def excluded_save_params(self) -> List[str]:
"""
diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py
index 718452229..ea0eaa278 100644
--- a/stable_baselines3/dqn/policies.py
+++ b/stable_baselines3/dqn/policies.py
@@ -1,10 +1,11 @@
-from typing import Optional, List, Callable, Union, Type, Any, Dict
+from typing import Any, Callable, Dict, List, Optional, Type, Union
import gym
import torch as th
-import torch.nn as nn
+from torch import nn as nn
+
from stable_baselines3.common.policies import BasePolicy, register_policy
-from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
+from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
class QNetwork(BasePolicy):
@@ -20,18 +21,24 @@ class QNetwork(BasePolicy):
dividing by 255.0 (True by default)
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- features_extractor: nn.Module,
- features_dim: int,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- normalize_images: bool = True):
- super(QNetwork, self).__init__(observation_space, action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- device=device)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ features_extractor: nn.Module,
+ features_dim: int,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ normalize_images: bool = True,
+ ):
+ super(QNetwork, self).__init__(
+ observation_space,
+ action_space,
+ features_extractor=features_extractor,
+ normalize_images=normalize_images,
+ device=device,
+ )
if net_arch is None:
net_arch = [64, 64]
@@ -63,13 +70,15 @@ def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Ten
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
- data.update(dict(
- net_arch=self.net_arch,
- features_dim=self.features_dim,
- activation_fn=self.activation_fn,
- features_extractor=self.features_extractor,
- epsilon=self.epsilon,
- ))
+ data.update(
+ dict(
+ net_arch=self.net_arch,
+ features_dim=self.features_dim,
+ activation_fn=self.activation_fn,
+ features_extractor=self.features_extractor,
+ epsilon=self.epsilon,
+ )
+ )
return data
@@ -94,23 +103,29 @@ class DQNPolicy(BasePolicy):
excluding the learning rate, to pass to the optimizer
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
- super(DQNPolicy, self).__init__(observation_space, action_space,
- device,
- features_extractor_class,
- features_extractor_kwargs,
- optimizer_class=optimizer_class,
- optimizer_kwargs=optimizer_kwargs)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super(DQNPolicy, self).__init__(
+ observation_space,
+ action_space,
+ device,
+ features_extractor_class,
+ features_extractor_kwargs,
+ optimizer_class=optimizer_class,
+ optimizer_kwargs=optimizer_kwargs,
+ )
if net_arch is None:
if features_extractor_class == FlattenExtractor:
@@ -118,22 +133,21 @@ def __init__(self, observation_space: gym.spaces.Space,
else:
net_arch = []
- self.features_extractor = features_extractor_class(self.observation_space,
- **self.features_extractor_kwargs)
+ self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.normalize_images = normalize_images
self.net_args = {
- 'observation_space': self.observation_space,
- 'action_space': self.action_space,
- 'features_extractor': self.features_extractor,
- 'features_dim': self.features_dim,
- 'net_arch': self.net_arch,
- 'activation_fn': self.activation_fn,
- 'normalize_images': normalize_images,
- 'device': device
+ "observation_space": self.observation_space,
+ "action_space": self.action_space,
+ "features_extractor": self.features_extractor,
+ "features_dim": self.features_dim,
+ "net_arch": self.net_arch,
+ "activation_fn": self.activation_fn,
+ "normalize_images": normalize_images,
+ "device": device,
}
self.q_net, self.q_net_target = None, None
@@ -152,8 +166,7 @@ def _build(self, lr_schedule: Callable) -> None:
self.q_net_target.load_state_dict(self.q_net.state_dict())
# Setup optimizer with initial learning rate
- self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1),
- **self.optimizer_kwargs)
+ self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def make_q_net(self) -> QNetwork:
return QNetwork(**self.net_args).to(self.device)
@@ -167,15 +180,17 @@ def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
- data.update(dict(
- net_arch=self.net_args['net_arch'],
- activation_fn=self.net_args['activation_fn'],
- lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
- optimizer_class=self.optimizer_class,
- optimizer_kwargs=self.optimizer_kwargs,
- features_extractor_class=self.features_extractor_class,
- features_extractor_kwargs=self.features_extractor_kwargs
- ))
+ data.update(
+ dict(
+ net_arch=self.net_args["net_arch"],
+ activation_fn=self.net_args["activation_fn"],
+ lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
+ optimizer_class=self.optimizer_class,
+ optimizer_kwargs=self.optimizer_kwargs,
+ features_extractor_class=self.features_extractor_class,
+ features_extractor_kwargs=self.features_extractor_kwargs,
+ )
+ )
return data
@@ -201,28 +216,33 @@ class CnnPolicy(DQNPolicy):
excluding the learning rate, to pass to the optimizer
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None):
- super(CnnPolicy, self).__init__(observation_space,
- action_space,
- lr_schedule,
- net_arch,
- device,
- activation_fn,
- features_extractor_class,
- features_extractor_kwargs,
- normalize_images,
- optimizer_class,
- optimizer_kwargs)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super(CnnPolicy, self).__init__(
+ observation_space,
+ action_space,
+ lr_schedule,
+ net_arch,
+ device,
+ activation_fn,
+ features_extractor_class,
+ features_extractor_kwargs,
+ normalize_images,
+ optimizer_class,
+ optimizer_kwargs,
+ )
register_policy("MlpPolicy", MlpPolicy)
diff --git a/stable_baselines3/ppo/__init__.py b/stable_baselines3/ppo/__init__.py
index 68fe97f82..c5b80937c 100644
--- a/stable_baselines3/ppo/__init__.py
+++ b/stable_baselines3/ppo/__init__.py
@@ -1,2 +1,2 @@
+from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy
from stable_baselines3.ppo.ppo import PPO
-from stable_baselines3.ppo.policies import MlpPolicy, CnnPolicy
diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py
index 95a7ec835..7d21de8bf 100644
--- a/stable_baselines3/ppo/policies.py
+++ b/stable_baselines3/ppo/policies.py
@@ -1,6 +1,6 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
-from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy
+from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy
MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy
diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py
index 2f05c4d57..eadf96162 100644
--- a/stable_baselines3/ppo/ppo.py
+++ b/stable_baselines3/ppo/ppo.py
@@ -1,15 +1,15 @@
-from typing import Type, Union, Callable, Optional, Dict, Any
+from typing import Any, Callable, Dict, Optional, Type, Union
-from gym import spaces
-import torch as th
-import torch.nn.functional as F
import numpy as np
+import torch as th
+from gym import spaces
+from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
+from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import explained_variance, get_schedule_fn
-from stable_baselines3.common.policies import ActorCriticPolicy
class PPO(OnPolicyAlgorithm):
@@ -62,37 +62,53 @@ class PPO(OnPolicyAlgorithm):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self, policy: Union[str, Type[ActorCriticPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable] = 3e-4,
- n_steps: int = 2048,
- batch_size: Optional[int] = 64,
- n_epochs: int = 10,
- gamma: float = 0.99,
- gae_lambda: float = 0.95,
- clip_range: float = 0.2,
- clip_range_vf: Optional[float] = None,
- ent_coef: float = 0.0,
- vf_coef: float = 0.5,
- max_grad_norm: float = 0.5,
- use_sde: bool = False,
- sde_sample_freq: int = -1,
- target_kl: Optional[float] = None,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Optional[Dict[str, Any]] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = "auto",
- _init_setup_model: bool = True):
-
- super(PPO, self).__init__(policy, env, learning_rate=learning_rate,
- n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda,
- ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm,
- use_sde=use_sde, sde_sample_freq=sde_sample_freq,
- tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
- verbose=verbose, device=device, create_eval_env=create_eval_env,
- seed=seed, _init_setup_model=False)
+ def __init__(
+ self,
+ policy: Union[str, Type[ActorCriticPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable] = 3e-4,
+ n_steps: int = 2048,
+ batch_size: Optional[int] = 64,
+ n_epochs: int = 10,
+ gamma: float = 0.99,
+ gae_lambda: float = 0.95,
+ clip_range: float = 0.2,
+ clip_range_vf: Optional[float] = None,
+ ent_coef: float = 0.0,
+ vf_coef: float = 0.5,
+ max_grad_norm: float = 0.5,
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ target_kl: Optional[float] = None,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Optional[Dict[str, Any]] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super(PPO, self).__init__(
+ policy,
+ env,
+ learning_rate=learning_rate,
+ n_steps=n_steps,
+ gamma=gamma,
+ gae_lambda=gae_lambda,
+ ent_coef=ent_coef,
+ vf_coef=vf_coef,
+ max_grad_norm=max_grad_norm,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ tensorboard_log=tensorboard_log,
+ policy_kwargs=policy_kwargs,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ _init_setup_model=False,
+ )
self.batch_size = batch_size
self.n_epochs = n_epochs
@@ -110,8 +126,7 @@ def _setup_model(self) -> None:
self.clip_range = get_schedule_fn(self.clip_range)
if self.clip_range_vf is not None:
if isinstance(self.clip_range_vf, (float, int)):
- assert self.clip_range_vf > 0, ("`clip_range_vf` must be positive, "
- "pass `None` to deactivate vf clipping")
+ assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping"
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
@@ -173,8 +188,9 @@ def train(self) -> None:
else:
# Clip the different between old and new value
# NOTE: this depends on the reward scaling
- values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf,
- clip_range_vf)
+ values_pred = rollout_data.old_values + th.clamp(
+ values - rollout_data.old_values, -clip_range_vf, clip_range_vf
+ )
# Value loss using the TD(gae_lambda) target
value_loss = F.mse_loss(rollout_data.returns, values_pred)
value_losses.append(value_loss.item())
@@ -205,8 +221,7 @@ def train(self) -> None:
break
self._n_updates += self.n_epochs
- explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
- self.rollout_buffer.values.flatten())
+ explained_var = explained_variance(self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten())
# Logs
logger.record("train/entropy_loss", np.mean(entropy_losses))
@@ -224,18 +239,27 @@ def train(self) -> None:
if self.clip_range_vf is not None:
logger.record("train/clip_range_vf", clip_range_vf)
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 1,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "PPO",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> "PPO":
-
- return super(PPO, self).learn(total_timesteps=total_timesteps, callback=callback,
- log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq,
- n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name,
- eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps)
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 1,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "PPO",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> "PPO":
+
+ return super(PPO, self).learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
diff --git a/stable_baselines3/sac/__init__.py b/stable_baselines3/sac/__init__.py
index 45c7ae740..5b0e89900 100644
--- a/stable_baselines3/sac/__init__.py
+++ b/stable_baselines3/sac/__init__.py
@@ -1,2 +1,2 @@
+from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy
from stable_baselines3.sac.sac import SAC
-from stable_baselines3.sac.policies import MlpPolicy, CnnPolicy
diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py
index e40929313..b5a66670c 100644
--- a/stable_baselines3/sac/policies.py
+++ b/stable_baselines3/sac/policies.py
@@ -1,13 +1,13 @@
-from typing import Optional, List, Tuple, Callable, Union, Type, Dict, Any
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import gym
import torch as th
-import torch.nn as nn
+from torch import nn as nn
-from stable_baselines3.common.preprocessing import get_action_dim
-from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic
-from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
+from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy
+from stable_baselines3.common.preprocessing import get_action_dim
+from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
# CAP the standard deviation of the actor
LOG_STD_MAX = 2
@@ -41,25 +41,31 @@ class Actor(BasePolicy):
:param device: (Union[th.device, str]) Device on which the code should run.
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- net_arch: List[int],
- features_extractor: nn.Module,
- features_dim: int,
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- full_std: bool = True,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- normalize_images: bool = True,
- device: Union[th.device, str] = 'auto'):
- super(Actor, self).__init__(observation_space, action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- device=device,
- squash_output=True)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ net_arch: List[int],
+ features_extractor: nn.Module,
+ features_dim: int,
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ use_sde: bool = False,
+ log_std_init: float = -3,
+ full_std: bool = True,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ clip_mean: float = 2.0,
+ normalize_images: bool = True,
+ device: Union[th.device, str] = "auto",
+ ):
+ super(Actor, self).__init__(
+ observation_space,
+ action_space,
+ features_extractor=features_extractor,
+ normalize_images=normalize_images,
+ device=device,
+ squash_output=True,
+ )
# Save arguments to re-create object at loading
self.use_sde = use_sde
@@ -83,14 +89,16 @@ def __init__(self, observation_space: gym.spaces.Space,
latent_sde_dim = last_layer_dim
# Separate feature extractor for gSDE
if sde_net_arch is not None:
- self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(features_dim, sde_net_arch,
- activation_fn)
-
- self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=use_expln,
- learn_features=True, squash_output=True)
- self.mu, self.log_std = self.action_dist.proba_distribution_net(latent_dim=last_layer_dim,
- latent_sde_dim=latent_sde_dim,
- log_std_init=log_std_init)
+ self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(
+ features_dim, sde_net_arch, activation_fn
+ )
+
+ self.action_dist = StateDependentNoiseDistribution(
+ action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
+ )
+ self.mu, self.log_std = self.action_dist.proba_distribution_net(
+ latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init
+ )
# Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean]
if clip_mean > 0.0:
@@ -103,18 +111,20 @@ def __init__(self, observation_space: gym.spaces.Space,
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
- data.update(dict(
- net_arch=self.net_arch,
- features_dim=self.features_dim,
- activation_fn=self.activation_fn,
- use_sde=self.use_sde,
- log_std_init=self.log_std_init,
- full_std=self.full_std,
- sde_net_arch=self.sde_net_arch,
- use_expln=self.use_expln,
- features_extractor=self.features_extractor,
- clip_mean=self.clip_mean
- ))
+ data.update(
+ dict(
+ net_arch=self.net_arch,
+ features_dim=self.features_dim,
+ activation_fn=self.activation_fn,
+ use_sde=self.use_sde,
+ log_std_init=self.log_std_init,
+ full_std=self.full_std,
+ sde_net_arch=self.sde_net_arch,
+ use_expln=self.use_expln,
+ features_extractor=self.features_extractor,
+ clip_mean=self.clip_mean,
+ )
+ )
return data
def get_std(self) -> th.Tensor:
@@ -127,7 +137,7 @@ def get_std(self) -> th.Tensor:
:return: (th.Tensor)
"""
- msg = 'get_std() is only available when using gSDE'
+ msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
return self.action_dist.get_std(self.log_std)
@@ -137,7 +147,7 @@ def reset_noise(self, batch_size: int = 1) -> None:
:param batch_size: (int)
"""
- msg = 'reset_noise() is only available when using gSDE'
+ msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.log_std, batch_size=batch_size)
@@ -167,8 +177,7 @@ def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor,
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed
- return self.action_dist.actions_from_params(mean_actions, log_std,
- deterministic=deterministic, **kwargs)
+ return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
@@ -210,30 +219,36 @@ class SACPolicy(BasePolicy):
:param n_critics: (int) Number of critic networks to create.
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2):
- super(SACPolicy, self).__init__(observation_space, action_space,
- device,
- features_extractor_class,
- features_extractor_kwargs,
- optimizer_class=optimizer_class,
- optimizer_kwargs=optimizer_kwargs,
- squash_output=True)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ use_sde: bool = False,
+ log_std_init: float = -3,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ clip_mean: float = 2.0,
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2,
+ ):
+ super(SACPolicy, self).__init__(
+ observation_space,
+ action_space,
+ device,
+ features_extractor_class,
+ features_extractor_kwargs,
+ optimizer_class=optimizer_class,
+ optimizer_kwargs=optimizer_kwargs,
+ squash_output=True,
+ )
if net_arch is None:
if features_extractor_class == FlattenExtractor:
@@ -242,33 +257,32 @@ def __init__(self, observation_space: gym.spaces.Space,
net_arch = []
# Create shared features extractor
- self.features_extractor = features_extractor_class(self.observation_space,
- **self.features_extractor_kwargs)
+ self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
- 'observation_space': self.observation_space,
- 'action_space': self.action_space,
- 'features_extractor': self.features_extractor,
- 'features_dim': self.features_dim,
- 'net_arch': self.net_arch,
- 'activation_fn': self.activation_fn,
- 'normalize_images': normalize_images,
- 'device': device
+ "observation_space": self.observation_space,
+ "action_space": self.action_space,
+ "features_extractor": self.features_extractor,
+ "features_dim": self.features_dim,
+ "net_arch": self.net_arch,
+ "activation_fn": self.activation_fn,
+ "normalize_images": normalize_images,
+ "device": device,
}
self.actor_kwargs = self.net_args.copy()
sde_kwargs = {
- 'use_sde': use_sde,
- 'log_std_init': log_std_init,
- 'sde_net_arch': sde_net_arch,
- 'use_expln': use_expln,
- 'clip_mean': clip_mean
+ "use_sde": use_sde,
+ "log_std_init": log_std_init,
+ "sde_net_arch": sde_net_arch,
+ "use_expln": use_expln,
+ "clip_mean": clip_mean,
}
self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy()
- self.critic_kwargs.update({'n_critics': n_critics})
+ self.critic_kwargs.update({"n_critics": n_critics})
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
@@ -277,8 +291,7 @@ def __init__(self, observation_space: gym.spaces.Space,
def _build(self, lr_schedule: Callable) -> None:
self.actor = self.make_actor()
- self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1),
- **self.optimizer_kwargs)
+ self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
@@ -286,29 +299,29 @@ def _build(self, lr_schedule: Callable) -> None:
# Do not optimize the shared feature extractor with the critic loss
# otherwise, there are gradient computation issues
# Another solution: having duplicated features extractor but requires more memory and computation
- critic_parameters = [param for name, param in self.critic.named_parameters() if
- 'features_extractor' not in name]
- self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1),
- **self.optimizer_kwargs)
+ critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name]
+ self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
- data.update(dict(
- net_arch=self.net_args['net_arch'],
- activation_fn=self.net_args['activation_fn'],
- use_sde=self.actor_kwargs['use_sde'],
- log_std_init=self.actor_kwargs['log_std_init'],
- sde_net_arch=self.actor_kwargs['sde_net_arch'],
- use_expln=self.actor_kwargs['use_expln'],
- clip_mean=self.actor_kwargs['clip_mean'],
- n_critics=self.critic_kwargs['n_critics'],
- lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
- optimizer_class=self.optimizer_class,
- optimizer_kwargs=self.optimizer_kwargs,
- features_extractor_class=self.features_extractor_class,
- features_extractor_kwargs=self.features_extractor_kwargs
- ))
+ data.update(
+ dict(
+ net_arch=self.net_args["net_arch"],
+ activation_fn=self.net_args["activation_fn"],
+ use_sde=self.actor_kwargs["use_sde"],
+ log_std_init=self.actor_kwargs["log_std_init"],
+ sde_net_arch=self.actor_kwargs["sde_net_arch"],
+ use_expln=self.actor_kwargs["use_expln"],
+ clip_mean=self.actor_kwargs["clip_mean"],
+ n_critics=self.critic_kwargs["n_critics"],
+ lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
+ optimizer_class=self.optimizer_class,
+ optimizer_kwargs=self.optimizer_kwargs,
+ features_extractor_class=self.features_extractor_class,
+ features_extractor_kwargs=self.features_extractor_kwargs,
+ )
+ )
return data
def reset_noise(self, batch_size: int = 1) -> None:
@@ -364,40 +377,45 @@ class CnnPolicy(SACPolicy):
:param n_critics: (int) Number of critic networks to create.
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- use_sde: bool = False,
- log_std_init: float = -3,
- sde_net_arch: Optional[List[int]] = None,
- use_expln: bool = False,
- clip_mean: float = 2.0,
- features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2):
- super(CnnPolicy, self).__init__(observation_space,
- action_space,
- lr_schedule,
- net_arch,
- device,
- activation_fn,
- use_sde,
- log_std_init,
- sde_net_arch,
- use_expln,
- clip_mean,
- features_extractor_class,
- features_extractor_kwargs,
- normalize_images,
- optimizer_class,
- optimizer_kwargs,
- n_critics)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ use_sde: bool = False,
+ log_std_init: float = -3,
+ sde_net_arch: Optional[List[int]] = None,
+ use_expln: bool = False,
+ clip_mean: float = 2.0,
+ features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2,
+ ):
+ super(CnnPolicy, self).__init__(
+ observation_space,
+ action_space,
+ lr_schedule,
+ net_arch,
+ device,
+ activation_fn,
+ use_sde,
+ log_std_init,
+ sde_net_arch,
+ use_expln,
+ clip_mean,
+ features_extractor_class,
+ features_extractor_kwargs,
+ normalize_images,
+ optimizer_class,
+ optimizer_kwargs,
+ n_critics,
+ )
register_policy("MlpPolicy", MlpPolicy)
diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py
index c177df342..39b8e3a1a 100644
--- a/stable_baselines3/sac/sac.py
+++ b/stable_baselines3/sac/sac.py
@@ -1,12 +1,13 @@
-from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
-import torch as th
-import torch.nn.functional as F
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+
import numpy as np
+import torch as th
+from torch.nn import functional as F
from stable_baselines3.common import logger
+from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
-from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.sac.policies import SACPolicy
@@ -68,44 +69,61 @@ class SAC(OffPolicyAlgorithm):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self, policy: Union[str, Type[SACPolicy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable] = 3e-4,
- buffer_size: int = int(1e6),
- learning_starts: int = 100,
- batch_size: int = 256,
- tau: float = 0.005,
- gamma: float = 0.99,
- train_freq: int = 1,
- gradient_steps: int = 1,
- n_episodes_rollout: int = -1,
- action_noise: Optional[ActionNoise] = None,
- optimize_memory_usage: bool = False,
- ent_coef: Union[str, float] = 'auto',
- target_update_interval: int = 1,
- target_entropy: Union[str, float] = 'auto',
- use_sde: bool = False,
- sde_sample_freq: int = -1,
- use_sde_at_warmup: bool = False,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Dict[str, Any] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = 'auto',
- _init_setup_model: bool = True):
-
- super(SAC, self).__init__(policy, env, SACPolicy, learning_rate,
- buffer_size, learning_starts, batch_size,
- tau, gamma, train_freq, gradient_steps,
- n_episodes_rollout, action_noise,
- policy_kwargs=policy_kwargs,
- tensorboard_log=tensorboard_log,
- verbose=verbose, device=device,
- create_eval_env=create_eval_env, seed=seed,
- use_sde=use_sde, sde_sample_freq=sde_sample_freq,
- use_sde_at_warmup=use_sde_at_warmup,
- optimize_memory_usage=optimize_memory_usage)
+ def __init__(
+ self,
+ policy: Union[str, Type[SACPolicy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable] = 3e-4,
+ buffer_size: int = int(1e6),
+ learning_starts: int = 100,
+ batch_size: int = 256,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ train_freq: int = 1,
+ gradient_steps: int = 1,
+ n_episodes_rollout: int = -1,
+ action_noise: Optional[ActionNoise] = None,
+ optimize_memory_usage: bool = False,
+ ent_coef: Union[str, float] = "auto",
+ target_update_interval: int = 1,
+ target_entropy: Union[str, float] = "auto",
+ use_sde: bool = False,
+ sde_sample_freq: int = -1,
+ use_sde_at_warmup: bool = False,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Dict[str, Any] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super(SAC, self).__init__(
+ policy,
+ env,
+ SACPolicy,
+ learning_rate,
+ buffer_size,
+ learning_starts,
+ batch_size,
+ tau,
+ gamma,
+ train_freq,
+ gradient_steps,
+ n_episodes_rollout,
+ action_noise,
+ policy_kwargs=policy_kwargs,
+ tensorboard_log=tensorboard_log,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ use_sde=use_sde,
+ sde_sample_freq=sde_sample_freq,
+ use_sde_at_warmup=use_sde_at_warmup,
+ optimize_memory_usage=optimize_memory_usage,
+ )
self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor]
@@ -122,7 +140,7 @@ def _setup_model(self) -> None:
super(SAC, self)._setup_model()
self._create_aliases()
# Target entropy is used when learning the entropy coefficient
- if self.target_entropy == 'auto':
+ if self.target_entropy == "auto":
# automatically set target entropy if needed
self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32)
else:
@@ -133,12 +151,12 @@ def _setup_model(self) -> None:
# The entropy coefficient or entropy can be learned automatically
# see Automating Entropy Adjustment for Maximum Entropy RL section
# of https://arxiv.org/abs/1812.05905
- if isinstance(self.ent_coef, str) and self.ent_coef.startswith('auto'):
+ if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"):
# Default initial value of ent_coef when learned
init_value = 1.0
- if '_' in self.ent_coef:
- init_value = float(self.ent_coef.split('_')[1])
- assert init_value > 0., "The initial value of ent_coef must be greater than 0"
+ if "_" in self.ent_coef:
+ init_value = float(self.ent_coef.split("_")[1])
+ assert init_value > 0.0, "The initial value of ent_coef must be greater than 0"
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
@@ -243,28 +261,37 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self._n_updates += gradient_steps
- logger.record("train/n_updates", self._n_updates, exclude='tensorboard')
+ logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/ent_coef", np.mean(ent_coefs))
logger.record("train/actor_loss", np.mean(actor_losses))
logger.record("train/critic_loss", np.mean(critic_losses))
if len(ent_coef_losses) > 0:
logger.record("train/ent_coef_loss", np.mean(ent_coef_losses))
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "SAC",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:
-
- return super(SAC, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
- eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
- tb_log_name=tb_log_name, eval_log_path=eval_log_path,
- reset_num_timesteps=reset_num_timesteps)
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 4,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "SAC",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> OffPolicyAlgorithm:
+
+ return super(SAC, self).learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
def excluded_save_params(self) -> List[str]:
"""
@@ -281,9 +308,9 @@ def get_torch_variables(self) -> Tuple[List[str], List[str]]:
cf base class
"""
state_dicts = ["policy", "actor.optimizer", "critic.optimizer"]
- saved_tensors = ['log_ent_coef']
+ saved_tensors = ["log_ent_coef"]
if self.ent_coef_optimizer is not None:
- state_dicts.append('ent_coef_optimizer')
+ state_dicts.append("ent_coef_optimizer")
else:
- saved_tensors.append('ent_coef_tensor')
+ saved_tensors.append("ent_coef_tensor")
return state_dicts, saved_tensors
diff --git a/stable_baselines3/td3/__init__.py b/stable_baselines3/td3/__init__.py
index cde926384..ed054f0d9 100644
--- a/stable_baselines3/td3/__init__.py
+++ b/stable_baselines3/td3/__init__.py
@@ -1,2 +1,2 @@
+from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy
from stable_baselines3.td3.td3 import TD3
-from stable_baselines3.td3.policies import MlpPolicy, CnnPolicy
diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py
index 325640f5a..bcc64a136 100644
--- a/stable_baselines3/td3/policies.py
+++ b/stable_baselines3/td3/policies.py
@@ -1,12 +1,12 @@
-from typing import Optional, List, Callable, Union, Type, Any, Dict
+from typing import Any, Callable, Dict, List, Optional, Type, Union
import gym
import torch as th
-import torch.nn as nn
+from torch import nn as nn
+from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy
from stable_baselines3.common.preprocessing import get_action_dim
-from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic
-from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor
+from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
class Actor(BasePolicy):
@@ -25,20 +25,25 @@ class Actor(BasePolicy):
:param device: (Union[th.device, str]) Device on which the code should run.
"""
- def __init__(self,
- observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- net_arch: List[int],
- features_extractor: nn.Module,
- features_dim: int,
- activation_fn: Type[nn.Module] = nn.ReLU,
- normalize_images: bool = True,
- device: Union[th.device, str] = 'auto'):
- super(Actor, self).__init__(observation_space, action_space,
- features_extractor=features_extractor,
- normalize_images=normalize_images,
- device=device,
- squash_output=True)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ net_arch: List[int],
+ features_extractor: nn.Module,
+ features_dim: int,
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ normalize_images: bool = True,
+ device: Union[th.device, str] = "auto",
+ ):
+ super(Actor, self).__init__(
+ observation_space,
+ action_space,
+ features_extractor=features_extractor,
+ normalize_images=normalize_images,
+ device=device,
+ squash_output=True,
+ )
self.features_extractor = features_extractor
self.normalize_images = normalize_images
@@ -54,12 +59,14 @@ def __init__(self,
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
- data.update(dict(
- net_arch=self.net_arch,
- features_dim=self.features_dim,
- activation_fn=self.activation_fn,
- features_extractor=self.features_extractor
- ))
+ data.update(
+ dict(
+ net_arch=self.net_arch,
+ features_dim=self.features_dim,
+ activation_fn=self.activation_fn,
+ features_extractor=self.features_extractor,
+ )
+ )
return data
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
@@ -93,25 +100,31 @@ class TD3Policy(BasePolicy):
:param n_critics: (int) Number of critic networks to create.
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2):
- super(TD3Policy, self).__init__(observation_space, action_space,
- device,
- features_extractor_class,
- features_extractor_kwargs,
- optimizer_class=optimizer_class,
- optimizer_kwargs=optimizer_kwargs,
- squash_output=True)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2,
+ ):
+ super(TD3Policy, self).__init__(
+ observation_space,
+ action_space,
+ device,
+ features_extractor_class,
+ features_extractor_kwargs,
+ optimizer_class=optimizer_class,
+ optimizer_kwargs=optimizer_kwargs,
+ squash_output=True,
+ )
# Default network architecture, from the original paper
if net_arch is None:
@@ -120,24 +133,23 @@ def __init__(self, observation_space: gym.spaces.Space,
else:
net_arch = []
- self.features_extractor = features_extractor_class(self.observation_space,
- **self.features_extractor_kwargs)
+ self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {
- 'observation_space': self.observation_space,
- 'action_space': self.action_space,
- 'features_extractor': self.features_extractor,
- 'features_dim': self.features_dim,
- 'net_arch': self.net_arch,
- 'activation_fn': self.activation_fn,
- 'normalize_images': normalize_images,
- 'device': device
+ "observation_space": self.observation_space,
+ "action_space": self.action_space,
+ "features_extractor": self.features_extractor,
+ "features_dim": self.features_dim,
+ "net_arch": self.net_arch,
+ "activation_fn": self.activation_fn,
+ "normalize_images": normalize_images,
+ "device": device,
}
self.critic_kwargs = self.net_args.copy()
- self.critic_kwargs.update({'n_critics': n_critics})
+ self.critic_kwargs.update({"n_critics": n_critics})
self.actor, self.actor_target = None, None
self.critic, self.critic_target = None, None
@@ -147,27 +159,27 @@ def _build(self, lr_schedule: Callable) -> None:
self.actor = self.make_actor()
self.actor_target = self.make_actor()
self.actor_target.load_state_dict(self.actor.state_dict())
- self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1),
- **self.optimizer_kwargs)
+ self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
self.critic = self.make_critic()
self.critic_target = self.make_critic()
self.critic_target.load_state_dict(self.critic.state_dict())
- self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1),
- **self.optimizer_kwargs)
+ self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def _get_data(self) -> Dict[str, Any]:
data = super()._get_data()
- data.update(dict(
- net_arch=self.net_args['net_arch'],
- activation_fn=self.net_args['activation_fn'],
- n_critics=self.critic_kwargs['n_critics'],
- lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
- optimizer_class=self.optimizer_class,
- optimizer_kwargs=self.optimizer_kwargs,
- features_extractor_class=self.features_extractor_class,
- features_extractor_kwargs=self.features_extractor_kwargs
- ))
+ data.update(
+ dict(
+ net_arch=self.net_args["net_arch"],
+ activation_fn=self.net_args["activation_fn"],
+ n_critics=self.critic_kwargs["n_critics"],
+ lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
+ optimizer_class=self.optimizer_class,
+ optimizer_kwargs=self.optimizer_kwargs,
+ features_extractor_class=self.features_extractor_class,
+ features_extractor_kwargs=self.features_extractor_kwargs,
+ )
+ )
return data
def make_actor(self) -> Actor:
@@ -208,30 +220,35 @@ class CnnPolicy(TD3Policy):
:param n_critics: (int) Number of critic networks to create.
"""
- def __init__(self, observation_space: gym.spaces.Space,
- action_space: gym.spaces.Space,
- lr_schedule: Callable,
- net_arch: Optional[List[int]] = None,
- device: Union[th.device, str] = 'auto',
- activation_fn: Type[nn.Module] = nn.ReLU,
- features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
- features_extractor_kwargs: Optional[Dict[str, Any]] = None,
- normalize_images: bool = True,
- optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
- optimizer_kwargs: Optional[Dict[str, Any]] = None,
- n_critics: int = 2):
- super(CnnPolicy, self).__init__(observation_space,
- action_space,
- lr_schedule,
- net_arch,
- device,
- activation_fn,
- features_extractor_class,
- features_extractor_kwargs,
- normalize_images,
- optimizer_class,
- optimizer_kwargs,
- n_critics)
+ def __init__(
+ self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ lr_schedule: Callable,
+ net_arch: Optional[List[int]] = None,
+ device: Union[th.device, str] = "auto",
+ activation_fn: Type[nn.Module] = nn.ReLU,
+ features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
+ features_extractor_kwargs: Optional[Dict[str, Any]] = None,
+ normalize_images: bool = True,
+ optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ n_critics: int = 2,
+ ):
+ super(CnnPolicy, self).__init__(
+ observation_space,
+ action_space,
+ lr_schedule,
+ net_arch,
+ device,
+ activation_fn,
+ features_extractor_class,
+ features_extractor_kwargs,
+ normalize_images,
+ optimizer_class,
+ optimizer_kwargs,
+ n_critics,
+ )
register_policy("MlpPolicy", MlpPolicy)
diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py
index 83b7447aa..368bce16c 100644
--- a/stable_baselines3/td3/td3.py
+++ b/stable_baselines3/td3/td3.py
@@ -1,10 +1,11 @@
+from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
+
import torch as th
-import torch.nn.functional as F
-from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
+from torch.nn import functional as F
from stable_baselines3.common import logger
-from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.noise import ActionNoise
+from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.td3.policies import TD3Policy
@@ -55,39 +56,56 @@ class TD3(OffPolicyAlgorithm):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
- def __init__(self, policy: Union[str, Type[TD3Policy]],
- env: Union[GymEnv, str],
- learning_rate: Union[float, Callable] = 1e-3,
- buffer_size: int = int(1e6),
- learning_starts: int = 100,
- batch_size: int = 100,
- tau: float = 0.005,
- gamma: float = 0.99,
- train_freq: int = -1,
- gradient_steps: int = -1,
- n_episodes_rollout: int = 1,
- action_noise: Optional[ActionNoise] = None,
- optimize_memory_usage: bool = False,
- policy_delay: int = 2,
- target_policy_noise: float = 0.2,
- target_noise_clip: float = 0.5,
- tensorboard_log: Optional[str] = None,
- create_eval_env: bool = False,
- policy_kwargs: Dict[str, Any] = None,
- verbose: int = 0,
- seed: Optional[int] = None,
- device: Union[th.device, str] = 'auto',
- _init_setup_model: bool = True):
-
- super(TD3, self).__init__(policy, env, TD3Policy, learning_rate,
- buffer_size, learning_starts, batch_size,
- tau, gamma, train_freq, gradient_steps,
- n_episodes_rollout, action_noise=action_noise,
- policy_kwargs=policy_kwargs,
- tensorboard_log=tensorboard_log,
- verbose=verbose, device=device,
- create_eval_env=create_eval_env, seed=seed,
- sde_support=False, optimize_memory_usage=optimize_memory_usage)
+ def __init__(
+ self,
+ policy: Union[str, Type[TD3Policy]],
+ env: Union[GymEnv, str],
+ learning_rate: Union[float, Callable] = 1e-3,
+ buffer_size: int = int(1e6),
+ learning_starts: int = 100,
+ batch_size: int = 100,
+ tau: float = 0.005,
+ gamma: float = 0.99,
+ train_freq: int = -1,
+ gradient_steps: int = -1,
+ n_episodes_rollout: int = 1,
+ action_noise: Optional[ActionNoise] = None,
+ optimize_memory_usage: bool = False,
+ policy_delay: int = 2,
+ target_policy_noise: float = 0.2,
+ target_noise_clip: float = 0.5,
+ tensorboard_log: Optional[str] = None,
+ create_eval_env: bool = False,
+ policy_kwargs: Dict[str, Any] = None,
+ verbose: int = 0,
+ seed: Optional[int] = None,
+ device: Union[th.device, str] = "auto",
+ _init_setup_model: bool = True,
+ ):
+
+ super(TD3, self).__init__(
+ policy,
+ env,
+ TD3Policy,
+ learning_rate,
+ buffer_size,
+ learning_starts,
+ batch_size,
+ tau,
+ gamma,
+ train_freq,
+ gradient_steps,
+ n_episodes_rollout,
+ action_noise=action_noise,
+ policy_kwargs=policy_kwargs,
+ tensorboard_log=tensorboard_log,
+ verbose=verbose,
+ device=device,
+ create_eval_env=create_eval_env,
+ seed=seed,
+ sde_support=False,
+ optimize_memory_usage=optimize_memory_usage,
+ )
self.policy_delay = policy_delay
self.target_noise_clip = target_noise_clip
@@ -141,8 +159,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Delayed policy updates
if gradient_step % self.policy_delay == 0:
# Compute actor loss
- actor_loss = -self.critic.q1_forward(replay_data.observations,
- self.actor(replay_data.observations)).mean()
+ actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean()
# Optimize the actor
self.actor.optimizer.zero_grad()
@@ -157,23 +174,32 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
self._n_updates += gradient_steps
- logger.record("train/n_updates", self._n_updates, exclude='tensorboard')
-
- def learn(self,
- total_timesteps: int,
- callback: MaybeCallback = None,
- log_interval: int = 4,
- eval_env: Optional[GymEnv] = None,
- eval_freq: int = -1,
- n_eval_episodes: int = 5,
- tb_log_name: str = "TD3",
- eval_log_path: Optional[str] = None,
- reset_num_timesteps: bool = True) -> OffPolicyAlgorithm:
-
- return super(TD3, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
- eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
- tb_log_name=tb_log_name, eval_log_path=eval_log_path,
- reset_num_timesteps=reset_num_timesteps)
+ logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
+
+ def learn(
+ self,
+ total_timesteps: int,
+ callback: MaybeCallback = None,
+ log_interval: int = 4,
+ eval_env: Optional[GymEnv] = None,
+ eval_freq: int = -1,
+ n_eval_episodes: int = 5,
+ tb_log_name: str = "TD3",
+ eval_log_path: Optional[str] = None,
+ reset_num_timesteps: bool = True,
+ ) -> OffPolicyAlgorithm:
+
+ return super(TD3, self).learn(
+ total_timesteps=total_timesteps,
+ callback=callback,
+ log_interval=log_interval,
+ eval_env=eval_env,
+ eval_freq=eval_freq,
+ n_eval_episodes=n_eval_episodes,
+ tb_log_name=tb_log_name,
+ eval_log_path=eval_log_path,
+ reset_num_timesteps=reset_num_timesteps,
+ )
def excluded_save_params(self) -> List[str]:
"""
diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py
index fd4abb5ba..b1a3339e8 100644
--- a/tests/test_callbacks.py
+++ b/tests/test_callbacks.py
@@ -1,23 +1,28 @@
import os
import shutil
-import pytest
import gym
+import pytest
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG
-from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback,
- EveryNTimesteps, StopTrainingOnRewardThreshold)
+from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
+from stable_baselines3.common.callbacks import (
+ CallbackList,
+ CheckpointCallback,
+ EvalCallback,
+ EveryNTimesteps,
+ StopTrainingOnRewardThreshold,
+)
@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG])
def test_callbacks(tmp_path, model_class):
- log_folder = tmp_path / 'logs/callbacks/'
+ log_folder = tmp_path / "logs/callbacks/"
# Dyn only support discrete actions
env_name = select_env(model_class)
# Create RL model
# Small network for fast test
- model = model_class('MlpPolicy', env_name, policy_kwargs=dict(net_arch=[32]))
+ model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32]))
checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder)
@@ -25,14 +30,13 @@ def test_callbacks(tmp_path, model_class):
# Stop training if the performance is good enough
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1)
- eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best,
- best_model_save_path=log_folder,
- log_path=log_folder, eval_freq=100)
+ eval_callback = EvalCallback(
+ eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100
+ )
# Equivalent to the `checkpoint_callback`
# but here in an event-driven manner
- checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder,
- name_prefix='event')
+ checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event")
event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event)
callback = CallbackList([checkpoint_callback, eval_callback, event_callback])
@@ -49,6 +53,6 @@ def test_callbacks(tmp_path, model_class):
def select_env(model_class) -> str:
if model_class is DQN:
- return 'CartPole-v0'
+ return "CartPole-v0"
else:
- return 'Pendulum-v0'
+ return "Pendulum-v0"
diff --git a/tests/test_cnn.py b/tests/test_cnn.py
index 8c1727663..58c80b6af 100644
--- a/tests/test_cnn.py
+++ b/tests/test_cnn.py
@@ -3,26 +3,24 @@
import numpy as np
import pytest
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN
+from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.identity_env import FakeImageEnv
-@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3, DQN])
+@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
def test_cnn(tmp_path, model_class):
- SAVE_NAME = 'cnn_model.zip'
+ SAVE_NAME = "cnn_model.zip"
# Fake grayscale with frameskip
# Atari after preprocessing: 84x84x1, here we are using lower resolution
# to check that the network handle it automatically
- env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1,
- discrete=model_class not in {SAC, TD3})
+ env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3})
if model_class in {A2C, PPO}:
kwargs = dict(n_steps=100)
else:
# Avoid memory error when using replay buffer
# Reduce the size of the features
- kwargs = dict(buffer_size=250,
- policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
- model = model_class('CnnPolicy', env, **kwargs).learn(250)
+ kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
+ model = model_class("CnnPolicy", env, **kwargs).learn(250)
obs = env.reset()
diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py
index 9637f4e52..cd379126c 100644
--- a/tests/test_custom_policy.py
+++ b/tests/test_custom_policy.py
@@ -4,31 +4,31 @@
from stable_baselines3 import A2C, PPO, SAC, TD3
-@pytest.mark.parametrize('net_arch', [
- [12, dict(vf=[16], pi=[8])],
- [4],
- [],
- [4, 4],
- [12, dict(vf=[8, 4], pi=[8])],
- [12, dict(vf=[8], pi=[8, 4])],
- [12, dict(pi=[8])],
-])
-@pytest.mark.parametrize('model_class', [A2C, PPO])
+@pytest.mark.parametrize(
+ "net_arch",
+ [
+ [12, dict(vf=[16], pi=[8])],
+ [4],
+ [],
+ [4, 4],
+ [12, dict(vf=[8, 4], pi=[8])],
+ [12, dict(vf=[8], pi=[8, 4])],
+ [12, dict(pi=[8])],
+ ],
+)
+@pytest.mark.parametrize("model_class", [A2C, PPO])
def test_flexible_mlp(model_class, net_arch):
- _ = model_class('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
+ _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000)
-@pytest.mark.parametrize('net_arch', [
- [4],
- [4, 4],
-])
-@pytest.mark.parametrize('model_class', [SAC, TD3])
+@pytest.mark.parametrize("net_arch", [[4], [4, 4],])
+@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_custom_offpolicy(model_class, net_arch):
- _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=net_arch)).learn(1000)
+ _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000)
-@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3])
-@pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)])
+@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3])
+@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)])
def test_custom_optimizer(model_class, optimizer_kwargs):
policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32])
- _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000)
+ _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000)
diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py
index 645c4d3ff..16cdcfacf 100644
--- a/tests/test_deterministic.py
+++ b/tests/test_deterministic.py
@@ -1,6 +1,6 @@
import pytest
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN
+from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise
N_STEPS_TRAINING = 3000
@@ -12,18 +12,17 @@ def test_deterministic_training_common(algo):
results = [[], []]
rewards = [[], []]
# Smaller network
- kwargs = {'policy_kwargs': dict(net_arch=[64])}
+ kwargs = {"policy_kwargs": dict(net_arch=[64])}
if algo in [TD3, SAC]:
- env_id = 'Pendulum-v0'
- kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1),
- 'learning_starts': 100})
+ env_id = "Pendulum-v0"
+ kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100})
else:
- env_id = 'CartPole-v1'
+ env_id = "CartPole-v1"
if algo == DQN:
- kwargs.update({'learning_starts': 100})
+ kwargs.update({"learning_starts": 100})
for i in range(2):
- model = algo('MlpPolicy', env_id, seed=SEED, **kwargs)
+ model = algo("MlpPolicy", env_id, seed=SEED, **kwargs)
model.learn(N_STEPS_TRAINING)
env = model.get_env()
obs = env.reset()
diff --git a/tests/test_distributions.py b/tests/test_distributions.py
index 0461e17f4..a73b81ede 100644
--- a/tests/test_distributions.py
+++ b/tests/test_distributions.py
@@ -2,13 +2,17 @@
import torch as th
from stable_baselines3 import A2C, PPO
-from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector,
- StateDependentNoiseDistribution,
- CategoricalDistribution, SquashedDiagGaussianDistribution,
- MultiCategoricalDistribution, BernoulliDistribution)
+from stable_baselines3.common.distributions import (
+ BernoulliDistribution,
+ CategoricalDistribution,
+ DiagGaussianDistribution,
+ MultiCategoricalDistribution,
+ SquashedDiagGaussianDistribution,
+ StateDependentNoiseDistribution,
+ TanhBijector,
+)
from stable_baselines3.common.utils import set_random_seed
-
N_ACTIONS = 2
N_FEATURES = 3
N_SAMPLES = int(5e6)
@@ -33,7 +37,7 @@ def test_squashed_gaussian(model_class):
"""
Test run with squashed Gaussian (notably entropy computation)
"""
- model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
+ model = model_class("MlpPolicy", "Pendulum-v0", use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True))
model.learn(500)
gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
@@ -62,10 +66,9 @@ def test_sde_distribution():
# TODO: analytical form for squashed Gaussian?
-@pytest.mark.parametrize("dist", [
- DiagGaussianDistribution(N_ACTIONS),
- StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),
-])
+@pytest.mark.parametrize(
+ "dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),]
+)
def test_entropy(dist):
# The entropy can be approximated by averaging the negative log likelihood
# mean negative log likelihood == differential entropy
@@ -89,7 +92,7 @@ def test_entropy(dist):
categorical_params = [
(CategoricalDistribution(N_ACTIONS), N_ACTIONS),
(MultiCategoricalDistribution([2, 3]), sum([2, 3])),
- (BernoulliDistribution(N_ACTIONS), N_ACTIONS)
+ (BernoulliDistribution(N_ACTIONS), N_ACTIONS),
]
diff --git a/tests/test_envs.py b/tests/test_envs.py
index 98197d397..e5f6bf3a9 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -1,18 +1,22 @@
-import pytest
import gym
-from gym import spaces
import numpy as np
+import pytest
+from gym import spaces
-from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
-from stable_baselines3.common.identity_env import (IdentityEnv, IdentityEnvBox, FakeImageEnv,
- IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
+from stable_baselines3.common.env_checker import check_env
+from stable_baselines3.common.identity_env import (
+ FakeImageEnv,
+ IdentityEnv,
+ IdentityEnvBox,
+ IdentityEnvMultiBinary,
+ IdentityEnvMultiDiscrete,
+)
-ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary,
- IdentityEnvMultiDiscrete, FakeImageEnv]
+ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv]
-@pytest.mark.parametrize("env_id", ['CartPole-v0', 'Pendulum-v0'])
+@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"])
def test_env(env_id):
"""
Check that environmnent integrated in Gym pass the test.
@@ -25,7 +29,7 @@ def test_env(env_id):
# Pendulum-v0 will produce a warning because the action space is
# in [-2, 2] and not [-1, 1]
- if env_id == 'Pendulum-v0':
+ if env_id == "Pendulum-v0":
assert len(record) == 1
else:
# The other environments must pass without warning
@@ -50,24 +54,28 @@ def test_high_dimension_action_space():
# Patch to avoid error
def patched_step(_action):
return env.observation_space.sample(), 0.0, False, {}
+
env.step = patched_step
check_env(env)
-@pytest.mark.parametrize("new_obs_space", [
- # Small image
- spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
- # Range not in [0, 255]
- spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
- # Wrong dtype
- spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
- # Not an image, it should be a 1D vector
- spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
- # Tuple space is not supported by SB
- spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]),
- # Dict space is not supported by SB when env is not a GoalEnv
- spaces.Dict({"position": spaces.Discrete(5)}),
-])
+@pytest.mark.parametrize(
+ "new_obs_space",
+ [
+ # Small image
+ spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
+ # Range not in [0, 255]
+ spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8),
+ # Wrong dtype
+ spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32),
+ # Not an image, it should be a 1D vector
+ spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32),
+ # Tuple space is not supported by SB
+ spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]),
+ # Dict space is not supported by SB when env is not a GoalEnv
+ spaces.Dict({"position": spaces.Discrete(5)}),
+ ],
+)
def test_non_default_spaces(new_obs_space):
env = FakeImageEnv()
env.observation_space = new_obs_space
diff --git a/tests/test_identity.py b/tests/test_identity.py
index 72d8c4153..38c6570be 100644
--- a/tests/test_identity.py
+++ b/tests/test_identity.py
@@ -1,13 +1,11 @@
import numpy as np
import pytest
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG
-from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv,
- IdentityEnvMultiBinary, IdentityEnvMultiDiscrete)
-
-from stable_baselines3.common.vec_env import DummyVecEnv
+from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy
+from stable_baselines3.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete
from stable_baselines3.common.noise import NormalActionNoise
+from stable_baselines3.common.vec_env import DummyVecEnv
DIM = 4
@@ -25,7 +23,7 @@ def test_discrete(model_class, env):
if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)):
return
- model = model_class('MlpPolicy', env_, gamma=0.5, seed=1, **kwargs).learn(n_steps)
+ model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps)
evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90)
obs = env.reset()
@@ -37,24 +35,14 @@ def test_discrete(model_class, env):
def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)
- n_steps = {
- A2C: 3500,
- PPO: 3000,
- SAC: 700,
- TD3: 500,
- DDPG: 500
- }[model_class]
-
- kwargs = dict(
- policy_kwargs=dict(net_arch=[64, 64]),
- seed=0,
- gamma=0.95
- )
+ n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class]
+
+ kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95)
if model_class in [TD3]:
n_actions = 1
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
- kwargs['action_noise'] = action_noise
+ kwargs["action_noise"] = action_noise
- model = model_class('MlpPolicy', env, **kwargs).learn(n_steps)
+ model = model_class("MlpPolicy", env, **kwargs).learn(n_steps)
evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90)
diff --git a/tests/test_logger.py b/tests/test_logger.py
index 6975ce5c6..c399c9ee4 100644
--- a/tests/test_logger.py
+++ b/tests/test_logger.py
@@ -1,9 +1,24 @@
-import pytest
import numpy as np
+import pytest
-from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure,
- info, debug, set_level, configure, record, record_dict,
- dump, record_mean, warn, error, reset)
+from stable_baselines3.common.logger import (
+ DEBUG,
+ ScopedConfigure,
+ configure,
+ debug,
+ dump,
+ error,
+ info,
+ make_output_format,
+ read_csv,
+ read_json,
+ record,
+ record_dict,
+ record_mean,
+ reset,
+ set_level,
+ warn,
+)
KEY_VALUES = {
"test": 1,
@@ -55,23 +70,23 @@ def test_main(tmp_path):
record_dict({"test": 1})
-@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv', 'tensorboard'])
+@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"])
def test_make_output(tmp_path, _format):
"""
test make output
:param _format: (str) output format
"""
- if _format == 'tensorboard':
+ if _format == "tensorboard":
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")
writer = make_output_format(_format, tmp_path)
writer.write(KEY_VALUES, KEY_EXCLUDED)
if _format == "csv":
- read_csv(tmp_path / 'progress.csv')
- elif _format == 'json':
- read_json(tmp_path / 'progress.json')
+ read_csv(tmp_path / "progress.csv")
+ elif _format == "json":
+ read_json(tmp_path / "progress.json")
writer.close()
@@ -80,4 +95,4 @@ def test_make_output_fail(tmp_path):
test value error on logger
"""
with pytest.raises(ValueError):
- make_output_format('dummy_format', tmp_path)
+ make_output_format("dummy_format", tmp_path)
diff --git a/tests/test_monitor.py b/tests/test_monitor.py
index 00a78026f..d3d041b4d 100644
--- a/tests/test_monitor.py
+++ b/tests/test_monitor.py
@@ -1,9 +1,9 @@
-import uuid
import json
import os
+import uuid
-import pandas
import gym
+import pandas
from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results
@@ -37,15 +37,15 @@ def test_monitor(tmp_path):
assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards)
_ = monitor_env.get_episode_times()
- with open(monitor_file, 'rt') as file_handler:
+ with open(monitor_file, "rt") as file_handler:
first_line = file_handler.readline()
- assert first_line.startswith('#')
+ assert first_line.startswith("#")
metadata = json.loads(first_line[1:])
- assert metadata['env_id'] == "CartPole-v1"
- assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata"
+ assert metadata["env_id"] == "CartPole-v1"
+ assert set(metadata.keys()) == {"env_id", "t_start"}, "Incorrect keys in monitor metadata"
last_logline = pandas.read_csv(file_handler, index_col=None)
- assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline"
+ assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline"
os.remove(monitor_file)
diff --git a/tests/test_predict.py b/tests/test_predict.py
index c95e19f77..4f60f1bc0 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -1,7 +1,7 @@
import gym
import pytest
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN
+from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.vec_env import DummyVecEnv
MODEL_LIST = [
@@ -19,26 +19,26 @@ def test_auto_wrap(model_class):
# Use different environment for DQN
if model_class is DQN:
- env_name = 'CartPole-v0'
+ env_name = "CartPole-v0"
else:
- env_name = 'Pendulum-v0'
+ env_name = "Pendulum-v0"
env = gym.make(env_name)
eval_env = gym.make(env_name)
- model = model_class('MlpPolicy', env)
+ model = model_class("MlpPolicy", env)
model.learn(100, eval_env=eval_env)
@pytest.mark.parametrize("model_class", MODEL_LIST)
-@pytest.mark.parametrize("env_id", ['Pendulum-v0', 'CartPole-v1'])
+@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"])
def test_predict(model_class, env_id):
- if env_id == 'CartPole-v1':
+ if env_id == "CartPole-v1":
if model_class in [SAC, TD3]:
return
elif model_class in [DQN]:
return
# test detection of different shapes by the predict method
- model = model_class('MlpPolicy', env_id)
+ model = model_class("MlpPolicy", env_id)
env = gym.make(env_id)
vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)])
diff --git a/tests/test_run.py b/tests/test_run.py
index d22d346c8..368381ee2 100644
--- a/tests/test_run.py
+++ b/tests/test_run.py
@@ -1,60 +1,97 @@
import numpy as np
import pytest
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG
+from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1))
-@pytest.mark.parametrize('model_class', [TD3, DDPG])
-@pytest.mark.parametrize('action_noise', [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
+@pytest.mark.parametrize("model_class", [TD3, DDPG])
+@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))])
def test_deterministic_pg(model_class, action_noise):
"""
Test for DDPG and variants (TD3).
"""
- model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]),
- learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise)
+ model = model_class(
+ "MlpPolicy",
+ "Pendulum-v0",
+ policy_kwargs=dict(net_arch=[64, 64]),
+ learning_starts=100,
+ verbose=1,
+ create_eval_env=True,
+ action_noise=action_noise,
+ )
model.learn(total_timesteps=1000, eval_freq=500)
-@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
+@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
def test_a2c(env_id):
- model = A2C('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
+ model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=1000, eval_freq=500)
-@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0'])
+@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"])
@pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2])
def test_ppo(env_id, clip_range_vf):
if clip_range_vf is not None and clip_range_vf < 0:
# Should throw an error
with pytest.raises(AssertionError):
- model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True,
- clip_range_vf=clip_range_vf)
+ model = PPO(
+ "MlpPolicy",
+ env_id,
+ seed=0,
+ policy_kwargs=dict(net_arch=[16]),
+ verbose=1,
+ create_eval_env=True,
+ clip_range_vf=clip_range_vf,
+ )
else:
- model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True,
- clip_range_vf=clip_range_vf)
+ model = PPO(
+ "MlpPolicy",
+ env_id,
+ seed=0,
+ policy_kwargs=dict(net_arch=[16]),
+ verbose=1,
+ create_eval_env=True,
+ clip_range_vf=clip_range_vf,
+ )
model.learn(total_timesteps=1000, eval_freq=500)
-@pytest.mark.parametrize("ent_coef", ['auto', 0.01, 'auto_0.01'])
+@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_sac(ent_coef):
- model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]),
- learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef,
- action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)))
+ model = SAC(
+ "MlpPolicy",
+ "Pendulum-v0",
+ policy_kwargs=dict(net_arch=[64, 64]),
+ learning_starts=100,
+ verbose=1,
+ create_eval_env=True,
+ ent_coef=ent_coef,
+ action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
+ )
model.learn(total_timesteps=1000, eval_freq=500)
@pytest.mark.parametrize("n_critics", [1, 3])
def test_n_critics(n_critics):
# Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG
- model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics),
- learning_starts=100, verbose=1)
+ model = SAC(
+ "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1
+ )
model.learn(total_timesteps=1000)
def test_dqn():
- model = DQN('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=[64, 64]),
- learning_starts=500, buffer_size=500, learning_rate=3e-4, verbose=1, create_eval_env=True)
+ model = DQN(
+ "MlpPolicy",
+ "CartPole-v1",
+ policy_kwargs=dict(net_arch=[64, 64]),
+ learning_starts=500,
+ buffer_size=500,
+ learning_rate=3e-4,
+ verbose=1,
+ create_eval_env=True,
+ )
model.learn(total_timesteps=1000, eval_freq=500)
diff --git a/tests/test_save_load.py b/tests/test_save_load.py
index 69df31a98..d56c40d5e 100644
--- a/tests/test_save_load.py
+++ b/tests/test_save_load.py
@@ -1,33 +1,21 @@
-import os
import io
+import os
+import pathlib
import warnings
from copy import deepcopy
-import pathlib
-import pytest
import gym
import numpy as np
+import pytest
import torch as th
-from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG
+from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm
-from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv
+from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox
+from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl
from stable_baselines3.common.vec_env import DummyVecEnv
-from stable_baselines3.common.identity_env import FakeImageEnv
-from stable_baselines3.common.save_util import (
- open_path,
- save_to_pkl,
- load_from_pkl,
-)
-
-MODEL_LIST = [
- PPO,
- A2C,
- TD3,
- SAC,
- DQN,
- DDPG
-]
+
+MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG]
def select_env(model_class: BaseAlgorithm) -> gym.Env:
@@ -58,17 +46,13 @@ def test_save_load(tmp_path, model_class):
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
- observations = np.concatenate(
- [env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0
- )
+ observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
# Modify all parameters to be random values
- random_params = dict(
- (param_name, th.rand_like(param)) for param_name, param in params.items()
- )
+ random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
model.policy.load_state_dict(random_params)
@@ -76,9 +60,7 @@ def test_save_load(tmp_path, model_class):
new_params = model.policy.state_dict()
# Check that all params are different now
for k in params:
- assert not th.allclose(
- params[k], new_params[k]
- ), "Parameters did not change as expected."
+ assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
@@ -95,9 +77,7 @@ def test_save_load(tmp_path, model_class):
# Check that all params are the same as before save load procedure now
for key in params:
- assert th.allclose(
- params[key], new_params[key]
- ), "Model parameters not the same after save and load."
+ assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = model.predict(observations, deterministic=True)
@@ -229,9 +209,7 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage):
if optimize_memory_usage:
assert len(recwarn) == 1
warning = recwarn.pop(UserWarning)
- assert "The last trajectory in the replay buffer will be truncated" in str(
- warning.message
- )
+ assert "The last trajectory in the replay buffer will be truncated" in str(warning.message)
else:
assert len(recwarn) == 0
@@ -253,22 +231,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# Avoid memory error when using replay buffer
# Reduce the size of the features
kwargs = dict(buffer_size=250)
- env = FakeImageEnv(
- screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN
- )
+ env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN)
env = DummyVecEnv([lambda: env])
# create model
- model = model_class(
- policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs
- )
+ model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs)
model.learn(total_timesteps=500, eval_freq=250)
env.reset()
- observations = np.concatenate(
- [env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0
- )
+ observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0)
policy = model.policy
policy_class = policy.__class__
@@ -281,9 +253,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
params = deepcopy(policy.state_dict())
# Modify all parameters to be random values
- random_params = dict(
- (param_name, th.rand_like(param)) for param_name, param in params.items()
- )
+ random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())
# Update model parameters with the new random values
policy.load_state_dict(random_params)
@@ -291,9 +261,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
new_params = policy.state_dict()
# Check that all params are different now
for k in params:
- assert not th.allclose(
- params[k], new_params[k]
- ), "Parameters did not change as expected."
+ assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."
params = new_params
@@ -320,9 +288,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str):
# Check that all params are the same as before save load procedure now
for key in params:
- assert th.allclose(
- params[key], new_params[key]
- ), "Policy parameters not the same after save and load."
+ assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions, _ = policy.predict(observations, deterministic=True)
@@ -370,17 +336,11 @@ def test_open_file_str_pathlib(tmp_path, pathtype):
save_to_pkl(fp1, "foo")
assert fp1.closed
with pytest.warns(None) as record:
- assert (
- load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl"))
- == "foo"
- )
+ assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo"
assert len(record) == 0
with pytest.warns(None) as record:
- assert (
- load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2))
- == "foo"
- )
+ assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo"
assert len(record) == 1
fp = pathlib.Path(f"{tmp_path}/t2").open("w")
diff --git a/tests/test_sde.py b/tests/test_sde.py
index dbf196368..d26e06cec 100644
--- a/tests/test_sde.py
+++ b/tests/test_sde.py
@@ -2,7 +2,7 @@
import torch as th
from torch.distributions import Normal
-from stable_baselines3 import A2C, SAC, PPO
+from stable_baselines3 import A2C, PPO, SAC
def test_state_dependent_exploration_grad():
@@ -58,6 +58,13 @@ def test_state_dependent_exploration_grad():
@pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []])
@pytest.mark.parametrize("use_expln", [False, True])
def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln):
- model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True,
- verbose=1, policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln))
+ model = model_class(
+ "MlpPolicy",
+ "Pendulum-v0",
+ use_sde=True,
+ seed=None,
+ create_eval_env=True,
+ verbose=1,
+ policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln),
+ )
model.learn(total_timesteps=int(500), eval_freq=250)
diff --git a/tests/test_spaces.py b/tests/test_spaces.py
index e057f97b3..98a1953fc 100644
--- a/tests/test_spaces.py
+++ b/tests/test_spaces.py
@@ -1,6 +1,6 @@
+import gym
import numpy as np
import pytest
-import gym
from stable_baselines3 import DQN, SAC, TD3
from stable_baselines3.common.evaluation import evaluate_policy
diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py
index e5441a85c..3f755a7aa 100644
--- a/tests/test_tensorboard.py
+++ b/tests/test_tensorboard.py
@@ -1,13 +1,14 @@
import os
+
import pytest
from stable_baselines3 import A2C, PPO, SAC, TD3
MODEL_DICT = {
- 'a2c': (A2C, 'CartPole-v1'),
- 'ppo': (PPO, 'CartPole-v1'),
- 'sac': (SAC, 'Pendulum-v0'),
- 'td3': (TD3, 'Pendulum-v0'),
+ "a2c": (A2C, "CartPole-v1"),
+ "ppo": (PPO, "CartPole-v1"),
+ "sac": (SAC, "Pendulum-v0"),
+ "td3": (TD3, "Pendulum-v0"),
}
N_STEPS = 100
@@ -20,7 +21,7 @@ def test_tensorboard(tmp_path, model_name):
logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
- model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=tmp_path)
+ model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
model.learn(N_STEPS)
model.learn(N_STEPS, reset_num_timesteps=False)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 4f0473871..7f03c1d0c 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,27 +1,25 @@
import os
import shutil
-import pytest
import gym
import numpy as np
+import pytest
from stable_baselines3 import A2C
-from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.atari_wrappers import ClipRewardEnv
+from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
-from stable_baselines3.common.cmd_util import make_vec_env, make_atari_env
+from stable_baselines3.common.monitor import Monitor
+from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
-from stable_baselines3.common.noise import (VectorizedActionNoise,
- OrnsteinUhlenbeckActionNoise, ActionNoise)
-@pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')])
+@pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")])
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv])
@pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit])
def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
- env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls,
- wrapper_class=wrapper_class, monitor_dir=None, seed=0)
+ env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, wrapper_class=wrapper_class, monitor_dir=None, seed=0)
assert env.num_envs == n_envs
@@ -37,13 +35,12 @@ def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class):
env.close()
-@pytest.mark.parametrize("env_id", ['BreakoutNoFrameskip-v4'])
+@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"])
@pytest.mark.parametrize("n_envs", [1, 2])
@pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)])
def test_make_atari_env(env_id, n_envs, wrapper_kwargs):
- env_id = 'BreakoutNoFrameskip-v4'
- env = make_atari_env(env_id, n_envs,
- wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0)
+ env_id = "BreakoutNoFrameskip-v4"
+ env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0)
assert env.num_envs == n_envs
@@ -70,10 +67,15 @@ def test_custom_vec_env(tmp_path):
"""
Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests.
"""
- monitor_dir = tmp_path / 'test_make_vec_env/'
- env = make_vec_env('CartPole-v1', n_envs=1,
- monitor_dir=monitor_dir, seed=0,
- vec_env_cls=SubprocVecEnv, vec_env_kwargs={'start_method': None})
+ monitor_dir = tmp_path / "test_make_vec_env/"
+ env = make_vec_env(
+ "CartPole-v1",
+ n_envs=1,
+ monitor_dir=monitor_dir,
+ seed=0,
+ vec_env_cls=SubprocVecEnv,
+ vec_env_kwargs={"start_method": None},
+ )
assert env.num_envs == 1
assert isinstance(env, SubprocVecEnv)
@@ -85,20 +87,27 @@ def test_custom_vec_env(tmp_path):
# This should fail because DummyVecEnv does not have any keyword argument
with pytest.raises(TypeError):
- make_vec_env('CartPole-v1', n_envs=1, vec_env_kwargs={'dummy': False})
+ make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False})
def test_evaluate_policy():
- model = A2C('MlpPolicy', 'Pendulum-v0', seed=0)
+ model = A2C("MlpPolicy", "Pendulum-v0", seed=0)
n_steps_per_episode, n_eval_episodes = 200, 2
model.n_callback_calls = 0
def dummy_callback(locals_, _globals):
- locals_['model'].n_callback_calls += 1
-
- _, episode_lengths = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True,
- render=False, callback=dummy_callback, reward_threshold=None,
- return_episode_rewards=True)
+ locals_["model"].n_callback_calls += 1
+
+ _, episode_lengths = evaluate_policy(
+ model,
+ model.get_env(),
+ n_eval_episodes,
+ deterministic=True,
+ render=False,
+ callback=dummy_callback,
+ reward_threshold=None,
+ return_episode_rewards=True,
+ )
n_steps = sum(episode_lengths)
assert n_steps == n_steps_per_episode * n_eval_episodes
diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py
index c6db91a94..265da2ed9 100644
--- a/tests/test_vec_check_nan.py
+++ b/tests/test_vec_check_nan.py
@@ -1,14 +1,15 @@
import gym
-from gym import spaces
import numpy as np
import pytest
+from gym import spaces
from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
class NanAndInfEnv(gym.Env):
"""Custom Environment that raised NaNs and Infs"""
- metadata = {'render.modes': ['human']}
+
+ metadata = {"render.modes": ["human"]}
def __init__(self):
super(NanAndInfEnv, self).__init__()
@@ -18,9 +19,9 @@ def __init__(self):
@staticmethod
def step(action):
if np.all(np.array(action) > 0):
- obs = float('NaN')
+ obs = float("NaN")
elif np.all(np.array(action) < 0):
- obs = float('inf')
+ obs = float("inf")
else:
obs = 0
return [obs], 0.0, False, {}
@@ -29,7 +30,7 @@ def step(action):
def reset():
return [0.0]
- def render(self, mode='human', close=False):
+ def render(self, mode="human", close=False):
pass
@@ -42,10 +43,10 @@ def test_check_nan():
env.step([[0]])
with pytest.raises(ValueError):
- env.step([[float('NaN')]])
+ env.step([[float("NaN")]])
with pytest.raises(ValueError):
- env.step([[float('inf')]])
+ env.step([[float("inf")]])
with pytest.raises(ValueError):
env.step([[-1]])
diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py
index 5e559dd19..8c33341c5 100644
--- a/tests/test_vec_envs.py
+++ b/tests/test_vec_envs.py
@@ -3,11 +3,11 @@
import itertools
import multiprocessing
-import pytest
import gym
import numpy as np
+import pytest
-from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack
+from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize
N_ENVS = 3
VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv]
@@ -39,8 +39,8 @@ def step(self, action):
def _choose_next_state(self):
self.state = self.observation_space.sample()
- def render(self, mode='human'):
- if mode == 'rgb_array':
+ def render(self, mode="human"):
+ if mode == "rgb_array":
return np.zeros((4, 4, 3))
def seed(self, seed=None):
@@ -59,8 +59,8 @@ def custom_method(dim_0=1, dim_1=1):
return np.ones((dim_0, dim_1))
-@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
-@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
+@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
+@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS)
def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper):
"""Test access to methods/attributes of vectorized environments"""
@@ -79,14 +79,14 @@ def make_env():
vec_env.seed(0)
# Test render method call
# vec_env.render() # we need a X server to test the "human" mode
- vec_env.render(mode='rgb_array')
- env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2)
+ vec_env.render(mode="rgb_array")
+ env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2)
setattr_results = []
# Set current_step to an arbitrary value
for env_idx in range(N_ENVS):
- setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx))
+ setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx))
# Retrieve the value for each environment
- getattr_results = vec_env.get_attr('current_step')
+ getattr_results = vec_env.get_attr("current_step")
assert len(env_method_results) == N_ENVS
assert len(setattr_results) == N_ENVS
@@ -98,34 +98,34 @@ def make_env():
assert getattr_results[env_idx] == env_idx
# Call env_method on a subset of the VecEnv
- env_method_subset = vec_env.env_method('custom_method', 1, indices=[0, 2], dim_1=3)
+ env_method_subset = vec_env.env_method("custom_method", 1, indices=[0, 2], dim_1=3)
assert (env_method_subset[0] == np.ones((1, 3))).all()
assert (env_method_subset[1] == np.ones((1, 3))).all()
assert len(env_method_subset) == 2
# Test to change value for all the environments
- setattr_result = vec_env.set_attr('current_step', 42, indices=None)
- getattr_result = vec_env.get_attr('current_step')
+ setattr_result = vec_env.set_attr("current_step", 42, indices=None)
+ getattr_result = vec_env.get_attr("current_step")
assert setattr_result is None
assert getattr_result == [42 for _ in range(N_ENVS)]
# Additional tests for setattr that does not affect all the environments
vec_env.reset()
- setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1])
- getattr_result = vec_env.get_attr('current_step')
- getattr_result_subset = vec_env.get_attr('current_step', indices=[0, 1])
+ setattr_result = vec_env.set_attr("current_step", 12, indices=[0, 1])
+ getattr_result = vec_env.get_attr("current_step")
+ getattr_result_subset = vec_env.get_attr("current_step", indices=[0, 1])
assert setattr_result is None
assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)]
assert getattr_result_subset == [12, 12]
- assert vec_env.get_attr('current_step', indices=[0, 2]) == [12, 0]
+ assert vec_env.get_attr("current_step", indices=[0, 2]) == [12, 0]
vec_env.reset()
# Change value only for first and last environment
- setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1])
- getattr_result = vec_env.get_attr('current_step')
+ setattr_result = vec_env.set_attr("current_step", 12, indices=[0, -1])
+ getattr_result = vec_env.get_attr("current_step")
assert setattr_result is None
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
- assert vec_env.get_attr('current_step', indices=[-1]) == [12]
+ assert vec_env.get_attr("current_step", indices=[-1]) == [12]
vec_env.close()
@@ -135,24 +135,23 @@ def __init__(self, max_steps):
"""Gym environment for testing that terminal observation is inserted
correctly."""
self.action_space = gym.spaces.Discrete(2)
- self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]),
- dtype='int')
+ self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), dtype="int")
self.max_steps = max_steps
self.current_step = 0
def reset(self):
self.current_step = 0
- return np.array([self.current_step], dtype='int')
+ return np.array([self.current_step], dtype="int")
def step(self, action):
prev_step = self.current_step
self.current_step += 1
done = self.current_step >= self.max_steps
- return np.array([prev_step], dtype='int'), 0.0, done, {}
+ return np.array([prev_step], dtype="int"), 0.0, done, {}
-@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
-@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS)
+@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
+@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS)
def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
"""Test that 'terminal_observation' gets added to info dict upon
termination."""
@@ -165,7 +164,7 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
else:
vec_env = vec_env_wrapper(vec_env)
- zero_acts = np.zeros((N_ENVS,), dtype='int')
+ zero_acts = np.zeros((N_ENVS,), dtype="int")
prev_obs_b = vec_env.reset()
for step_num in range(1, max(step_nums) + 1):
obs_b, _, done_b, info_b = vec_env.step(zero_acts)
@@ -176,9 +175,9 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
for prev_obs, obs, done, info, final_step_num in env_iter:
assert done == (step_num == final_step_num)
if not done:
- assert 'terminal_observation' not in info
+ assert "terminal_observation" not in info
else:
- terminal_obs = info['terminal_observation']
+ terminal_obs = info["terminal_observation"]
# do some rough ordering checks that should work for all
# wrappers, including VecNormalize
@@ -196,12 +195,14 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper):
vec_env.close()
-SPACES = collections.OrderedDict([
- ('discrete', gym.spaces.Discrete(2)),
- ('multidiscrete', gym.spaces.MultiDiscrete([2, 3])),
- ('multibinary', gym.spaces.MultiBinary(3)),
- ('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
-])
+SPACES = collections.OrderedDict(
+ [
+ ("discrete", gym.spaces.Discrete(2)),
+ ("multidiscrete", gym.spaces.MultiDiscrete([2, 3])),
+ ("multibinary", gym.spaces.MultiBinary(3)),
+ ("continuous", gym.spaces.Box(low=np.zeros(2), high=np.ones(2))),
+ ]
+)
def check_vecenv_spaces(vec_env_class, space, obs_assert):
@@ -230,7 +231,7 @@ def check_vecenv_obs(obs, space):
assert space.contains(value)
-@pytest.mark.parametrize('vec_env_class,space', itertools.product(VEC_ENV_CLASSES, SPACES.values()))
+@pytest.mark.parametrize("vec_env_class,space", itertools.product(VEC_ENV_CLASSES, SPACES.values()))
def test_vecenv_single_space(vec_env_class, space):
def obs_assert(obs):
return check_vecenv_obs(obs, space)
@@ -245,7 +246,7 @@ def sample(self):
return dict(super().sample())
-@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
+@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vecenv_dict_spaces(vec_env_class):
"""Test dictionary observation spaces with vectorized environments."""
space = gym.spaces.Dict(SPACES)
@@ -263,7 +264,7 @@ def obs_assert(obs):
check_vecenv_spaces(vec_env_class, unordered_space, obs_assert)
-@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES)
+@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES)
def test_vecenv_tuple_spaces(vec_env_class):
"""Test tuple observation spaces with vectorized environments."""
space = gym.spaces.Tuple(tuple(SPACES.values()))
@@ -281,7 +282,7 @@ def test_subproc_start_method():
start_methods = [None]
# Only test thread-safe methods. Others may deadlock tests! (gh/428)
# Note: adding unsafe `fork` method as we are now using PyTorch
- all_methods = {'forkserver', 'spawn', 'fork'}
+ all_methods = {"forkserver", "spawn", "fork"}
available_methods = multiprocessing.get_all_start_methods()
start_methods += list(all_methods.intersection(available_methods))
space = gym.spaces.Discrete(2)
@@ -294,20 +295,20 @@ def obs_assert(obs):
check_vecenv_spaces(vec_env_class, space, obs_assert)
with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"):
- vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method')
+ vec_env_class = functools.partial(SubprocVecEnv, start_method="illegal_method")
check_vecenv_spaces(vec_env_class, space, obs_assert)
class CustomWrapperA(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
- self.var_a = 'a'
+ self.var_a = "a"
class CustomWrapperB(VecNormalize):
def __init__(self, venv):
VecNormalize.__init__(self, venv)
- self.var_b = 'b'
+ self.var_b = "b"
def func_b(self):
return self.var_b
@@ -319,7 +320,7 @@ def name_test(self):
class CustomWrapperBB(CustomWrapperB):
def __init__(self, venv):
CustomWrapperB.__init__(self, venv)
- self.var_bb = 'bb'
+ self.var_bb = "bb"
def test_vecenv_wrapper_getattr():
@@ -328,10 +329,10 @@ def make_env():
vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
wrapped = CustomWrapperA(CustomWrapperBB(vec_env))
- assert wrapped.var_a == 'a'
- assert wrapped.var_b == 'b'
- assert wrapped.var_bb == 'bb'
- assert wrapped.func_b() == 'b'
+ assert wrapped.var_a == "a"
+ assert wrapped.var_b == "b"
+ assert wrapped.var_bb == "bb"
+ assert wrapped.func_b() == "b"
assert wrapped.name_test() == CustomWrapperBB
double_wrapped = CustomWrapperA(CustomWrapperB(wrapped))
diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py
index 979ea67e2..311e3c92e 100644
--- a/tests/test_vec_normalize.py
+++ b/tests/test_vec_normalize.py
@@ -1,13 +1,18 @@
import gym
-import pytest
import numpy as np
+import pytest
-from stable_baselines3.common.running_mean_std import RunningMeanStd
-from stable_baselines3.common.vec_env import (DummyVecEnv, VecNormalize, VecFrameStack, sync_envs_normalization,
- unwrap_vec_normalize)
from stable_baselines3 import SAC, TD3
+from stable_baselines3.common.running_mean_std import RunningMeanStd
+from stable_baselines3.common.vec_env import (
+ DummyVecEnv,
+ VecFrameStack,
+ VecNormalize,
+ sync_envs_normalization,
+ unwrap_vec_normalize,
+)
-ENV_ID = 'Pendulum-v0'
+ENV_ID = "Pendulum-v0"
def make_env():
@@ -54,8 +59,9 @@ def _make_warmstart_cartpole():
def test_runningmeanstd():
"""Test RunningMeanStd object"""
for (x_1, x_2, x_3) in [
- (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
- (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2))]:
+ (np.random.randn(3), np.random.randn(4), np.random.randn(5)),
+ (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)),
+ ]:
rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:])
x_cat = np.concatenate([x_1, x_2, x_3], axis=0)
@@ -120,12 +126,12 @@ def test_normalize_external():
@pytest.mark.parametrize("model_class", [SAC, TD3])
def test_offpolicy_normalization(model_class):
env = DummyVecEnv([make_env])
- env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.)
+ env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0)
eval_env = DummyVecEnv([make_env])
- eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10., clip_reward=10.)
+ eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0)
- model = model_class('MlpPolicy', env, verbose=1, policy_kwargs=dict(net_arch=[64]))
+ model = model_class("MlpPolicy", env, verbose=1, policy_kwargs=dict(net_arch=[64]))
model.learn(total_timesteps=1000, eval_env=eval_env, eval_freq=500)
# Check getter
assert isinstance(model.get_vec_normalize_env(), VecNormalize)
@@ -136,7 +142,7 @@ def test_sync_vec_normalize():
assert unwrap_vec_normalize(env) is None
- env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100., clip_reward=100.)
+ env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
@@ -145,8 +151,7 @@ def test_sync_vec_normalize():
assert isinstance(unwrap_vec_normalize(env), VecNormalize)
eval_env = DummyVecEnv([make_env])
- eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True,
- clip_obs=100., clip_reward=100.)
+ eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0)
eval_env = VecFrameStack(eval_env, 1)
env.seed(0)