diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 761e41e6a..2b35d61e8 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -24,7 +24,11 @@ - [ ] My change requires a change to the documentation. - [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*). - [ ] I have updated the documentation accordingly. -- [ ] I have checked the codestyle using `make lint` -- [ ] I have ensured `make pytest` and `make type` both pass. +- [ ] I have reformatted the code using `make format` (**required**) +- [ ] I have checked the codestyle using `make check-codestyle` and `make lint` (**required**) +- [ ] I have ensured `make pytest` and `make type` both pass. (**required**) + + +Note: we are using a maximum length of 127 characters per line diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cbad4a146..47ca91980 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,9 @@ jobs: - name: Type check run: | make type + - name: Check codestyle + run: | + make check-codestyle - name: Lint with flake8 run: | make lint diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 8eeaf41f8..695c14543 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:0.8.0a1 +image: stablebaselines/stable-baselines3-cpu:0.8.0a4 type-check: script: @@ -15,4 +15,5 @@ doc-build: lint-check: script: + - make check-codestyle - make lint diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ff56c13e3..e04e59ee4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -38,23 +38,9 @@ pip install -e .[docs,tests,extra] ## Codestyle -We follow the [PEP8 codestyle](https://www.python.org/dev/peps/pep-0008/). Please order the imports as follows: +We are using [black codestyle](https://github.com/psf/black) (max line length of 127 characters) together with [isort](https://github.com/timothycrosley/isort) to sort the imports. -1. built-in -2. packages -3. current module - -with one space between each, that gives for instance: -```python -import os -import warnings - -import numpy as np - -from stable_baselines3 import PPO -``` - -In general, we recommend using pycharm to format everything in an efficient way. +**Please run `make format`** to reformat your code. You can check the codestyle using `make check-codestyle` and `make lint`. Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template: @@ -77,11 +63,10 @@ def my_function(arg1: type1, arg2: type2) -> returntype: Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process. Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli). -A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch. +A PR must pass the Continuous Integration tests to be merged with the master branch. -Note: in rare cases, we can create exception for codacy failure. -## Test +## Tests All new features must add tests in the `tests/` folder ensuring that everything works fine. We use [pytest](https://pytest.org/). @@ -99,12 +84,18 @@ Type checking with `pytype`: make type ``` -Codestyle check with `flake8`: +Codestyle check with `black`, `isort` and `flake8`: ``` +make check-codestyle make lint ``` +To run `pytype`, `format` and `lint` in one command: +``` +make commit-checks +``` + Build the documentation: ``` @@ -121,6 +112,7 @@ make spelling ## Changelog and Documentation Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed. +You should add your username next to each changelog entry that you added. If this is your first contribution, please add your username at the bottom too. A README is present in the `docs/` folder for instructions on how to build the documentation. diff --git a/Dockerfile b/Dockerfile index 0b30e6d76..2ed1d2309 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ curl \ ca-certificates \ libjpeg-dev \ - libpng-dev && \ + libpng-dev \ + libglib2.0-0 && \ rm -rf /var/lib/apt/lists/* # Install anaconda abd dependencies @@ -37,15 +38,4 @@ RUN \ pip install opencv-python-headless && \ rm -rf $HOME/.cache/pip - -# Codacy deps -RUN apt-get update && apt-get install -y --no-install-recommends \ - default-jre \ - jq && \ - rm -rf /var/lib/apt/lists/* - -# Codacy code coverage report: used for partial code coverage reporting -RUN cd $CODE_DIR && \ - curl -Ls -o codacy-coverage-reporter.jar "$(curl -Ls https://api.github.com/repos/codacy/codacy-coverage-reporter/releases/latest | jq -r '.assets | map({name, browser_download_url} | select(.name | (startswith("codacy-coverage-reporter") and contains("assembly") and endswith(".jar")))) | .[0].browser_download_url')" - CMD /bin/bash diff --git a/Makefile b/Makefile index 05ce4d298..749bc026b 100644 --- a/Makefile +++ b/Makefile @@ -14,6 +14,20 @@ lint: # exit-zero treats all errors as warnings. flake8 ${LINT_PATHS} --count --exit-zero --statistics +format: + # Sort imports + isort ${LINT_PATHS} + # Reformat using black + black -l 127 ${LINT_PATHS} + +check-codestyle: + # Sort imports + isort --check ${LINT_PATHS} + # Reformat using black + black --check -l 127 ${LINT_PATHS} + +commit-checks: format type lint + doc: cd docs && make html @@ -23,8 +37,6 @@ spelling: clean: cd docs && make clean -.PHONY: clean spelling doc lint - # Build docker images # If you do export RELEASE=True, it will also push them docker: docker-cpu docker-gpu @@ -46,3 +58,5 @@ test-release: python setup.py sdist python setup.py bdist_wheel twine upload --repository-url https://test.pypi.org/legacy/ dist/* + +.PHONY: clean spelling doc lint format check-codestyle commit-checks diff --git a/README.md b/README.md index d6338af7f..95e7c1536 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) - +[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) **WARNING: Stable Baselines3 is currently in a beta version, breaking changes may occur before 1.0 is released** diff --git a/docs/conf.py b/docs/conf.py index 78138f600..320834c0a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -20,12 +20,13 @@ # PyEnchant. try: import sphinxcontrib.spelling # noqa: F401 + enable_spell_check = True except ImportError: enable_spell_check = False # source code directory, relative to this file, for sphinx-autobuild -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) class Mock(MagicMock): @@ -44,18 +45,18 @@ def __getattr__(cls, name): sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Read version from file -version_file = os.path.join(os.path.dirname(__file__), '../stable_baselines3', 'version.txt') -with open(version_file, 'r') as file_handler: +version_file = os.path.join(os.path.dirname(__file__), "../stable_baselines3", "version.txt") +with open(version_file, "r") as file_handler: __version__ = file_handler.read().strip() # -- Project information ----------------------------------------------------- -project = 'Stable Baselines3' -copyright = '2020, Stable Baselines3' -author = 'Stable Baselines3 Contributors' +project = "Stable Baselines3" +copyright = "2020, Stable Baselines3" +author = "Stable Baselines3 Contributors" # The short X.Y version -version = 'master (' + __version__ + ' )' +version = "master (" + __version__ + " )" # The full version, including alpha/beta/rc tags release = __version__ @@ -70,30 +71,30 @@ def __getattr__(cls, name): # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', + "sphinx.ext.autodoc", # 'sphinx_autodoc_typehints', - 'sphinx.ext.autosummary', - 'sphinx.ext.mathjax', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', + "sphinx.ext.autosummary", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", # 'sphinx.ext.intersphinx', # 'sphinx.ext.doctest' ] if enable_spell_check: - extensions.append('sphinxcontrib.spelling') + extensions.append("sphinxcontrib.spelling") # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -105,10 +106,10 @@ def __getattr__(cls, name): # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # -- Options for HTML output ------------------------------------------------- @@ -117,13 +118,13 @@ def __getattr__(cls, name): # a list of builtin themes. # Fix for read the docs -on_rtd = os.environ.get('READTHEDOCS') == 'True' +on_rtd = os.environ.get("READTHEDOCS") == "True" if on_rtd: - html_theme = 'default' + html_theme = "default" else: - html_theme = 'sphinx_rtd_theme' + html_theme = "sphinx_rtd_theme" -html_logo = '_static/img/logo.png' +html_logo = "_static/img/logo.png" def setup(app): @@ -139,7 +140,7 @@ def setup(app): # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -155,7 +156,7 @@ def setup(app): # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'StableBaselines3doc' +htmlhelp_basename = "StableBaselines3doc" # -- Options for LaTeX output ------------------------------------------------ @@ -164,15 +165,12 @@ def setup(app): # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -182,8 +180,7 @@ def setup(app): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'StableBaselines3.tex', 'Stable Baselines3 Documentation', - 'Stable Baselines3 Contributors', 'manual'), + (master_doc, "StableBaselines3.tex", "Stable Baselines3 Documentation", "Stable Baselines3 Contributors", "manual"), ] @@ -191,10 +188,7 @@ def setup(app): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'stablebaselines3', 'Stable Baselines3 Documentation', - [author], 1) -] +man_pages = [(master_doc, "stablebaselines3", "Stable Baselines3 Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -203,9 +197,15 @@ def setup(app): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'StableBaselines3', 'Stable Baselines3 Documentation', - author, 'StableBaselines3', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "StableBaselines3", + "Stable Baselines3 Documentation", + author, + "StableBaselines3", + "One line description of project.", + "Miscellaneous", + ), ] diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bb5cecafa..f49fb4579 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -40,8 +40,10 @@ Others: - Split the ``collect_rollout()`` method for off-policy algorithms - Added ``_on_step()`` for off-policy base class - Optimized replay buffer size by removing the need of ``next_observations`` numpy array +- Switch to ``black`` codestyle and added ``make format``, ``make check-codestyle`` and ``commit-checks`` - Ignored errors from newer pytype version - Added a check when using ``gSDE`` +- Removed codacy dependency from Dockerfile Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index 1d0170067..011c3d9b1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,7 +21,7 @@ filterwarnings = inputs = stable_baselines3 [flake8] -ignore = W503,W504 # line breaks before and after binary operators +ignore = W503,W504,E203,E231 # line breaks before and after binary operators # Ignore import not used when aliases are defined per-file-ignores = ./stable_baselines3/__init__.py:F401 @@ -48,3 +48,8 @@ exclude = max-complexity = 15 # The GitHub editor is 127 chars wide max-line-length = 127 + +[isort] +profile = black +line_length = 127 +src_paths = stable_baselines3 diff --git a/setup.py b/setup.py index 5733441ee..8401c7c1d 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,8 @@ import os -from setuptools import setup, find_packages -with open(os.path.join('stable_baselines3', 'version.txt'), 'r') as file_handler: +from setuptools import find_packages, setup + +with open(os.path.join("stable_baselines3", "version.txt"), "r") as file_handler: __version__ = file_handler.read().strip() @@ -64,66 +65,69 @@ """ # noqa:E501 -setup(name='stable_baselines3', - packages=[package for package in find_packages() - if package.startswith('stable_baselines3')], - package_data={ - 'stable_baselines3': ['py.typed', 'version.txt'] - }, - install_requires=[ - 'gym>=0.17', - 'numpy', - 'torch>=1.4.0', - # For saving models - 'cloudpickle', - # For reading logs - 'pandas', - # Plotting learning curves - 'matplotlib' - ], - extras_require={ - 'tests': [ - # Run tests and coverage - 'pytest', - 'pytest-cov', - 'pytest-env', - 'pytest-xdist', - # Type check - 'pytype', - # Lint code - 'flake8>=3.8' - ], - 'docs': [ - 'sphinx', - 'sphinx-autobuild', - 'sphinx-rtd-theme', - # For spelling - 'sphinxcontrib.spelling', - # Type hints support - # 'sphinx-autodoc-typehints' - ], - 'extra': [ - # For render - 'opencv-python', - # For atari games, - 'atari_py~=0.2.0', 'pillow', - # Tensorboard support - 'tensorboard', - # Checking memory taken by replay buffer - 'psutil' - ] - }, - description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.', - author='Antonin Raffin', - url='https://github.com/DLR-RM/stable-baselines3', - author_email='antonin.raffin@dlr.de', - keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " - "gym openai stable baselines toolbox python data-science", - license="MIT", - long_description=long_description, - long_description_content_type='text/markdown', - version=__version__, - ) +setup( + name="stable_baselines3", + packages=[package for package in find_packages() if package.startswith("stable_baselines3")], + package_data={"stable_baselines3": ["py.typed", "version.txt"]}, + install_requires=[ + "gym>=0.17", + "numpy", + "torch>=1.4.0", + # For saving models + "cloudpickle", + # For reading logs + "pandas", + # Plotting learning curves + "matplotlib", + ], + extras_require={ + "tests": [ + # Run tests and coverage + "pytest", + "pytest-cov", + "pytest-env", + "pytest-xdist", + # Type check + "pytype", + # Lint code + "flake8>=3.8", + # Sort imports + "isort>=5.0", + # Reformat + "black", + ], + "docs": [ + "sphinx", + "sphinx-autobuild", + "sphinx-rtd-theme", + # For spelling + "sphinxcontrib.spelling", + # Type hints support + # 'sphinx-autodoc-typehints' + ], + "extra": [ + # For render + "opencv-python", + # For atari games, + "atari_py~=0.2.0", + "pillow", + # Tensorboard support + "tensorboard", + # Checking memory taken by replay buffer + "psutil", + ], + }, + description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.", + author="Antonin Raffin", + url="https://github.com/DLR-RM/stable-baselines3", + author_email="antonin.raffin@dlr.de", + keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " + "gym openai stable baselines toolbox python data-science", + license="MIT", + long_description=long_description, + long_description_content_type="text/markdown", + version=__version__, +) # python setup.py sdist # python setup.py bdist_wheel diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 1bdb2e9e9..b88ca5d4c 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -8,6 +8,6 @@ from stable_baselines3.td3 import TD3 # Read version from file -version_file = os.path.join(os.path.dirname(__file__), 'version.txt') -with open(version_file, 'r') as file_handler: +version_file = os.path.join(os.path.dirname(__file__), "version.txt") +with open(version_file, "r") as file_handler: __version__ = file_handler.read().strip() diff --git a/stable_baselines3/a2c/__init__.py b/stable_baselines3/a2c/__init__.py index f74a6497c..e6aeda5bb 100644 --- a/stable_baselines3/a2c/__init__.py +++ b/stable_baselines3/a2c/__init__.py @@ -1,2 +1,2 @@ from stable_baselines3.a2c.a2c import A2C -from stable_baselines3.a2c.policies import MlpPolicy, CnnPolicy +from stable_baselines3.a2c.policies import CnnPolicy, MlpPolicy diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index a27d864c5..c2c7b34e1 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,13 +1,14 @@ +from typing import Any, Callable, Dict, Optional, Type, Union + import torch as th -import torch.nn.functional as F from gym import spaces -from typing import Type, Union, Callable, Optional, Dict, Any +from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import explained_variance -from stable_baselines3.common.policies import ActorCriticPolicy class A2C(OnPolicyAlgorithm): @@ -50,44 +51,59 @@ class A2C(OnPolicyAlgorithm): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy: Union[str, Type[ActorCriticPolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 7e-4, - n_steps: int = 5, - gamma: float = 0.99, - gae_lambda: float = 1.0, - ent_coef: float = 0.0, - vf_coef: float = 0.5, - max_grad_norm: float = 0.5, - rms_prop_eps: float = 1e-5, - use_rms_prop: bool = True, - use_sde: bool = False, - sde_sample_freq: int = -1, - normalize_advantage: bool = False, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): - - super(A2C, self).__init__(policy, env, learning_rate=learning_rate, - n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, - ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, - use_sde=use_sde, sde_sample_freq=sde_sample_freq, - tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, - verbose=verbose, device=device, create_eval_env=create_eval_env, - seed=seed, _init_setup_model=False) + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 7e-4, + n_steps: int = 5, + gamma: float = 0.99, + gae_lambda: float = 1.0, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + rms_prop_eps: float = 1e-5, + use_rms_prop: bool = True, + use_sde: bool = False, + sde_sample_freq: int = -1, + normalize_advantage: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(A2C, self).__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + ) self.normalize_advantage = normalize_advantage # Update optimizer inside the policy if we want to use RMSProp # (original implementation) rather than Adam - if use_rms_prop and 'optimizer_class' not in self.policy_kwargs: - self.policy_kwargs['optimizer_class'] = th.optim.RMSprop - self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps, - weight_decay=0) + if use_rms_prop and "optimizer_class" not in self.policy_kwargs: + self.policy_kwargs["optimizer_class"] = th.optim.RMSprop + self.policy_kwargs["optimizer_kwargs"] = dict(alpha=0.99, eps=rms_prop_eps, weight_decay=0) if _init_setup_model: self._setup_model() @@ -139,8 +155,7 @@ def train(self) -> None: th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() - explained_var = explained_variance(self.rollout_buffer.returns.flatten(), - self.rollout_buffer.values.flatten()) + explained_var = explained_variance(self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten()) self._n_updates += 1 logger.record("train/n_updates", self._n_updates, exclude="tensorboard") @@ -148,21 +163,30 @@ def train(self) -> None: logger.record("train/entropy_loss", entropy_loss.item()) logger.record("train/policy_loss", policy_loss.item()) logger.record("train/value_loss", value_loss.item()) - if hasattr(self.policy, 'log_std'): + if hasattr(self.policy, "log_std"): logger.record("train/std", th.exp(self.policy.log_std).mean().item()) - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 100, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "A2C", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> 'A2C': - - return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, - log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, - n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, - eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps) + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 100, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "A2C", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "A2C": + + return super(A2C, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/stable_baselines3/a2c/policies.py b/stable_baselines3/a2c/policies.py index ae1160d20..eed0ddea1 100644 --- a/stable_baselines3/a2c/policies.py +++ b/stable_baselines3/a2c/policies.py @@ -1,6 +1,6 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for A2C -from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy diff --git a/stable_baselines3/common/__init__.py b/stable_baselines3/common/__init__.py index 32675071f..275e3ad27 100644 --- a/stable_baselines3/common/__init__.py +++ b/stable_baselines3/common/__init__.py @@ -1,2 +1,2 @@ -from stable_baselines3.common.cmd_util import make_vec_env, make_atari_env +from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env from stable_baselines3.common.utils import set_random_seed diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index 7af95fa0e..fdae2d327 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,8 +1,10 @@ import gym -from gym import spaces import numpy as np +from gym import spaces + try: import cv2 # pytype:disable=import-error + cv2.ocl.setUseOpenCL(False) except ImportError: cv2 = None @@ -23,7 +25,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30): self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 - assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + assert env.unwrapped.get_action_meanings()[0] == "NOOP" def reset(self, **kwargs) -> np.ndarray: self.env.reset(**kwargs) @@ -48,7 +50,7 @@ def __init__(self, env: gym.Env): :param env: (gym.Env) the environment to wrap """ gym.Wrapper.__init__(self, env) - assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 def reset(self, **kwargs) -> np.ndarray: @@ -180,8 +182,9 @@ def __init__(self, env: gym.Env, width: int = 84, height: int = 84): gym.ObservationWrapper.__init__(self, env) self.width = width self.height = height - self.observation_space = spaces.Box(low=0, high=255, shape=(self.height, self.width, 1), - dtype=env.observation_space.dtype) + self.observation_space = spaces.Box( + low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype + ) def observation(self, frame: np.ndarray) -> np.ndarray: """ @@ -217,17 +220,21 @@ class AtariWrapper(gym.Wrapper): life is lost. :param clip_reward: (bool) If True (default), the reward is clip to {-1, 0, 1} depending on its sign. """ - def __init__(self, env: gym.Env, - noop_max: int = 30, - frame_skip: int = 4, - screen_size: int = 84, - terminal_on_life_loss: bool = True, - clip_reward: bool = True): + + def __init__( + self, + env: gym.Env, + noop_max: int = 30, + frame_skip: int = 4, + screen_size: int = 84, + terminal_on_life_loss: bool = True, + clip_reward: bool = True, + ): env = NoopResetEnv(env, noop_max=noop_max) env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) - if 'FIRE' in env.unwrapped.get_action_meanings(): + if "FIRE" in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) env = WarpFrame(env, width=screen_size, height=screen_size) if clip_reward: diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 56f9f9642..1ba7cc0ce 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -1,28 +1,32 @@ """Abstract base classes for RL algorithms.""" +import io +import pathlib import time -from typing import Union, Type, Optional, Dict, Any, Iterable, List, Tuple, Callable from abc import ABC, abstractmethod from collections import deque -import pathlib -import io +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union import gym -import torch as th import numpy as np +import torch as th from stable_baselines3.common import logger, utils -from stable_baselines3.common.policies import BasePolicy, get_policy_from_name -from stable_baselines3.common.utils import (set_random_seed, get_schedule_fn, update_learning_rate, get_device, - check_for_correct_spaces) -from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage -from stable_baselines3.common.preprocessing import is_image_space -from stable_baselines3.common.save_util import (recursive_getattr, recursive_setattr, save_to_zip_file, - load_from_zip_file) -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.policies import BasePolicy, get_policy_from_name +from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback +from stable_baselines3.common.utils import ( + check_for_correct_spaces, + get_device, + get_schedule_fn, + set_random_seed, + update_learning_rate, +) +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecNormalize, VecTransposeImage, unwrap_vec_normalize def maybe_make_env(env: Union[GymEnv, str, None], monitor_wrapper: bool, verbose: int) -> Optional[GymEnv]: @@ -72,21 +76,23 @@ class BaseAlgorithm(ABC): Default: -1 (only sample at the beginning of the rollout) """ - def __init__(self, - policy: Type[BasePolicy], - env: Union[GymEnv, str, None], - policy_base: Type[BasePolicy], - learning_rate: Union[float, Callable], - policy_kwargs: Dict[str, Any] = None, - tensorboard_log: Optional[str] = None, - verbose: int = 0, - device: Union[th.device, str] = 'auto', - support_multi_env: bool = False, - create_eval_env: bool = False, - monitor_wrapper: bool = True, - seed: Optional[int] = None, - use_sde: bool = False, - sde_sample_freq: int = -1): + def __init__( + self, + policy: Type[BasePolicy], + env: Union[GymEnv, str, None], + policy_base: Type[BasePolicy], + learning_rate: Union[float, Callable], + policy_kwargs: Dict[str, Any] = None, + tensorboard_log: Optional[str] = None, + verbose: int = 0, + device: Union[th.device, str] = "auto", + support_multi_env: bool = False, + create_eval_env: bool = False, + monitor_wrapper: bool = True, + seed: Optional[int] = None, + use_sde: bool = False, + sde_sample_freq: int = -1, + ): if isinstance(policy, str) and policy_base is not None: self.policy_class = get_policy_from_name(policy_base, policy) @@ -147,12 +153,12 @@ def __init__(self, self.env = env if not support_multi_env and self.n_envs > 1: - raise ValueError("Error: the model does not support multiple envs; it requires " - "a single vectorized environment.") + raise ValueError( + "Error: the model does not support multiple envs; it requires " "a single vectorized environment." + ) if self.use_sde and not isinstance(self.observation_space, gym.spaces.Box): - raise ValueError("generalized State-Dependent Exploration (gSDE) can only " - "be used with continuous actions.") + raise ValueError("generalized State-Dependent Exploration (gSDE) can only be used with continuous actions.") def _wrap_env(self, env: GymEnv) -> VecEnv: if not isinstance(env, VecEnv): @@ -262,15 +268,18 @@ def get_torch_variables(self) -> Tuple[List[str], List[str]]: return state_dicts, [] @abstractmethod - def learn(self, total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 100, - tb_log_name: str = "run", - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> 'BaseAlgorithm': + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 100, + tb_log_name: str = "run", + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "BaseAlgorithm": """ Return a trained model. @@ -286,10 +295,13 @@ def learn(self, total_timesteps: int, :return: (BaseAlgorithm) the trained model """ - def predict(self, observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, - deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]: + def predict( + self, + observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Get the model's action(s) from an observation @@ -303,7 +315,7 @@ def predict(self, observation: np.ndarray, return self.policy.predict(observation, state, mask, deterministic) @classmethod - def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAlgorithm': + def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> "BaseAlgorithm": """ Load the model from a zip-file @@ -314,14 +326,16 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl """ data, params, tensors = load_from_zip_file(load_path) - if 'policy_kwargs' in data: - for arg_to_remove in ['device']: - if arg_to_remove in data['policy_kwargs']: - del data['policy_kwargs'][arg_to_remove] + if "policy_kwargs" in data: + for arg_to_remove in ["device"]: + if arg_to_remove in data["policy_kwargs"]: + del data["policy_kwargs"][arg_to_remove] - if 'policy_kwargs' in kwargs and kwargs['policy_kwargs'] != data['policy_kwargs']: - raise ValueError(f"The specified policy kwargs do not equal the stored policy kwargs." - f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}") + if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: + raise ValueError( + f"The specified policy kwargs do not equal the stored policy kwargs." + f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" + ) # check if observation space and action space are part of the saved parameters if "observation_space" not in data or "action_space" not in data: @@ -334,8 +348,12 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs) -> 'BaseAl env = data["env"] # noinspection PyArgumentList - model = cls(policy=data["policy_class"], env=env, - device='auto', _init_setup_model=False) # pytype: disable=not-instantiable,wrong-keyword-args + model = cls( + policy=data["policy_class"], + env=env, + device="auto", + _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args + ) # load parameters model.__dict__.update(data) @@ -367,19 +385,21 @@ def set_random_seed(self, seed: Optional[int] = None) -> None: """ if seed is None: return - set_random_seed(seed, using_cuda=self.device == th.device('cuda')) + set_random_seed(seed, using_cuda=self.device == th.device("cuda")) self.action_space.seed(seed) if self.env is not None: self.env.seed(seed) if self.eval_env is not None: self.eval_env.seed(seed) - def _init_callback(self, - callback: MaybeCallback, - eval_env: Optional[VecEnv] = None, - eval_freq: int = 10000, - n_eval_episodes: int = 5, - log_path: Optional[str] = None) -> BaseCallback: + def _init_callback( + self, + callback: MaybeCallback, + eval_env: Optional[VecEnv] = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + ) -> BaseCallback: """ :param callback: (MaybeCallback) Callback(s) called at every step with state of the algorithm. :param eval_freq: (Optional[int]) How many steps between evaluations; if None, do not evaluate. @@ -398,24 +418,29 @@ def _init_callback(self, # Create eval callback in charge of the evaluation if eval_env is not None: - eval_callback = EvalCallback(eval_env, - best_model_save_path=log_path, - log_path=log_path, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes) + eval_callback = EvalCallback( + eval_env, + best_model_save_path=log_path, + log_path=log_path, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + ) callback = CallbackList([callback, eval_callback]) callback.init_callback(self) return callback - def _setup_learn(self, - total_timesteps: int, - eval_env: Optional[GymEnv], - callback: MaybeCallback = None, - eval_freq: int = 10000, - n_eval_episodes: int = 5, - log_path: Optional[str] = None, - reset_num_timesteps: bool = True, - tb_log_name: str = 'run', - ) -> Tuple[int, BaseCallback]: + def _setup_learn( + self, + total_timesteps: int, + eval_env: Optional[GymEnv], + callback: MaybeCallback = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "run", + ) -> Tuple[int, BaseCallback]: """ Initialize different variables needed for training. @@ -476,8 +501,8 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd if dones is None: dones = np.array([False] * len(infos)) for idx, info in enumerate(infos): - maybe_ep_info = info.get('episode') - maybe_is_success = info.get('is_success') + maybe_ep_info = info.get("episode") + maybe_is_success = info.get("is_success") if maybe_ep_info is not None: self.ep_info_buffer.extend([maybe_ep_info]) if maybe_is_success is not None and dones[idx]: @@ -522,7 +547,7 @@ def save( torch_variables = state_dicts_names + tensors_names for torch_var in torch_variables: # we need to get only the name of the top most module as we'll remove that - var_name = torch_var.split('.')[0] + var_name = torch_var.split(".")[0] exclude.add(var_name) # Remove parameter entries of parameters which are to be excluded diff --git a/stable_baselines3/common/bit_flipping_env.py b/stable_baselines3/common/bit_flipping_env.py index 1dd881502..b579fe157 100644 --- a/stable_baselines3/common/bit_flipping_env.py +++ b/stable_baselines3/common/bit_flipping_env.py @@ -21,27 +21,31 @@ class BitFlippingEnv(GoalEnv): :param discrete_obs_space: (bool) Whether to use the discrete observation version or not, by default, it uses the MultiBinary one """ - def __init__(self, n_bits: int = 10, - continuous: bool = False, - max_steps: Optional[int] = None, - discrete_obs_space: bool = False): + + def __init__( + self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False + ): super(BitFlippingEnv, self).__init__() # The achieved goal is determined by the current state # here, it is a special where they are equal if discrete_obs_space: # In the discrete case, the agent act on the binary # representation of the observation - self.observation_space = spaces.Dict({ - 'observation': spaces.Discrete(2 ** n_bits - 1), - 'achieved_goal': spaces.Discrete(2 ** n_bits - 1), - 'desired_goal': spaces.Discrete(2 ** n_bits - 1) - }) + self.observation_space = spaces.Dict( + { + "observation": spaces.Discrete(2 ** n_bits - 1), + "achieved_goal": spaces.Discrete(2 ** n_bits - 1), + "desired_goal": spaces.Discrete(2 ** n_bits - 1), + } + ) else: - self.observation_space = spaces.Dict({ - 'observation': spaces.MultiBinary(n_bits), - 'achieved_goal': spaces.MultiBinary(n_bits), - 'desired_goal': spaces.MultiBinary(n_bits) - }) + self.observation_space = spaces.Dict( + { + "observation": spaces.MultiBinary(n_bits), + "achieved_goal": spaces.MultiBinary(n_bits), + "desired_goal": spaces.MultiBinary(n_bits), + } + ) self.obs_space = spaces.MultiBinary(n_bits) @@ -69,7 +73,7 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: if self.discrete_obs_space: # The internal state is the binary representation of the # observed one - return int(sum([state[i] * 2**i for i in range(len(state))])) + return int(sum([state[i] * 2 ** i for i in range(len(state))])) return state def _get_obs(self) -> OrderedDict: @@ -78,11 +82,13 @@ def _get_obs(self) -> OrderedDict: :return: (OrderedDict) """ - return OrderedDict([ - ('observation', self.convert_if_needed(self.state.copy())), - ('achieved_goal', self.convert_if_needed(self.state.copy())), - ('desired_goal', self.convert_if_needed(self.desired_goal.copy())) - ]) + return OrderedDict( + [ + ("observation", self.convert_if_needed(self.state.copy())), + ("achieved_goal", self.convert_if_needed(self.state.copy())), + ("desired_goal", self.convert_if_needed(self.desired_goal.copy())), + ] + ) def reset(self) -> OrderedDict: self.current_step = 0 @@ -95,25 +101,22 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: else: self.state[action] = 1 - self.state[action] obs = self._get_obs() - reward = self.compute_reward(obs['achieved_goal'], obs['desired_goal'], None) + reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None) done = reward == 0 self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps - info = {'is_success': done} + info = {"is_success": done} done = done or self.current_step >= self.max_steps return obs, reward, done, info - def compute_reward(self, - achieved_goal: np.ndarray, - desired_goal: np.ndarray, - _info) -> float: + def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> float: # Deceptive reward: it is positive only when the goal is achieved if self.discrete_obs_space: return 0.0 if achieved_goal == desired_goal else -1.0 return 0.0 if (achieved_goal == desired_goal).all() else -1.0 - def render(self, mode: str = 'human') -> Optional[np.ndarray]: - if mode == 'rgb_array': + def render(self, mode: str = "human") -> Optional[np.ndarray]: + if mode == "rgb_array": return self.state.copy() print(self.state) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 24ed8aee3..4534063a2 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -1,5 +1,5 @@ -from typing import Union, Optional, Generator import warnings +from typing import Generator, Optional, Union import numpy as np import torch as th @@ -11,9 +11,9 @@ except ImportError: psutil = None -from stable_baselines3.common.vec_env import VecNormalize -from stable_baselines3.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape +from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples +from stable_baselines3.common.vec_env import VecNormalize class BaseBuffer(object): @@ -28,12 +28,14 @@ class BaseBuffer(object): :param n_envs: (int) Number of parallel environments """ - def __init__(self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[th.device, str] = 'cpu', - n_envs: int = 1): + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + ): super(BaseBuffer, self).__init__() self.buffer_size = buffer_size self.observation_space = observation_space @@ -89,10 +91,7 @@ def reset(self) -> None: self.pos = 0 self.full = False - def sample(self, - batch_size: int, - env: Optional[VecNormalize] = None - ): + def sample(self, batch_size: int, env: Optional[VecNormalize] = None): """ :param batch_size: (int) Number of element to sample :param env: (Optional[VecNormalize]) associated gym VecEnv @@ -103,10 +102,7 @@ def sample(self, batch_inds = np.random.randint(0, upper_bound, size=batch_size) return self._get_samples(batch_inds, env=env) - def _get_samples(self, - batch_inds: np.ndarray, - env: Optional[VecNormalize] = None - ): + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None): """ :param batch_inds: (th.Tensor) :param env: (Optional[VecNormalize]) @@ -129,15 +125,13 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: return th.as_tensor(array).to(self.device) @staticmethod - def _normalize_obs(obs: np.ndarray, - env: Optional[VecNormalize] = None) -> np.ndarray: + def _normalize_obs(obs: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray: if env is not None: return env.normalize_obs(obs).astype(np.float32) return obs @staticmethod - def _normalize_reward(reward: np.ndarray, - env: Optional[VecNormalize] = None) -> np.ndarray: + def _normalize_reward(reward: np.ndarray, env: Optional[VecNormalize] = None) -> np.ndarray: if env is not None: return env.normalize_reward(reward).astype(np.float32) return reward @@ -158,15 +152,17 @@ class ReplayBuffer(BaseBuffer): See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 """ - def __init__(self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[th.device, str] = 'cpu', - n_envs: int = 1, - optimize_memory_usage: bool = False): - super(ReplayBuffer, self).__init__(buffer_size, observation_space, - action_space, device, n_envs=n_envs) + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + optimize_memory_usage: bool = False, + ): + super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) assert n_envs == 1, "Replay buffer only support single environment for now" @@ -186,8 +182,7 @@ def __init__(self, self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) if psutil is not None: - total_memory_usage = (self.observations.nbytes + self.actions.nbytes - + self.rewards.nbytes + self.dones.nbytes) + total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes if self.next_observations is not None: total_memory_usage += self.next_observations.nbytes @@ -195,15 +190,12 @@ def __init__(self, # Convert to GB total_memory_usage /= 1e9 mem_available /= 1e9 - warnings.warn("This system does not have apparently enough memory to store the complete " - f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB") - - def add(self, - obs: np.ndarray, - next_obs: np.ndarray, - action: np.ndarray, - reward: np.ndarray, - done: np.ndarray) -> None: + warnings.warn( + "This system does not have apparently enough memory to store the complete " + f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" + ) + + def add(self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray) -> None: # Copy to avoid modification by reference self.observations[self.pos] = np.array(obs).copy() if self.optimize_memory_usage: @@ -220,10 +212,7 @@ def add(self, self.full = True self.pos = 0 - def sample(self, - batch_size: int, - env: Optional[VecNormalize] = None - ) -> ReplayBufferSamples: + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: """ Sample elements from the replay buffer. Custom sampling when using memory efficient variant, @@ -245,20 +234,19 @@ def sample(self, batch_inds = np.random.randint(0, self.pos, size=batch_size) return self._get_samples(batch_inds, env=env) - def _get_samples(self, - batch_inds: np.ndarray, - env: Optional[VecNormalize] = None - ) -> ReplayBufferSamples: + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: if self.optimize_memory_usage: next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, 0, :], env) else: next_obs = self._normalize_obs(self.next_observations[batch_inds, 0, :], env) - data = (self._normalize_obs(self.observations[batch_inds, 0, :], env), - self.actions[batch_inds, 0, :], - next_obs, - self.dones[batch_inds], - self._normalize_reward(self.rewards[batch_inds], env)) + data = ( + self._normalize_obs(self.observations[batch_inds, 0, :], env), + self.actions[batch_inds, 0, :], + next_obs, + self.dones[batch_inds], + self._normalize_reward(self.rewards[batch_inds], env), + ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) @@ -276,17 +264,18 @@ class RolloutBuffer(BaseBuffer): :param n_envs: (int) Number of parallel environments """ - def __init__(self, - buffer_size: int, - observation_space: spaces.Space, - action_space: spaces.Space, - device: Union[th.device, str] = 'cpu', - gae_lambda: float = 1, - gamma: float = 0.99, - n_envs: int = 1): - - super(RolloutBuffer, self).__init__(buffer_size, observation_space, - action_space, device, n_envs=n_envs) + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) self.gae_lambda = gae_lambda self.gamma = gamma self.observations, self.actions, self.rewards, self.advantages = None, None, None, None @@ -295,8 +284,7 @@ def __init__(self, self.reset() def reset(self) -> None: - self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, - dtype=np.float32) + self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=np.float32) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -307,9 +295,7 @@ def reset(self) -> None: self.generator_ready = False super(RolloutBuffer, self).reset() - def compute_returns_and_advantage(self, - last_value: th.Tensor, - dones: np.ndarray) -> None: + def compute_returns_and_advantage(self, last_value: th.Tensor, dones: np.ndarray) -> None: """ Post-processing step: compute the returns (sum of discounted rewards) and GAE advantage. @@ -340,13 +326,9 @@ def compute_returns_and_advantage(self, self.advantages[step] = last_gae_lam self.returns = self.advantages + self.values - def add(self, - obs: np.ndarray, - action: np.ndarray, - reward: np.ndarray, - done: np.ndarray, - value: th.Tensor, - log_prob: th.Tensor) -> None: + def add( + self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray, value: th.Tensor, log_prob: th.Tensor + ) -> None: """ :param obs: (np.ndarray) Observation :param action: (np.ndarray) Action @@ -372,12 +354,11 @@ def add(self, self.full = True def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSamples, None, None]: - assert self.full, '' + assert self.full, "" indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: - for tensor in ['observations', 'actions', 'values', - 'log_probs', 'advantages', 'returns']: + for tensor in ["observations", "actions", "values", "log_probs", "advantages", "returns"]: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True @@ -387,15 +368,16 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample start_idx = 0 while start_idx < self.buffer_size * self.n_envs: - yield self._get_samples(indices[start_idx:start_idx + batch_size]) + yield self._get_samples(indices[start_idx : start_idx + batch_size]) start_idx += batch_size - def _get_samples(self, batch_inds: np.ndarray, - env: Optional[VecNormalize] = None) -> RolloutBufferSamples: - data = (self.observations[batch_inds], - self.actions[batch_inds], - self.values[batch_inds].flatten(), - self.log_probs[batch_inds].flatten(), - self.advantages[batch_inds].flatten(), - self.returns[batch_inds].flatten()) + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> RolloutBufferSamples: + data = ( + self.observations[batch_inds], + self.actions[batch_inds], + self.values[batch_inds].flatten(), + self.log_probs[batch_inds].flatten(), + self.advantages[batch_inds].flatten(), + self.returns[batch_inds].flatten(), + ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index ccd812f08..c5e53e58e 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -1,15 +1,15 @@ import os -from abc import ABC, abstractmethod -import warnings import typing -from typing import Union, List, Dict, Any, Optional +import warnings +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union import gym import numpy as np -from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization -from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common import logger +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization if typing.TYPE_CHECKING: from stable_baselines3.common.base_class import BaseAlgorithm # pytype: disable=pyi-error @@ -21,6 +21,7 @@ class BaseCallback(ABC): :param verbose: (int) """ + def __init__(self, verbose: int = 0): super(BaseCallback, self).__init__() # The RL model @@ -40,7 +41,7 @@ def __init__(self, verbose: int = 0): self.parent = None # type: Optional[BaseCallback] # Type hint as string to avoid circular import - def init_callback(self, model: 'BaseAlgorithm') -> None: + def init_callback(self, model: "BaseAlgorithm") -> None: """ Initialize the callback by saving references to the RL model and the training environment for convenience. @@ -111,6 +112,7 @@ class EventCallback(BaseCallback): when an event is triggered. :param verbose: (int) """ + def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): super(EventCallback, self).__init__(verbose=verbose) self.callback = callback @@ -118,7 +120,7 @@ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): if callback is not None: self.callback.parent = self - def init_callback(self, model: 'BaseAlgorithm') -> None: + def init_callback(self, model: "BaseAlgorithm") -> None: super(EventCallback, self).init_callback(model) if self.callback is not None: self.callback.init_callback(self.model) @@ -143,6 +145,7 @@ class CallbackList(BaseCallback): :param callbacks: (List[BaseCallback]) A list of callbacks that will be called sequentially. """ + def __init__(self, callbacks: List[BaseCallback]): super(CallbackList, self).__init__() assert isinstance(callbacks, list) @@ -184,7 +187,8 @@ class CheckpointCallback(BaseCallback): :param save_path: (str) Path to the folder where the model will be saved. :param name_prefix: (str) Common prefix to the saved models """ - def __init__(self, save_freq: int, save_path: str, name_prefix='rl_model', verbose=0): + + def __init__(self, save_freq: int, save_path: str, name_prefix="rl_model", verbose=0): super(CheckpointCallback, self).__init__(verbose) self.save_freq = save_freq self.save_path = save_path @@ -197,7 +201,7 @@ def _init_callback(self) -> None: def _on_step(self) -> bool: if self.n_calls % self.save_freq == 0: - path = os.path.join(self.save_path, f'{self.name_prefix}_{self.num_timesteps}_steps') + path = os.path.join(self.save_path, f"{self.name_prefix}_{self.num_timesteps}_steps") self.model.save(path) if self.verbose > 1: print(f"Saving model checkpoint to {path}") @@ -211,6 +215,7 @@ class ConvertCallback(BaseCallback): :param callback: (callable) :param verbose: (int) """ + def __init__(self, callback, verbose=0): super(ConvertCallback, self).__init__(verbose) self.callback = callback @@ -240,15 +245,19 @@ class EvalCallback(EventCallback): :param render: (bool) Whether to render or not the environment during evaluation :param verbose: (int) """ - def __init__(self, eval_env: Union[gym.Env, VecEnv], - callback_on_new_best: Optional[BaseCallback] = None, - n_eval_episodes: int = 5, - eval_freq: int = 10000, - log_path: str = None, - best_model_save_path: str = None, - deterministic: bool = True, - render: bool = False, - verbose: int = 1): + + def __init__( + self, + eval_env: Union[gym.Env, VecEnv], + callback_on_new_best: Optional[BaseCallback] = None, + n_eval_episodes: int = 5, + eval_freq: int = 10000, + log_path: str = None, + best_model_save_path: str = None, + deterministic: bool = True, + render: bool = False, + verbose: int = 1, + ): super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose) self.n_eval_episodes = n_eval_episodes self.eval_freq = eval_freq @@ -268,7 +277,7 @@ def __init__(self, eval_env: Union[gym.Env, VecEnv], self.best_model_save_path = best_model_save_path # Logs will be written in ``evaluations.npz`` if log_path is not None: - log_path = os.path.join(log_path, 'evaluations') + log_path = os.path.join(log_path, "evaluations") self.log_path = log_path self.evaluations_results = [] self.evaluations_timesteps = [] @@ -277,8 +286,7 @@ def __init__(self, eval_env: Union[gym.Env, VecEnv], def _init_callback(self): # Does not work in some corner cases, where the wrapper is not the same if not isinstance(self.training_env, type(self.eval_env)): - warnings.warn("Training and eval env are not of the same type" - f"{self.training_env} != {self.eval_env}") + warnings.warn("Training and eval env are not of the same type" f"{self.training_env} != {self.eval_env}") # Create folders if needed if self.best_model_save_path is not None: @@ -292,36 +300,42 @@ def _on_step(self) -> bool: # Sync training and eval env if there is VecNormalize sync_envs_normalization(self.training_env, self.eval_env) - episode_rewards, episode_lengths = evaluate_policy(self.model, self.eval_env, - n_eval_episodes=self.n_eval_episodes, - render=self.render, - deterministic=self.deterministic, - return_episode_rewards=True) + episode_rewards, episode_lengths = evaluate_policy( + self.model, + self.eval_env, + n_eval_episodes=self.n_eval_episodes, + render=self.render, + deterministic=self.deterministic, + return_episode_rewards=True, + ) if self.log_path is not None: self.evaluations_timesteps.append(self.num_timesteps) self.evaluations_results.append(episode_rewards) self.evaluations_length.append(episode_lengths) - np.savez(self.log_path, timesteps=self.evaluations_timesteps, - results=self.evaluations_results, ep_lengths=self.evaluations_length) + np.savez( + self.log_path, + timesteps=self.evaluations_timesteps, + results=self.evaluations_results, + ep_lengths=self.evaluations_length, + ) mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) self.last_mean_reward = mean_reward if self.verbose > 0: - print(f"Eval num_timesteps={self.num_timesteps}, " - f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") + print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") print(f"Episode length: {mean_ep_length:.2f} +/- {std_ep_length:.2f}") # Add to current Logger - self.logger.record('eval/mean_reward', float(mean_reward)) - self.logger.record('eval/mean_ep_length', mean_ep_length) + self.logger.record("eval/mean_reward", float(mean_reward)) + self.logger.record("eval/mean_ep_length", mean_ep_length) if mean_reward > self.best_mean_reward: if self.verbose > 0: print("New best mean reward!") if self.best_model_save_path is not None: - self.model.save(os.path.join(self.best_model_save_path, 'best_model')) + self.model.save(os.path.join(self.best_model_save_path, "best_model")) self.best_mean_reward = mean_reward # Trigger callback if needed if self.callback is not None: @@ -341,18 +355,20 @@ class StopTrainingOnRewardThreshold(BaseCallback): to stop training. :param verbose: (int) """ + def __init__(self, reward_threshold: float, verbose: int = 0): super(StopTrainingOnRewardThreshold, self).__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: - assert self.parent is not None, ("``StopTrainingOnMinimumReward`` callback must be used " - "with an ``EvalCallback``") + assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``" # Convert np.bool to bool, otherwise callback() is False won't work continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose > 0 and not continue_training: - print(f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} " - f" is above the threshold {self.reward_threshold}") + print( + f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} " + f" is above the threshold {self.reward_threshold}" + ) return continue_training @@ -364,6 +380,7 @@ class EveryNTimesteps(EventCallback): :param callback: (BaseCallback) Callback that will be called when the event is triggered. """ + def __init__(self, n_steps: int, callback: BaseCallback): super(EveryNTimesteps, self).__init__(callback) self.n_steps = n_steps diff --git a/stable_baselines3/common/cmd_util.py b/stable_baselines3/common/cmd_util.py index ceeced7b7..4f49ccce5 100644 --- a/stable_baselines3/common/cmd_util.py +++ b/stable_baselines3/common/cmd_util.py @@ -1,23 +1,25 @@ import os import warnings -from typing import Dict, Any, Optional, Callable, Type, Union +from typing import Any, Callable, Dict, Optional, Type, Union import gym -from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.atari_wrappers import AtariWrapper +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -def make_vec_env(env_id: Union[str, Type[gym.Env]], - n_envs: int = 1, - seed: Optional[int] = None, - start_index: int = 0, - monitor_dir: Optional[str] = None, - wrapper_class: Optional[Callable] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None): +def make_vec_env( + env_id: Union[str, Type[gym.Env]], + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + monitor_dir: Optional[str] = None, + wrapper_class: Optional[Callable] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, +): """ Create a wrapped, monitored ``VecEnv``. By default it uses a ``DummyVecEnv`` which is usually faster @@ -62,6 +64,7 @@ def _init(): if wrapper_class is not None: env = wrapper_class(env) return env + return _init # No custom VecEnv is passed @@ -72,15 +75,17 @@ def _init(): return vec_env_cls([make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs) -def make_atari_env(env_id: Union[str, Type[gym.Env]], - n_envs: int = 1, - seed: Optional[int] = None, - start_index: int = 0, - monitor_dir: Optional[str] = None, - wrapper_kwargs: Optional[Dict[str, Any]] = None, - env_kwargs: Optional[Dict[str, Any]] = None, - vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, - vec_env_kwargs: Optional[Dict[str, Any]] = None): +def make_atari_env( + env_id: Union[str, Type[gym.Env]], + n_envs: int = 1, + seed: Optional[int] = None, + start_index: int = 0, + monitor_dir: Optional[str] = None, + wrapper_kwargs: Optional[Dict[str, Any]] = None, + env_kwargs: Optional[Dict[str, Any]] = None, + vec_env_cls: Optional[Union[DummyVecEnv, SubprocVecEnv]] = None, + vec_env_kwargs: Optional[Dict[str, Any]] = None, +): """ Create a wrapped, monitored VecEnv for Atari. It is a wrapper around ``make_vec_env`` that includes common preprocessing for Atari games. @@ -105,6 +110,14 @@ def atari_wrapper(env: gym.Env) -> gym.Env: env = AtariWrapper(env, **wrapper_kwargs) return env - return make_vec_env(env_id, n_envs=n_envs, seed=seed, start_index=start_index, - monitor_dir=monitor_dir, wrapper_class=atari_wrapper, - env_kwargs=env_kwargs, vec_env_cls=vec_env_cls, vec_env_kwargs=vec_env_kwargs) + return make_vec_env( + env_id, + n_envs=n_envs, + seed=seed, + start_index=start_index, + monitor_dir=monitor_dir, + wrapper_class=atari_wrapper, + env_kwargs=env_kwargs, + vec_env_cls=vec_env_cls, + vec_env_kwargs=vec_env_kwargs, + ) diff --git a/stable_baselines3/common/distributions.py b/stable_baselines3/common/distributions.py index 1a20928f2..f46691de6 100644 --- a/stable_baselines3/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -1,12 +1,13 @@ """Probability distributions.""" from abc import ABC, abstractmethod -from typing import Optional, Tuple, Dict, Any, List +from typing import Any, Dict, List, Optional, Tuple + import gym import torch as th -import torch.nn as nn -from torch.distributions import Normal, Categorical, Bernoulli from gym import spaces +from torch import nn as nn +from torch.distributions import Bernoulli, Categorical, Normal from stable_baselines3.common.preprocessing import get_action_dim @@ -25,7 +26,7 @@ def proba_distribution_net(self, *args, **kwargs): concrete classes.""" @abstractmethod - def proba_distribution(self, *args, **kwargs) -> 'Distribution': + def proba_distribution(self, *args, **kwargs) -> "Distribution": """Set parameters of the distribution. :return: (Distribution) self @@ -124,8 +125,7 @@ def __init__(self, action_dim: int): self.mean_actions = None self.log_std = None - def proba_distribution_net(self, latent_dim: int, - log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: + def proba_distribution_net(self, latent_dim: int, log_std_init: float = 0.0) -> Tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the mean of the Gaussian, the other parameter will be the @@ -140,8 +140,7 @@ def proba_distribution_net(self, latent_dim: int, log_std = nn.Parameter(th.ones(self.action_dim) * log_std_init, requires_grad=True) return mean_actions, log_std - def proba_distribution(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> 'DiagGaussianDistribution': + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution": """ Create the distribution given its parameters (mean, std) @@ -174,15 +173,12 @@ def sample(self) -> th.Tensor: def mode(self) -> th.Tensor: return self.distribution.mean - def actions_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor, - deterministic: bool = False) -> th.Tensor: + def actions_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(mean_actions, log_std) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ Compute the log probability of taking an action given the distribution parameters. @@ -210,13 +206,11 @@ def __init__(self, action_dim: int, epsilon: float = 1e-6): self.epsilon = epsilon self.gaussian_actions = None - def proba_distribution(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> 'SquashedDiagGaussianDistribution': + def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "SquashedDiagGaussianDistribution": super(SquashedDiagGaussianDistribution, self).proba_distribution(mean_actions, log_std) return self - def log_prob(self, actions: th.Tensor, - gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: + def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor: # Inverse tanh # Naive implementation (not stable): 0.5 * torch.log((1 + x) / (1 - x)) # We use numpy to avoid numerical instability @@ -246,8 +240,7 @@ def mode(self) -> th.Tensor: # Squash the output return th.tanh(self.gaussian_actions) - def log_prob_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params(self, mean_actions: th.Tensor, log_std: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: action = self.actions_from_params(mean_actions, log_std) log_prob = self.log_prob(action, self.gaussian_actions) return action, log_prob @@ -278,7 +271,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: action_logits = nn.Linear(latent_dim, self.action_dim) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> 'CategoricalDistribution': + def proba_distribution(self, action_logits: th.Tensor) -> "CategoricalDistribution": self.distribution = Categorical(logits=action_logits) return self @@ -294,8 +287,7 @@ def sample(self) -> th.Tensor: def mode(self) -> th.Tensor: return th.argmax(self.distribution.probs, dim=1) - def actions_from_params(self, action_logits: th.Tensor, - deterministic: bool = False) -> th.Tensor: + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) @@ -332,14 +324,15 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: action_logits = nn.Linear(latent_dim, sum(self.action_dims)) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> 'MultiCategoricalDistribution': + def proba_distribution(self, action_logits: th.Tensor) -> "MultiCategoricalDistribution": self.distributions = [Categorical(logits=split) for split in th.split(action_logits, tuple(self.action_dims), dim=1)] return self def log_prob(self, actions: th.Tensor) -> th.Tensor: # Extract each discrete action and compute log prob for their respective distributions - return th.stack([dist.log_prob(action) for dist, action in zip(self.distributions, - th.unbind(actions, dim=1))], dim=1).sum(dim=1) + return th.stack( + [dist.log_prob(action) for dist, action in zip(self.distributions, th.unbind(actions, dim=1))], dim=1 + ).sum(dim=1) def entropy(self) -> th.Tensor: return th.stack([dist.entropy() for dist in self.distributions], dim=1).sum(dim=1) @@ -350,8 +343,7 @@ def sample(self) -> th.Tensor: def mode(self) -> th.Tensor: return th.stack([th.argmax(dist.probs, dim=1) for dist in self.distributions], dim=1) - def actions_from_params(self, action_logits: th.Tensor, - deterministic: bool = False) -> th.Tensor: + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) @@ -386,7 +378,7 @@ def proba_distribution_net(self, latent_dim: int) -> nn.Module: action_logits = nn.Linear(latent_dim, self.action_dims) return action_logits - def proba_distribution(self, action_logits: th.Tensor) -> 'BernoulliDistribution': + def proba_distribution(self, action_logits: th.Tensor) -> "BernoulliDistribution": self.distribution = Bernoulli(logits=action_logits) return self @@ -402,8 +394,7 @@ def sample(self) -> th.Tensor: def mode(self) -> th.Tensor: return th.round(self.distribution.probs) - def actions_from_params(self, action_logits: th.Tensor, - deterministic: bool = False) -> th.Tensor: + def actions_from_params(self, action_logits: th.Tensor, deterministic: bool = False) -> th.Tensor: # Update the proba distribution self.proba_distribution(action_logits) return self.get_actions(deterministic=deterministic) @@ -436,12 +427,15 @@ class StateDependentNoiseDistribution(Distribution): :param epsilon: (float) small value to avoid NaN due to numerical imprecision. """ - def __init__(self, action_dim: int, - full_std: bool = True, - use_expln: bool = False, - squash_output: bool = False, - learn_features: bool = False, - epsilon: float = 1e-6): + def __init__( + self, + action_dim: int, + full_std: bool = True, + use_expln: bool = False, + squash_output: bool = False, + learn_features: bool = False, + epsilon: float = 1e-6, + ): super(StateDependentNoiseDistribution, self).__init__() self.distribution = None self.action_dim = action_dim @@ -501,8 +495,9 @@ def sample_weights(self, log_std: th.Tensor, batch_size: int = 1) -> None: # Pre-compute matrices in case of parallel exploration self.exploration_matrices = self.weights_dist.rsample((batch_size,)) - def proba_distribution_net(self, latent_dim: int, log_std_init: float = -2.0, - latent_sde_dim: Optional[int] = None) -> Tuple[nn.Module, nn.Parameter]: + def proba_distribution_net( + self, latent_dim: int, log_std_init: float = -2.0, latent_sde_dim: Optional[int] = None + ) -> Tuple[nn.Module, nn.Parameter]: """ Create the layers and parameter that represent the distribution: one output will be the deterministic action, the other parameter will be the @@ -527,9 +522,9 @@ def proba_distribution_net(self, latent_dim: int, log_std_init: float = -2.0, self.sample_weights(log_std) return mean_actions_net, log_std - def proba_distribution(self, mean_actions: th.Tensor, - log_std: th.Tensor, - latent_sde: th.Tensor) -> 'StateDependentNoiseDistribution': + def proba_distribution( + self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor + ) -> "StateDependentNoiseDistribution": """ Create the distribution given its parameters (mean, std) @@ -591,17 +586,16 @@ def get_noise(self, latent_sde: th.Tensor) -> th.Tensor: noise = th.bmm(latent_sde, self.exploration_matrices) return noise.squeeze(1) - def actions_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor, - latent_sde: th.Tensor, - deterministic: bool = False) -> th.Tensor: + def actions_from_params( + self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor, deterministic: bool = False + ) -> th.Tensor: # Update the proba distribution self.proba_distribution(mean_actions, log_std, latent_sde) return self.get_actions(deterministic=deterministic) - def log_prob_from_params(self, mean_actions: th.Tensor, - log_std: th.Tensor, - latent_sde: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: + def log_prob_from_params( + self, mean_actions: th.Tensor, log_std: th.Tensor, latent_sde: th.Tensor + ) -> Tuple[th.Tensor, th.Tensor]: actions = self.actions_from_params(mean_actions, log_std, latent_sde) log_prob = self.log_prob(actions) return actions, log_prob @@ -644,16 +638,16 @@ def inverse(y: th.Tensor) -> th.Tensor: """ eps = th.finfo(y.dtype).eps # Clip the action to avoid NaN - return TanhBijector.atanh(y.clamp(min=-1. + eps, max=1. - eps)) + return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps)) def log_prob_correction(self, x: th.Tensor) -> th.Tensor: # Squash correction (from original SAC implementation) return th.log(1.0 - th.tanh(x) ** 2 + self.epsilon) -def make_proba_distribution(action_space: gym.spaces.Space, - use_sde: bool = False, - dist_kwargs: Optional[Dict[str, Any]] = None) -> Distribution: +def make_proba_distribution( + action_space: gym.spaces.Space, use_sde: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None +) -> Distribution: """ Return an instance of Distribution for the correct type of action space @@ -677,6 +671,8 @@ def make_proba_distribution(action_space: gym.spaces.Space, elif isinstance(action_space, spaces.MultiBinary): return BernoulliDistribution(action_space.n, **dist_kwargs) else: - raise NotImplementedError("Error: probability distribution, not implemented for action space" - f"of type {type(action_space)}." - " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary.") + raise NotImplementedError( + "Error: probability distribution, not implemented for action space" + f"of type {type(action_space)}." + " Must be of type Gym Spaces: Box, Discrete, MultiDiscrete or MultiBinary." + ) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index dbed5b8bf..558a9fe0a 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -2,8 +2,8 @@ from typing import Union import gym -from gym import spaces import numpy as np +from gym import spaces from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -22,41 +22,48 @@ def _check_image_input(observation_space: spaces.Box) -> None: when the observation is apparently an image. """ if observation_space.dtype != np.uint8: - warnings.warn("It seems that your observation is an image but the `dtype` " - "of your observation_space is not `np.uint8`. " - "If your observation is not an image, we recommend you to flatten the observation " - "to have only a 1D vector") + warnings.warn( + "It seems that your observation is an image but the `dtype` " + "of your observation_space is not `np.uint8`. " + "If your observation is not an image, we recommend you to flatten the observation " + "to have only a 1D vector" + ) if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): - warnings.warn("It seems that your observation space is an image but the " - "upper and lower bounds are not in [0, 255]. " - "Because the CNN policy normalize automatically the observation " - "you may encounter issue if the values are not in that range." - ) + warnings.warn( + "It seems that your observation space is an image but the " + "upper and lower bounds are not in [0, 255]. " + "Because the CNN policy normalize automatically the observation " + "you may encounter issue if the values are not in that range." + ) if observation_space.shape[0] < 36 or observation_space.shape[1] < 36: - warnings.warn("The minimal resolution for an image is 36x36 for the default CnnPolicy. " - "You might need to use a custom `cnn_extractor` " - "cf https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html") + warnings.warn( + "The minimal resolution for an image is 36x36 for the default CnnPolicy. " + "You might need to use a custom `cnn_extractor` " + "cf https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html" + ) def _check_unsupported_obs_spaces(env: gym.Env, observation_space: spaces.Space) -> None: """Emit warnings when the observation space used is not supported by Stable-Baselines.""" if isinstance(observation_space, spaces.Dict) and not isinstance(env, gym.GoalEnv): - warnings.warn("The observation space is a Dict but the environment is not a gym.GoalEnv " - "(cf https://github.com/openai/gym/blob/master/gym/core.py), " - "this is currently not supported by Stable Baselines " - "(cf https://github.com/hill-a/stable-baselines/issues/133), " - "you will need to use a custom policy. " - ) + warnings.warn( + "The observation space is a Dict but the environment is not a gym.GoalEnv " + "(cf https://github.com/openai/gym/blob/master/gym/core.py), " + "this is currently not supported by Stable Baselines " + "(cf https://github.com/hill-a/stable-baselines/issues/133), " + "you will need to use a custom policy. " + ) if isinstance(observation_space, spaces.Tuple): - warnings.warn("The observation space is a Tuple," - "this is currently not supported by Stable Baselines " - "(cf https://github.com/hill-a/stable-baselines/issues/133), " - "you will need to flatten the observation and maybe use a custom policy. " - ) + warnings.warn( + "The observation space is a Tuple," + "this is currently not supported by Stable Baselines " + "(cf https://github.com/hill-a/stable-baselines/issues/133), " + "you will need to flatten the observation and maybe use a custom policy. " + ) def _check_nan(env: gym.Env) -> None: @@ -67,26 +74,27 @@ def _check_nan(env: gym.Env) -> None: _, _, _, _ = vec_env.step(action) -def _check_obs(obs: Union[tuple, dict, np.ndarray, int], - observation_space: spaces.Space, - method_name: str) -> None: +def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: """ Check that the observation returned by the environment correspond to the declared one. """ if not isinstance(observation_space, spaces.Tuple): - assert not isinstance(obs, tuple), ("The observation returned by the `{}()` " - "method should be a single value, not a tuple".format(method_name)) + assert not isinstance( + obs, tuple + ), "The observation returned by the `{}()` method should be a single value, not a tuple".format(method_name) # The check for a GoalEnv is done by the base class if isinstance(observation_space, spaces.Discrete): assert isinstance(obs, int), "The observation returned by `{}()` method must be an int".format(method_name) elif _enforce_array_obs(observation_space): - assert isinstance(obs, np.ndarray), ("The observation returned by `{}()` " - "method must be a numpy array".format(method_name)) + assert isinstance(obs, np.ndarray), "The observation returned by `{}()` method must be a numpy array".format( + method_name + ) - assert observation_space.contains(obs), ("The observation returned by the `{}()` " - "method does not match the given observation space".format(method_name)) + assert observation_space.contains( + obs + ), "The observation returned by the `{}()` method does not match the given observation space".format(method_name) def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None: @@ -96,7 +104,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists obs = env.reset() - _check_obs(obs, observation_space, 'reset') + _check_obs(obs, observation_space, "reset") # Sample a random action action = action_space.sample() @@ -107,7 +115,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Unpack obs, reward, done, info = data - _check_obs(obs, observation_space, 'step') + _check_obs(obs, observation_space, "step") # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" @@ -116,7 +124,7 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action if isinstance(env, gym.GoalEnv): # For a GoalEnv, the keys are checked at reset - assert reward == env.compute_reward(obs['achieved_goal'], obs['desired_goal'], info) + assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) def _check_spaces(env: gym.Env) -> None: @@ -127,11 +135,10 @@ def _check_spaces(env: gym.Env) -> None: # Helper to link to the code, because gym has no proper documentation gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" - assert hasattr(env, 'observation_space'), "You must specify an observation space (cf gym.spaces)" + gym_spaces - assert hasattr(env, 'action_space'), "You must specify an action space (cf gym.spaces)" + gym_spaces + assert hasattr(env, "observation_space"), "You must specify an observation space (cf gym.spaces)" + gym_spaces + assert hasattr(env, "action_space"), "You must specify an action space (cf gym.spaces)" + gym_spaces - assert isinstance(env.observation_space, - spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces + assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces @@ -145,18 +152,20 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No :param headless: (bool) Whether to disable render modes that require a graphical interface. False by default. """ - render_modes = env.metadata.get('render.modes') + render_modes = env.metadata.get("render.modes") if render_modes is None: if warn: - warnings.warn("No render modes was declared in the environment " - " (env.metadata['render.modes'] is None or not defined), " - "you may have trouble when calling `.render()`") + warnings.warn( + "No render modes was declared in the environment " + " (env.metadata['render.modes'] is None or not defined), " + "you may have trouble when calling `.render()`" + ) else: # Don't check render mode that require a # graphical interface (useful for CI) - if headless and 'human' in render_modes: - render_modes.remove('human') + if headless and "human" in render_modes: + render_modes.remove("human") # Check all declared render modes for render_mode in render_modes: env.render(mode=render_mode) @@ -178,8 +187,9 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - :param skip_render_check: (bool) Whether to skip the checks for the render method. True by default (useful for the CI) """ - assert isinstance(env, gym.Env), ("You environment must inherit from gym.Env class " - " cf https://github.com/openai/gym/blob/master/gym/core.py") + assert isinstance( + env, gym.Env + ), "You environment must inherit from gym.Env class cf https://github.com/openai/gym/blob/master/gym/core.py" # ============= Check the spaces (observation and action) ================ _check_spaces(env) @@ -199,16 +209,22 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - _check_image_input(observation_space) if isinstance(observation_space, spaces.Box) and len(observation_space.shape) not in [1, 3]: - warnings.warn("Your observation has an unconventional shape (neither an image, nor a 1D vector). " - "We recommend you to flatten the observation " - "to have only a 1D vector") + warnings.warn( + "Your observation has an unconventional shape (neither an image, nor a 1D vector). " + "We recommend you to flatten the observation " + "to have only a 1D vector" + ) # Check for the action space, it may lead to hard-to-debug issues - if (isinstance(action_space, spaces.Box) and - (np.any(np.abs(action_space.low) != np.abs(action_space.high)) - or np.any(np.abs(action_space.low) > 1) or np.any(np.abs(action_space.high) > 1))): - warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " - "cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html") + if isinstance(action_space, spaces.Box) and ( + np.any(np.abs(action_space.low) != np.abs(action_space.high)) + or np.any(np.abs(action_space.low) > 1) + or np.any(np.abs(action_space.high) > 1) + ): + warnings.warn( + "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) " + "cf https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html" + ) # ============ Check the returned values =============== _check_returned_values(env, observation_space, action_space) diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index b298b65b4..6dac4d5a7 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -4,9 +4,16 @@ from stable_baselines3.common.vec_env import VecEnv -def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, - render=False, callback=None, reward_threshold=None, - return_episode_rewards=False): +def evaluate_policy( + model, + env, + n_eval_episodes=10, + deterministic=True, + render=False, + callback=None, + reward_threshold=None, + return_episode_rewards=False, +): """ Runs policy for ``n_eval_episodes`` episodes and returns average reward. This is made to work only with one env. @@ -49,8 +56,7 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) if reward_threshold is not None: - assert mean_reward > reward_threshold, ('Mean reward below threshold: ' - f'{mean_reward:.2f} < {reward_threshold:.2f}') + assert mean_reward > reward_threshold, "Mean reward below threshold: " f"{mean_reward:.2f} < {reward_threshold:.2f}" if return_episode_rewards: return episode_rewards, episode_lengths return mean_reward, std_reward diff --git a/stable_baselines3/common/identity_env.py b/stable_baselines3/common/identity_env.py index df77f6ac2..4a492e75a 100644 --- a/stable_baselines3/common/identity_env.py +++ b/stable_baselines3/common/identity_env.py @@ -1,18 +1,14 @@ -from typing import Union, Optional +from typing import Optional, Union import numpy as np from gym import Env, Space -from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box +from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete - -from stable_baselines3.common.type_aliases import GymStepReturn, GymObs +from stable_baselines3.common.type_aliases import GymObs, GymStepReturn class IdentityEnv(Env): - def __init__(self, - dim: Optional[int] = None, - space: Optional[Space] = None, - ep_length: int = 100): + def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_length: int = 100): """ Identity environment for testing purposes @@ -55,14 +51,12 @@ def _choose_next_state(self) -> None: def _get_reward(self, action: Union[int, np.ndarray]) -> float: return 1.0 if np.all(self.state == action) else 0.0 - def render(self, mode: str = 'human') -> None: + def render(self, mode: str = "human") -> None: pass class IdentityEnvBox(IdentityEnv): - def __init__(self, low: float = -1.0, - high: float = 1.0, eps: float = 0.05, - ep_length: int = 100): + def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_length: int = 100): """ Identity environment for testing purposes @@ -120,14 +114,12 @@ class FakeImageEnv(Env): :param n_channels: (int) Number of color channels :param discrete: (bool) """ - def __init__(self, action_dim: int = 6, - screen_height: int = 84, - screen_width: int = 84, - n_channels: int = 1, - discrete: bool = True): - - self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, - n_channels), dtype=np.uint8) + + def __init__( + self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True + ): + + self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8) if discrete: self.action_space = Discrete(action_dim) else: @@ -145,5 +137,5 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: done = self.current_step >= self.ep_length return self.observation_space.sample(), reward, done, {} - def render(self, mode: str = 'human') -> None: + def render(self, mode: str = "human") -> None: pass diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index d23726e64..bed63273e 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -1,15 +1,16 @@ -import sys import datetime import json import os +import sys import tempfile import warnings from collections import defaultdict -from typing import Dict, List, TextIO, Union, Any, Optional, Tuple +from typing import Any, Dict, List, Optional, TextIO, Tuple, Union -import pandas import numpy as np +import pandas import torch as th + try: from torch.utils.tensorboard import SummaryWriter except ImportError: @@ -27,8 +28,7 @@ class KVWriter(object): Key Value writer """ - def write(self, key_values: Dict[str, Any], - key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: """ Write a dictionary to file @@ -67,10 +67,10 @@ def __init__(self, filename_or_file: Union[str, TextIO]): :param filename_or_file: (str or File) the file to write the log to """ if isinstance(filename_or_file, str): - self.file = open(filename_or_file, 'wt') + self.file = open(filename_or_file, "wt") self.own_file = True else: - assert hasattr(filename_or_file, 'write'), f'Expected file or str, got {filename_or_file}' + assert hasattr(filename_or_file, "write"), f"Expected file or str, got {filename_or_file}" self.file = filename_or_file self.own_file = False @@ -80,56 +80,56 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None: tag = None for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and 'stdout' in excluded: + if excluded is not None and "stdout" in excluded: continue if isinstance(value, float): # Align left - value_str = f'{value:<8.3g}' + value_str = f"{value:<8.3g}" else: value_str = str(value) - if key.find('/') > 0: # Find tag and add it to the dict - tag = key[:key.find('/') + 1] - key2str[self._truncate(tag)] = '' + if key.find("/") > 0: # Find tag and add it to the dict + tag = key[: key.find("/") + 1] + key2str[self._truncate(tag)] = "" # Remove tag from key if tag is not None and tag in key: - key = str(' ' + key[len(tag):]) + key = str(" " + key[len(tag) :]) key2str[self._truncate(key)] = self._truncate(value_str) # Find max widths if len(key2str) == 0: - warnings.warn('Tried to write empty key-value dict') + warnings.warn("Tried to write empty key-value dict") return else: key_width = max(map(len, key2str.keys())) val_width = max(map(len, key2str.values())) # Write out the data - dashes = '-' * (key_width + val_width + 7) + dashes = "-" * (key_width + val_width + 7) lines = [dashes] for key, value in key2str.items(): - key_space = ' ' * (key_width - len(key)) - val_space = ' ' * (val_width - len(value)) + key_space = " " * (key_width - len(key)) + val_space = " " * (val_width - len(value)) lines.append(f"| {key}{key_space} | {value}{val_space} |") lines.append(dashes) - self.file.write('\n'.join(lines) + '\n') + self.file.write("\n".join(lines) + "\n") # Flush the output to the file self.file.flush() @classmethod def _truncate(cls, string: str, max_length: int = 23) -> str: - return string[:max_length - 3] + '...' if len(string) > max_length else string + return string[: max_length - 3] + "..." if len(string) > max_length else string def write_sequence(self, sequence: List) -> None: sequence = list(sequence) for i, elem in enumerate(sequence): self.file.write(elem) if i < len(sequence) - 1: # add space unless this is the last one - self.file.write(' ') - self.file.write('\n') + self.file.write(" ") + self.file.write("\n") self.file.flush() def close(self) -> None: @@ -147,23 +147,22 @@ def __init__(self, filename: str): :param filename: (str) the file to write the log to """ - self.file = open(filename, 'wt') + self.file = open(filename, "wt") - def write(self, key_values: Dict[str, Any], - key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and 'json' in excluded: + if excluded is not None and "json" in excluded: continue - if hasattr(value, 'dtype'): + if hasattr(value, "dtype"): if value.shape == () or len(value) == 1: # if value is a dimensionless numpy array or of length 1, serialize as a float key_values[key] = float(value) else: # otherwise, a value is a numpy array, serialize as a list or nested lists key_values[key] = value.tolist() - self.file.write(json.dumps(key_values) + '\n') + self.file.write(json.dumps(key_values) + "\n") self.file.flush() def close(self) -> None: @@ -182,12 +181,11 @@ def __init__(self, filename: str): :param filename: (str) the file to write the log to """ - self.file = open(filename, 'w+t') + self.file = open(filename, "w+t") self.keys = [] - self.separator = ',' + self.separator = "," - def write(self, key_values: Dict[str, Any], - key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: # Add our current row to the history extra_keys = key_values.keys() - self.keys if extra_keys: @@ -197,20 +195,20 @@ def write(self, key_values: Dict[str, Any], self.file.seek(0) for (i, key) in enumerate(self.keys): if i > 0: - self.file.write(',') + self.file.write(",") self.file.write(key) - self.file.write('\n') + self.file.write("\n") for line in lines[1:]: self.file.write(line[:-1]) self.file.write(self.separator * len(extra_keys)) - self.file.write('\n') + self.file.write("\n") for i, key in enumerate(self.keys): if i > 0: - self.file.write(',') + self.file.write(",") value = key_values.get(key) if value is not None: self.file.write(str(value)) - self.file.write('\n') + self.file.write("\n") self.file.flush() def close(self) -> None: @@ -227,17 +225,14 @@ def __init__(self, folder: str): :param folder: (str) the folder to write the log to """ - assert SummaryWriter is not None, ("tensorboard is not installed, you can use " - "pip install tensorboard to do so") + assert SummaryWriter is not None, "tensorboard is not installed, you can use " "pip install tensorboard to do so" self.writer = SummaryWriter(log_dir=folder) - def write(self, key_values: Dict[str, Any], - key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: + def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, Tuple[str, ...]]], step: int = 0) -> None: - for (key, value), (_, excluded) in zip(sorted(key_values.items()), - sorted(key_excluded.items())): + for (key, value), (_, excluded) in zip(sorted(key_values.items()), sorted(key_excluded.items())): - if excluded is not None and 'tensorboard' in excluded: + if excluded is not None and "tensorboard" in excluded: continue if isinstance(value, np.ScalarType): @@ -258,7 +253,7 @@ def close(self) -> None: self.writer = None -def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWriter: +def make_output_format(_format: str, log_dir: str, log_suffix: str = "") -> KVWriter: """ return a logger for the requested format @@ -268,26 +263,26 @@ def make_output_format(_format: str, log_dir: str, log_suffix: str = '') -> KVWr :return: (KVWriter) the logger """ os.makedirs(log_dir, exist_ok=True) - if _format == 'stdout': + if _format == "stdout": return HumanOutputFormat(sys.stdout) - elif _format == 'log': - return HumanOutputFormat(os.path.join(log_dir, f'log{log_suffix}.txt')) - elif _format == 'json': - return JSONOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.json')) - elif _format == 'csv': - return CSVOutputFormat(os.path.join(log_dir, f'progress{log_suffix}.csv')) - elif _format == 'tensorboard': + elif _format == "log": + return HumanOutputFormat(os.path.join(log_dir, f"log{log_suffix}.txt")) + elif _format == "json": + return JSONOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.json")) + elif _format == "csv": + return CSVOutputFormat(os.path.join(log_dir, f"progress{log_suffix}.csv")) + elif _format == "tensorboard": return TensorBoardOutputFormat(log_dir) else: - raise ValueError(f'Unknown format specified: {_format}') + raise ValueError(f"Unknown format specified: {_format}") # ================================================================ # API # ================================================================ -def record(key: str, value: Any, - exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + +def record(key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration @@ -300,8 +295,7 @@ def record(key: str, value: Any, Logger.CURRENT.record(key, value, exclude) -def record_mean(key: str, value: Union[int, float], - exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: +def record_mean(key: str, value: Union[int, float], exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ The same as record(), but if called many times, values averaged. @@ -431,6 +425,7 @@ def get_dir() -> str: # Backend # ================================================================ + class Logger(object): # A logger with no output files. (See right below class definition) # So that you can still log to the terminal without setting up any output files @@ -453,8 +448,7 @@ def __init__(self, folder: Optional[str], output_formats: List[KVWriter]): # Logging API, forwarded # ---------------------------------------- - def record(self, key: str, value: Any, - exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ Log a value of some diagnostic Call this once for each diagnostic quantity, each iteration @@ -467,8 +461,7 @@ def record(self, key: str, value: Any, self.name_to_value[key] = value self.name_to_excluded[key] = exclude - def record_mean(self, key: str, value: Any, - exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: + def record_mean(self, key: str, value: Any, exclude: Optional[Union[str, Tuple[str, ...]]] = None) -> None: """ The same as record(), but if called many times, values averaged. @@ -565,21 +558,21 @@ def configure(folder: Optional[str] = None, format_strings: Optional[List[str]] (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) """ if folder is None: - folder = os.getenv('SB3_LOGDIR') + folder = os.getenv("SB3_LOGDIR") if folder is None: folder = os.path.join(tempfile.gettempdir(), datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f")) assert isinstance(folder, str) os.makedirs(folder, exist_ok=True) - log_suffix = '' + log_suffix = "" if format_strings is None: - format_strings = os.getenv('SB3_LOG_FORMAT', 'stdout,log,csv').split(',') + format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",") format_strings = filter(None, format_strings) output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] Logger.CURRENT = Logger(folder=folder, output_formats=output_formats) - log(f'Logging to {folder}') + log(f"Logging to {folder}") def reset() -> None: @@ -589,7 +582,7 @@ def reset() -> None: if Logger.CURRENT is not Logger.DEFAULT: Logger.CURRENT.close() Logger.CURRENT = Logger.DEFAULT - log('Reset logger') + log("Reset logger") class ScopedConfigure(object): @@ -621,6 +614,7 @@ def __exit__(self, *args) -> None: # Readers # ================================================================ + def read_json(filename: str) -> pandas.DataFrame: """ read a json file using pandas @@ -629,7 +623,7 @@ def read_json(filename: str) -> pandas.DataFrame: :return: (pandas.DataFrame) the data in the json """ data = [] - with open(filename, 'rt') as file_handler: + with open(filename, "rt") as file_handler: for line in file_handler: data.append(json.loads(line)) return pandas.DataFrame(data) @@ -642,4 +636,4 @@ def read_csv(filename: str) -> pandas.DataFrame: :param filename: (str) the file path to read :return: (pandas.DataFrame) the data in the csv """ - return pandas.read_csv(filename, index_col=None, comment='#') + return pandas.read_csv(filename, index_col=None, comment="#") diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 3210a74d2..df6719f2a 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -1,15 +1,15 @@ -__all__ = ['Monitor', 'get_monitor_files', 'load_results'] +__all__ = ["Monitor", "get_monitor_files", "load_results"] import csv import json import os import time from glob import glob -from typing import Tuple, Dict, Any, List, Optional +from typing import Any, Dict, List, Optional, Tuple import gym -import pandas import numpy as np +import pandas class Monitor(gym.Wrapper): @@ -23,14 +23,17 @@ class Monitor(gym.Wrapper): if extra parameters are needed at reset :param info_keywords: (Tuple[str, ...]) extra information to log, from the information return of env.step() """ + EXT = "monitor.csv" - def __init__(self, - env: gym.Env, - filename: Optional[str] = None, - allow_early_resets: bool = True, - reset_keywords: Tuple[str, ...] = (), - info_keywords: Tuple[str, ...] = ()): + def __init__( + self, + env: gym.Env, + filename: Optional[str] = None, + allow_early_resets: bool = True, + reset_keywords: Tuple[str, ...] = (), + info_keywords: Tuple[str, ...] = (), + ): super(Monitor, self).__init__(env=env) self.t_start = time.time() if filename is None: @@ -43,9 +46,8 @@ def __init__(self, else: filename = filename + "." + Monitor.EXT self.file_handler = open(filename, "wt") - self.file_handler.write('#%s\n' % json.dumps({"t_start": self.t_start, 'env_id': env.spec and env.spec.id})) - self.logger = csv.DictWriter(self.file_handler, - fieldnames=('r', 'l', 't') + reset_keywords + info_keywords) + self.file_handler.write("#%s\n" % json.dumps({"t_start": self.t_start, "env_id": env.spec and env.spec.id})) + self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + reset_keywords + info_keywords) self.logger.writeheader() self.file_handler.flush() @@ -68,14 +70,16 @@ def reset(self, **kwargs) -> np.ndarray: :return: (np.ndarray) the first observation of the environment """ if not self.allow_early_resets and not self.needs_reset: - raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, " - "wrap your env with Monitor(env, path, allow_early_resets=True)") + raise RuntimeError( + "Tried to reset an environment before done. If you want to allow early resets, " + "wrap your env with Monitor(env, path, allow_early_resets=True)" + ) self.rewards = [] self.needs_reset = False for key in self.reset_keywords: value = kwargs.get(key) if value is None: - raise ValueError('Expected you to pass kwarg {} into reset'.format(key)) + raise ValueError("Expected you to pass kwarg {} into reset".format(key)) self.current_reset_info[key] = value return self.env.reset(**kwargs) @@ -104,7 +108,7 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, A if self.logger: self.logger.writerow(ep_info) self.file_handler.flush() - info['episode'] = ep_info + info["episode"] = ep_info self.total_steps += 1 return observation, reward, done, info @@ -153,6 +157,7 @@ class LoadMonitorResultsError(Exception): """ Raised when loading the monitor log fails. """ + pass @@ -178,16 +183,16 @@ def load_results(path: str) -> pandas.DataFrame: raise LoadMonitorResultsError("no monitor files of the form *%s found in %s" % (Monitor.EXT, path)) data_frames, headers = [], [] for file_name in monitor_files: - with open(file_name, 'rt') as file_handler: + with open(file_name, "rt") as file_handler: first_line = file_handler.readline() - assert first_line[0] == '#' + assert first_line[0] == "#" header = json.loads(first_line[1:]) data_frame = pandas.read_csv(file_handler, index_col=None) headers.append(header) - data_frame['t'] += header['t_start'] + data_frame["t"] += header["t_start"] data_frames.append(data_frame) data_frame = pandas.concat(data_frames) - data_frame.sort_values('t', inplace=True) + data_frame.sort_values("t", inplace=True) data_frame.reset_index(inplace=True) - data_frame['t'] -= min(header['t_start'] for header in headers) + data_frame["t"] -= min(header["t_start"] for header in headers) return data_frame diff --git a/stable_baselines3/common/noise.py b/stable_baselines3/common/noise.py index 63b4b4a48..f7b8b7372 100644 --- a/stable_baselines3/common/noise.py +++ b/stable_baselines3/common/noise.py @@ -1,6 +1,6 @@ -from typing import Optional, List, Iterable -from abc import ABC, abstractmethod import copy +from abc import ABC, abstractmethod +from typing import Iterable, List, Optional import numpy as np @@ -41,7 +41,7 @@ def __call__(self) -> np.ndarray: return np.random.normal(self._mu, self._sigma) def __repr__(self) -> str: - return f'NormalActionNoise(mu={self._mu}, sigma={self._sigma})' + return f"NormalActionNoise(mu={self._mu}, sigma={self._sigma})" class OrnsteinUhlenbeckActionNoise(ActionNoise): @@ -57,11 +57,14 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise): :param initial_noise: (Optional[np.ndarray]) the initial value for the noise output, (if None: 0) """ - def __init__(self, mean: np.ndarray, - sigma: np.ndarray, - theta: float = .15, - dt: float = 1e-2, - initial_noise: Optional[np.ndarray] = None): + def __init__( + self, + mean: np.ndarray, + sigma: np.ndarray, + theta: float = 0.15, + dt: float = 1e-2, + initial_noise: Optional[np.ndarray] = None, + ): self._theta = theta self._mu = mean self._sigma = sigma @@ -72,8 +75,11 @@ def __init__(self, mean: np.ndarray, super(OrnsteinUhlenbeckActionNoise, self).__init__() def __call__(self) -> np.ndarray: - noise = (self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt - + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)) + noise = ( + self.noise_prev + + self._theta * (self._mu - self.noise_prev) * self._dt + + self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape) + ) self.noise_prev = noise return noise @@ -84,7 +90,7 @@ def reset(self) -> None: self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu) def __repr__(self) -> str: - return f'OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})' + return f"OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})" class VectorizedActionNoise(ActionNoise): @@ -149,15 +155,11 @@ def noises(self, noises: List[ActionNoise]) -> None: noises = list(noises) # raises TypeError if not iterable assert len(noises) == self.n_envs, f"Expected a list of {self.n_envs} ActionNoises, found {len(noises)}." - different_types = [ - i for i, noise in enumerate(noises) - if not isinstance(noise, type(self.base_noise)) - ] + different_types = [i for i, noise in enumerate(noises) if not isinstance(noise, type(self.base_noise))] if len(different_types): raise ValueError( - f"Noise instances at indices {different_types} don't match the type of base_noise", - type(self.base_noise) + f"Noise instances at indices {different_types} don't match the type of base_noise", type(self.base_noise) ) self._noises = noises diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 5a8d085b0..b2814df14 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -1,23 +1,23 @@ +import io +import pathlib import time import warnings -import pathlib -from typing import Union, Type, Optional, Dict, Any, Callable, List, Tuple -import io +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym -import torch as th import numpy as np +import torch as th from stable_baselines3.common import logger from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn from stable_baselines3.common.utils import safe_mean from stable_baselines3.common.vec_env import VecEnv -from stable_baselines3.common.type_aliases import GymEnv, RolloutReturn, MaybeCallback -from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.buffers import ReplayBuffer -from stable_baselines3.common.save_util import save_to_pkl, load_from_pkl class OffPolicyAlgorithm(BaseAlgorithm): @@ -69,40 +69,52 @@ class OffPolicyAlgorithm(BaseAlgorithm): :param sde_support: (bool) Whether the model support gSDE or not """ - def __init__(self, - policy: Type[BasePolicy], - env: Union[GymEnv, str], - policy_base: Type[BasePolicy], - learning_rate: Union[float, Callable], - buffer_size: int = int(1e6), - learning_starts: int = 100, - batch_size: int = 256, - tau: float = 0.005, - gamma: float = 0.99, - train_freq: int = 1, - gradient_steps: int = 1, - n_episodes_rollout: int = -1, - action_noise: Optional[ActionNoise] = None, - optimize_memory_usage: bool = False, - policy_kwargs: Dict[str, Any] = None, - tensorboard_log: Optional[str] = None, - verbose: int = 0, - device: Union[th.device, str] = 'auto', - support_multi_env: bool = False, - create_eval_env: bool = False, - monitor_wrapper: bool = True, - seed: Optional[int] = None, - use_sde: bool = False, - sde_sample_freq: int = -1, - use_sde_at_warmup: bool = False, - sde_support: bool = True): - - super(OffPolicyAlgorithm, self).__init__(policy=policy, env=env, policy_base=policy_base, - learning_rate=learning_rate, policy_kwargs=policy_kwargs, - tensorboard_log=tensorboard_log, verbose=verbose, - device=device, support_multi_env=support_multi_env, - create_eval_env=create_eval_env, monitor_wrapper=monitor_wrapper, - seed=seed, use_sde=use_sde, sde_sample_freq=sde_sample_freq) + def __init__( + self, + policy: Type[BasePolicy], + env: Union[GymEnv, str], + policy_base: Type[BasePolicy], + learning_rate: Union[float, Callable], + buffer_size: int = int(1e6), + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: int = 1, + gradient_steps: int = 1, + n_episodes_rollout: int = -1, + action_noise: Optional[ActionNoise] = None, + optimize_memory_usage: bool = False, + policy_kwargs: Dict[str, Any] = None, + tensorboard_log: Optional[str] = None, + verbose: int = 0, + device: Union[th.device, str] = "auto", + support_multi_env: bool = False, + create_eval_env: bool = False, + monitor_wrapper: bool = True, + seed: Optional[int] = None, + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + sde_support: bool = True, + ): + + super(OffPolicyAlgorithm, self).__init__( + policy=policy, + env=env, + policy_base=policy_base, + learning_rate=learning_rate, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + support_multi_env=support_multi_env, + create_eval_env=create_eval_env, + monitor_wrapper=monitor_wrapper, + seed=seed, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + ) self.buffer_size = buffer_size self.batch_size = batch_size self.learning_starts = learning_starts @@ -115,30 +127,40 @@ def __init__(self, self.optimize_memory_usage = optimize_memory_usage if train_freq > 0 and n_episodes_rollout > 0: - warnings.warn("You passed a positive value for `train_freq` and `n_episodes_rollout`." - "Please make sure this is intended. " - "The agent will collect data by stepping in the environment " - "until both conditions are true: " - "`number of steps in the env` >= `train_freq` and " - "`number of episodes` > `n_episodes_rollout`") + warnings.warn( + "You passed a positive value for `train_freq` and `n_episodes_rollout`." + "Please make sure this is intended. " + "The agent will collect data by stepping in the environment " + "until both conditions are true: " + "`number of steps in the env` >= `train_freq` and " + "`number of episodes` > `n_episodes_rollout`" + ) self.actor = None # type: Optional[th.nn.Module] self.replay_buffer = None # type: Optional[ReplayBuffer] # Update policy keyword arguments if sde_support: - self.policy_kwargs['use_sde'] = self.use_sde - self.policy_kwargs['device'] = self.device + self.policy_kwargs["use_sde"] = self.use_sde + self.policy_kwargs["device"] = self.device # For gSDE only self.use_sde_at_warmup = use_sde_at_warmup def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space, - self.action_space, self.device, - optimize_memory_usage=self.optimize_memory_usage) - self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, **self.policy_kwargs) # pytype:disable=not-instantiable + self.replay_buffer = ReplayBuffer( + self.buffer_size, + self.observation_space, + self.action_space, + self.device, + optimize_memory_usage=self.optimize_memory_usage, + ) + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + **self.policy_kwargs # pytype:disable=not-instantiable + ) self.policy = self.policy.to(self.device) def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: @@ -158,65 +180,78 @@ def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) :param path: (Union[str, pathlib.Path, io.BufferedIOBase]) Path to the pickled replay buffer. """ self.replay_buffer = load_from_pkl(path, self.verbose) - assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class' - - def _setup_learn(self, - total_timesteps: int, - eval_env: Optional[GymEnv], - callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None, - eval_freq: int = 10000, - n_eval_episodes: int = 5, - log_path: Optional[str] = None, - reset_num_timesteps: bool = True, - tb_log_name: str = 'run', - ) -> Tuple[int, BaseCallback]: + assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" + + def _setup_learn( + self, + total_timesteps: int, + eval_env: Optional[GymEnv], + callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None, + eval_freq: int = 10000, + n_eval_episodes: int = 5, + log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + tb_log_name: str = "run", + ) -> Tuple[int, BaseCallback]: """ cf `BaseAlgorithm`. """ # Prevent continuity issue by truncating trajectory # when using memory efficient replay buffer # see https://github.com/DLR-RM/stable-baselines3/issues/46 - truncate_last_traj = (self.optimize_memory_usage and reset_num_timesteps - and self.replay_buffer is not None - and (self.replay_buffer.full or self.replay_buffer.pos > 0)) + truncate_last_traj = ( + self.optimize_memory_usage + and reset_num_timesteps + and self.replay_buffer is not None + and (self.replay_buffer.full or self.replay_buffer.pos > 0) + ) if truncate_last_traj: - warnings.warn("The last trajectory in the replay buffer will be truncated, " - "see https://github.com/DLR-RM/stable-baselines3/issues/46." - "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" - "to avoid that issue.") + warnings.warn( + "The last trajectory in the replay buffer will be truncated, " + "see https://github.com/DLR-RM/stable-baselines3/issues/46." + "You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`" + "to avoid that issue." + ) # Go to the previous index pos = (self.replay_buffer.pos - 1) % self.replay_buffer.buffer_size self.replay_buffer.dones[pos] = True - return super()._setup_learn(total_timesteps, eval_env, callback, eval_freq, - n_eval_episodes, log_path, reset_num_timesteps, tb_log_name) - - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "run", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> 'OffPolicyAlgorithm': - - total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq, - n_eval_episodes, eval_log_path, reset_num_timesteps, - tb_log_name) + return super()._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, log_path, reset_num_timesteps, tb_log_name + ) + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "run", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "OffPolicyAlgorithm": + + total_timesteps, callback = self._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + ) callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: - rollout = self.collect_rollouts(self.env, n_episodes=self.n_episodes_rollout, - n_steps=self.train_freq, action_noise=self.action_noise, - callback=callback, - learning_starts=self.learning_starts, - replay_buffer=self.replay_buffer, - log_interval=log_interval) + rollout = self.collect_rollouts( + self.env, + n_episodes=self.n_episodes_rollout, + n_steps=self.train_freq, + action_noise=self.action_noise, + callback=callback, + learning_starts=self.learning_starts, + replay_buffer=self.replay_buffer, + log_interval=log_interval, + ) if rollout.continue_training is False: break @@ -238,8 +273,9 @@ def train(self, gradient_steps: int, batch_size: int) -> None: """ raise NotImplementedError() - def _sample_action(self, learning_starts: int, - action_noise: Optional[ActionNoise] = None) -> Tuple[np.ndarray, np.ndarray]: + def _sample_action( + self, learning_starts: int, action_noise: Optional[ActionNoise] = None + ) -> Tuple[np.ndarray, np.ndarray]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, @@ -288,16 +324,16 @@ def _dump_logs(self) -> None: fps = int(self.num_timesteps / (time.time() - self.start_time)) logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - logger.record('rollout/ep_rew_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer])) - logger.record('rollout/ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer])) + logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) logger.record("time/fps", fps) - logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard") + logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard") if self.use_sde: logger.record("train/std", (self.actor.get_std()).mean().item()) if len(self.ep_success_buffer) > 0: - logger.record('rollout/success rate', safe_mean(self.ep_success_buffer)) + logger.record("rollout/success rate", safe_mean(self.ep_success_buffer)) # Pass the number of timesteps for tensorboard logger.dump(step=self.num_timesteps) @@ -309,15 +345,17 @@ def _on_step(self) -> None: """ pass - def collect_rollouts(self, - env: VecEnv, - callback: BaseCallback, - n_episodes: int = 1, - n_steps: int = -1, - action_noise: Optional[ActionNoise] = None, - learning_starts: int = 0, - replay_buffer: Optional[ReplayBuffer] = None, - log_interval: Optional[int] = None) -> RolloutReturn: + def collect_rollouts( + self, + env: VecEnv, + callback: BaseCallback, + n_episodes: int = 1, + n_steps: int = -1, + action_noise: Optional[ActionNoise] = None, + learning_starts: int = 0, + replay_buffer: Optional[ReplayBuffer] = None, + log_interval: Optional[int] = None, + ) -> RolloutReturn: """ Collect experiences and store them into a ReplayBuffer. diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 2937b77cb..f84d18f34 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,18 +1,18 @@ import time -from typing import Union, Type, Optional, Dict, Any, List, Tuple, Callable +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym -import torch as th import numpy as np +import torch as th from stable_baselines3.common import logger -from stable_baselines3.common.utils import safe_mean from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback -from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.utils import safe_mean +from stable_baselines3.common.vec_env import VecEnv class OnPolicyAlgorithm(BaseAlgorithm): @@ -48,32 +48,44 @@ class OnPolicyAlgorithm(BaseAlgorithm): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, - policy: Union[str, Type[ActorCriticPolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable], - n_steps: int, - gamma: float, - gae_lambda: float, - ent_coef: float, - vf_coef: float, - max_grad_norm: float, - use_sde: bool, - sde_sample_freq: int, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - monitor_wrapper: bool = True, - policy_kwargs: Optional[Dict[str, Any]] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): - - super(OnPolicyAlgorithm, self).__init__(policy=policy, env=env, policy_base=ActorCriticPolicy, - learning_rate=learning_rate, policy_kwargs=policy_kwargs, - verbose=verbose, device=device, use_sde=use_sde, - sde_sample_freq=sde_sample_freq, create_eval_env=create_eval_env, - support_multi_env=True, seed=seed, tensorboard_log=tensorboard_log) + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable], + n_steps: int, + gamma: float, + gae_lambda: float, + ent_coef: float, + vf_coef: float, + max_grad_norm: float, + use_sde: bool, + sde_sample_freq: int, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + monitor_wrapper: bool = True, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(OnPolicyAlgorithm, self).__init__( + policy=policy, + env=env, + policy_base=ActorCriticPolicy, + learning_rate=learning_rate, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + create_eval_env=create_eval_env, + support_multi_env=True, + seed=seed, + tensorboard_log=tensorboard_log, + ) self.n_steps = n_steps self.gamma = gamma @@ -90,20 +102,28 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space, - self.action_space, self.device, - gamma=self.gamma, gae_lambda=self.gae_lambda, - n_envs=self.n_envs) - self.policy = self.policy_class(self.observation_space, self.action_space, - self.lr_schedule, use_sde=self.use_sde, device=self.device, - **self.policy_kwargs) # pytype:disable=not-instantiable + self.rollout_buffer = RolloutBuffer( + self.n_steps, + self.observation_space, + self.action_space, + self.device, + gamma=self.gamma, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + ) + self.policy = self.policy_class( + self.observation_space, + self.action_space, + self.lr_schedule, + use_sde=self.use_sde, + device=self.device, + **self.policy_kwargs # pytype:disable=not-instantiable + ) self.policy = self.policy.to(self.device) - def collect_rollouts(self, - env: VecEnv, - callback: BaseCallback, - rollout_buffer: RolloutBuffer, - n_rollout_steps: int) -> bool: + def collect_rollouts( + self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int + ) -> bool: """ Collect rollouts using the current policy and fill a `RolloutBuffer`. @@ -169,29 +189,29 @@ def train(self) -> None: """ raise NotImplementedError - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 1, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "OnPolicyAlgorithm", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> 'OnPolicyAlgorithm': + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "OnPolicyAlgorithm", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "OnPolicyAlgorithm": iteration = 0 - total_timesteps, callback = self._setup_learn(total_timesteps, eval_env, callback, eval_freq, - n_eval_episodes, eval_log_path, reset_num_timesteps, - tb_log_name) + total_timesteps, callback = self._setup_learn( + total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + ) callback.on_training_start(locals(), globals()) while self.num_timesteps < total_timesteps: - continue_training = self.collect_rollouts(self.env, callback, - self.rollout_buffer, - n_rollout_steps=self.n_steps) + continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) if continue_training is False: break @@ -204,10 +224,8 @@ def learn(self, fps = int(self.num_timesteps / (time.time() - self.start_time)) logger.record("time/iterations", iteration, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: - logger.record("rollout/ep_rew_mean", - safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) - logger.record("rollout/ep_len_mean", - safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) + logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) + logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) logger.record("time/fps", fps) logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 276a0f318..acadd8845 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -1,24 +1,28 @@ """Policies: abstract base class and concrete implementations.""" -from abc import ABC, abstractmethod import collections -from typing import Union, Type, Dict, List, Tuple, Optional, Any, Callable +from abc import ABC, abstractmethod from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym -import torch as th -import torch.nn as nn import numpy as np - -from stable_baselines3.common.preprocessing import preprocess_obs, is_image_space, get_action_dim -from stable_baselines3.common.torch_layers import (FlattenExtractor, BaseFeaturesExtractor, create_mlp, - NatureCNN, MlpExtractor) +import torch as th +from torch import nn as nn + +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + Distribution, + MultiCategoricalDistribution, + StateDependentNoiseDistribution, + make_proba_distribution, +) +from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp from stable_baselines3.common.utils import get_device, is_vectorized_observation from stable_baselines3.common.vec_env import VecTransposeImage -from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, - DiagGaussianDistribution, CategoricalDistribution, - MultiCategoricalDistribution, BernoulliDistribution, - StateDependentNoiseDistribution) class BaseModel(nn.Module, ABC): @@ -44,16 +48,18 @@ class BaseModel(nn.Module, ABC): excluding the learning rate, to pass to the optimizer """ - def __init__(self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - device: Union[th.device, str] = 'auto', - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - features_extractor: Optional[nn.Module] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + device: Union[th.device, str] = "auto", + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor: Optional[nn.Module] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): super(BaseModel, self).__init__() if optimizer_kwargs is None: @@ -86,7 +92,7 @@ def extract_features(self, obs: th.Tensor) -> th.Tensor: :param obs: (th.Tensor) :return: (th.Tensor) """ - assert self.features_extractor is not None, 'No feature extractor was set' + assert self.features_extractor is not None, "No feature extractor was set" preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) return self.features_extractor(preprocessed_obs) @@ -112,10 +118,10 @@ def save(self, path: str) -> None: :param path: (str) """ - th.save({'state_dict': self.state_dict(), 'data': self._get_data()}, path) + th.save({"state_dict": self.state_dict(), "data": self._get_data()}, path) @classmethod - def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel': + def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": """ Load model from path. @@ -126,9 +132,9 @@ def load(cls, path: str, device: Union[th.device, str] = 'auto') -> 'BaseModel': device = get_device(device) saved_variables = th.load(path, map_location=device) # Create policy object - model = cls(**saved_variables['data']) # pytype: disable=not-instantiable + model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable # Load weights - model.load_state_dict(saved_variables['state_dict']) + model.load_state_dict(saved_variables["state_dict"]) model.to(device) return model @@ -159,6 +165,7 @@ class BasePolicy(BaseModel): :param squash_output: (bool) For continuous actions, whether the output is squashed or not using a ``tanh()`` function. """ + def __init__(self, *args, squash_output: bool = False, **kwargs): super(BasePolicy, self).__init__(*args, **kwargs) self._squash_output = squash_output @@ -196,11 +203,13 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te :return: (th.Tensor) Taken action according to the policy """ - def predict(self, - observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, - deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]: + def predict( + self, + observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Get the policy action and state from an observation (and optional state). Includes sugar-coating to handle different observations (e.g. normalizing images). @@ -222,12 +231,15 @@ def predict(self, # Handle the different cases for images # as PyTorch use channel first format if is_image_space(self.observation_space): - if not (observation.shape == self.observation_space.shape - or observation.shape[1:] == self.observation_space.shape): + if not ( + observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape + ): # Try to re-order the channels transpose_obs = VecTransposeImage.transpose_image(observation) - if (transpose_obs.shape == self.observation_space.shape - or transpose_obs.shape[1:] == self.observation_space.shape): + if ( + transpose_obs.shape == self.observation_space.shape + or transpose_obs.shape[1:] == self.observation_space.shape + ): observation = transpose_obs vectorized_env = is_vectorized_observation(observation, self.observation_space) @@ -313,40 +325,44 @@ class ActorCriticPolicy(BasePolicy): excluding the learning rate, to pass to the optimizer """ - def __init__(self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable[[float], float], - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable[[float], float], + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): if optimizer_kwargs is None: optimizer_kwargs = {} # Small values to avoid NaN in Adam optimizer if optimizer_class == th.optim.Adam: - optimizer_kwargs['eps'] = 1e-5 - - super(ActorCriticPolicy, self).__init__(observation_space, - action_space, - device, - features_extractor_class, - features_extractor_kwargs, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - squash_output=squash_output) + optimizer_kwargs["eps"] = 1e-5 + + super(ActorCriticPolicy, self).__init__( + observation_space, + action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output, + ) # Default network architecture, from stable-baselines if net_arch is None: @@ -359,8 +375,7 @@ def __init__(self, self.activation_fn = activation_fn self.ortho_init = ortho_init - self.features_extractor = features_extractor_class(self.observation_space, - **self.features_extractor_kwargs) + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.normalize_images = normalize_images @@ -369,10 +384,10 @@ def __init__(self, # Keyword arguments for gSDE distribution if use_sde: dist_kwargs = { - 'full_std': full_std, - 'squash_output': squash_output, - 'use_expln': use_expln, - 'learn_features': sde_net_arch is not None + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": sde_net_arch is not None, } self.sde_features_extractor = None @@ -390,22 +405,24 @@ def _get_data(self) -> Dict[str, Any]: default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) - data.update(dict( - net_arch=self.net_arch, - activation_fn=self.activation_fn, - use_sde=self.use_sde, - log_std_init=self.log_std_init, - squash_output=default_none_kwargs['squash_output'], - full_std=default_none_kwargs['full_std'], - sde_net_arch=default_none_kwargs['sde_net_arch'], - use_expln=default_none_kwargs['use_expln'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - ortho_init=self.ortho_init, - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs - )) + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=default_none_kwargs["squash_output"], + full_std=default_none_kwargs["full_std"], + sde_net_arch=default_none_kwargs["sde_net_arch"], + use_expln=default_none_kwargs["use_expln"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) return data def reset_noise(self, n_envs: int = 1) -> None: @@ -414,8 +431,7 @@ def reset_noise(self, n_envs: int = 1) -> None: :param n_envs: (int) """ - assert isinstance(self.action_dist, - StateDependentNoiseDistribution), 'reset_noise() is only available when using gSDE' + assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" self.action_dist.sample_weights(self.log_std, batch_size=n_envs) def _build(self, lr_schedule: Callable[[float], float]) -> None: @@ -428,25 +444,27 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None: # Note: If net_arch is None and some features extractor is used, # net_arch here is an empty list and mlp_extractor does not # really contain any layers (acts like an identity module). - self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch, - activation_fn=self.activation_fn, device=self.device) + self.mlp_extractor = MlpExtractor( + self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device + ) latent_dim_pi = self.mlp_extractor.latent_dim_pi # Separate feature extractor for gSDE if self.sde_net_arch is not None: - self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(self.features_dim, - self.sde_net_arch, - self.activation_fn) + self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( + self.features_dim, self.sde_net_arch, self.activation_fn + ) if isinstance(self.action_dist, DiagGaussianDistribution): - self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi, - log_std_init=self.log_std_init) + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) elif isinstance(self.action_dist, StateDependentNoiseDistribution): latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim - self.action_net, self.log_std = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi, - latent_sde_dim=latent_sde_dim, - log_std_init=self.log_std_init) + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init + ) elif isinstance(self.action_dist, CategoricalDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) elif isinstance(self.action_dist, MultiCategoricalDistribution): @@ -468,7 +486,7 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None: self.features_extractor: np.sqrt(2), self.mlp_extractor: np.sqrt(2), self.action_net: 0.01, - self.value_net: 1 + self.value_net: 1, } for module, gain in module_gains.items(): module.apply(partial(self.init_weights, gain=gain)) @@ -476,8 +494,7 @@ def _build(self, lr_schedule: Callable[[float], float]) -> None: # Setup optimizer with initial learning rate self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - def forward(self, obs: th.Tensor, - deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Forward pass in all the networks (actor and critic) @@ -512,8 +529,7 @@ def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: latent_sde = self.sde_features_extractor(features) return latent_pi, latent_vf, latent_sde - def _get_action_dist_from_latent(self, latent_pi: th.Tensor, - latent_sde: Optional[th.Tensor] = None) -> Distribution: + def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution: """ Retrieve action distribution given the latent codes. @@ -537,7 +553,7 @@ def _get_action_dist_from_latent(self, latent_pi: th.Tensor, elif isinstance(self.action_dist, StateDependentNoiseDistribution): return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) else: - raise ValueError('Invalid action distribution') + raise ValueError("Invalid action distribution") def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: """ @@ -551,8 +567,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) return distribution.get_actions(deterministic=deterministic) - def evaluate_actions(self, obs: th.Tensor, - actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: """ Evaluate actions according to the current policy, given the observations. @@ -604,43 +619,47 @@ class ActorCriticCnnPolicy(ActorCriticPolicy): excluding the learning rate, to pass to the optimizer """ - def __init__(self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): - super(ActorCriticCnnPolicy, self).__init__(observation_space, - action_space, - lr_schedule, - net_arch, - device, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - sde_net_arch, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(ActorCriticCnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + device, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) class ContinuousCritic(BaseModel): @@ -669,19 +688,25 @@ class ContinuousCritic(BaseModel): :param n_critics: (int) Number of critic networks to create. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto', - n_critics: int = 2): - super().__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + device: Union[th.device, str] = "auto", + n_critics: int = 2, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + device=device, + ) action_dim = get_action_dim(self.action_space) @@ -690,7 +715,7 @@ def __init__(self, observation_space: gym.spaces.Space, for idx in range(n_critics): q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) q_net = nn.Sequential(*q_net) - self.add_module(f'qf{idx}', q_net) + self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: @@ -711,9 +736,9 @@ def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: return self.q_networks[0](th.cat([features, actions], dim=1)) -def create_sde_features_extractor(features_dim: int, - sde_net_arch: List[int], - activation_fn: Type[nn.Module]) -> Tuple[nn.Sequential, int]: +def create_sde_features_extractor( + features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module] +) -> Tuple[nn.Sequential, int]: """ Create the neural network that will be used to extract features for the gSDE exploration function. @@ -747,8 +772,10 @@ def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[ if base_policy_type not in _policy_registry: raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") if name not in _policy_registry[base_policy_type]: - raise KeyError(f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!") + raise KeyError( + f"Error: unknown policy type {name}," + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!" + ) return _policy_registry[base_policy_type][name] diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 849756f17..3efaf9b53 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,14 +1,12 @@ from typing import Tuple +import numpy as np import torch as th -import torch.nn.functional as F from gym import spaces -import numpy as np +from torch.nn import functional as F -def is_image_space(observation_space: spaces.Space, - channels_last: bool = True, - check_channels: bool = False) -> bool: +def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool: """ Check if a observation space has the shape, limits and dtype of a valid image. @@ -45,8 +43,7 @@ def is_image_space(observation_space: spaces.Space, return False -def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, - normalize_images: bool = True) -> th.Tensor: +def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor: """ Preprocess observation to be to a neural network. For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) @@ -69,9 +66,13 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, elif isinstance(observation_space, spaces.MultiDiscrete): # Tensor concatenation of one hot encodings of each Categorical sub-space - return th.cat([F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() - for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1))], - dim=-1).view(obs.shape[0], sum(observation_space.nvec)) + return th.cat( + [ + F.one_hot(obs_.long(), num_classes=int(observation_space.nvec[idx])).float() + for idx, obs_ in enumerate(th.split(obs.long(), 1, dim=1)) + ], + dim=-1, + ).view(obs.shape[0], sum(observation_space.nvec)) elif isinstance(observation_space, spaces.MultiBinary): return obs.float() @@ -91,13 +92,13 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: return observation_space.shape elif isinstance(observation_space, spaces.Discrete): # Observation is an int - return 1, + return (1,) elif isinstance(observation_space, spaces.MultiDiscrete): # Number of discrete features - return int(len(observation_space.nvec)), + return (int(len(observation_space.nvec)),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - return int(observation_space.n), + return (int(observation_space.n),) else: raise NotImplementedError() diff --git a/stable_baselines3/common/results_plotter.py b/stable_baselines3/common/results_plotter.py index 4f879be2a..8fe805c1c 100644 --- a/stable_baselines3/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -1,17 +1,17 @@ -from typing import Tuple, Callable, List, Optional +from typing import Callable, List, Optional, Tuple import numpy as np import pandas as pd + # import matplotlib # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode -import matplotlib.pyplot as plt +from matplotlib import pyplot as plt from stable_baselines3.common.monitor import load_results - -X_TIMESTEPS = 'timesteps' -X_EPISODES = 'episodes' -X_WALLTIME = 'walltime_hrs' +X_TIMESTEPS = "timesteps" +X_EPISODES = "episodes" +X_WALLTIME = "walltime_hrs" POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME] EPISODES_WINDOW = 100 @@ -29,8 +29,7 @@ def rolling_window(array: np.ndarray, window: int) -> np.ndarray: return np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides) -def window_func(var_1: np.ndarray, var_2: np.ndarray, - window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: +def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callable) -> Tuple[np.ndarray, np.ndarray]: """ Apply a function to the rolling window of 2 arrays @@ -42,7 +41,7 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, """ var_2_window = rolling_window(var_2, window) function_on_var2 = func(var_2_window, axis=-1) - return var_1[window - 1:], function_on_var2 + return var_1[window - 1 :], function_on_var2 def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]: @@ -62,15 +61,16 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray y_var = data_frame.r.values elif x_axis == X_WALLTIME: # Convert to hours - x_var = data_frame.t.values / 3600. + x_var = data_frame.t.values / 3600.0 y_var = data_frame.r.values else: raise NotImplementedError return x_var, y_var -def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]], - x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2)) -> None: +def plot_curves( + xy_list: List[Tuple[np.ndarray, np.ndarray]], x_axis: str, title: str, figsize: Tuple[int, int] = (8, 2) +) -> None: """ plot the curves @@ -98,8 +98,9 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]], plt.tight_layout() -def plot_results(dirs: List[str], num_timesteps: Optional[int], - x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)) -> None: +def plot_results( + dirs: List[str], num_timesteps: Optional[int], x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2) +) -> None: """ Plot the results using csv files from ``Monitor`` wrapper. diff --git a/stable_baselines3/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py index f50310676..b98fccc65 100644 --- a/stable_baselines3/common/running_mean_std.py +++ b/stable_baselines3/common/running_mean_std.py @@ -22,9 +22,7 @@ def update(self, arr: np.ndarray) -> None: batch_count = arr.shape[0] self.update_from_moments(batch_mean, batch_var, batch_count) - def update_from_moments(self, batch_mean: np.ndarray, - batch_var: np.ndarray, - batch_count: int) -> None: + def update_from_moments(self, batch_mean: np.ndarray, batch_var: np.ndarray, batch_count: int) -> None: delta = batch_mean - self.mean tot_count = self.count + batch_count diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 0d2fe9823..51fa8cd17 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -2,19 +2,19 @@ Save util taken from stable_baselines used to serialize data (class parameters) of model classes """ -import io -import os -import json import base64 import functools -from typing import Dict, Any, Tuple, Optional, Union -import warnings -import zipfile +import io +import json +import os import pathlib import pickle +import warnings +import zipfile +from typing import Any, Dict, Optional, Tuple, Union -import torch as th import cloudpickle +import torch as th from stable_baselines3.common.type_aliases import TensorDict from stable_baselines3.common.utils import get_device @@ -112,11 +112,7 @@ def data_to_json(data: Dict[str, Any]) -> str: # e.g. numpy scalars) if hasattr(data_item, "__dict__") or isinstance(data_item, dict): # Take elements from __dict__ for custom classes - item_generator = ( - data_item.items - if isinstance(data_item, dict) - else data_item.__dict__.items - ) + item_generator = data_item.items if isinstance(data_item, dict) else data_item.__dict__.items for variable_name, variable_item in item_generator(): # Check if serializable. If not, just include the # string-representation of the object. @@ -130,9 +126,7 @@ def data_to_json(data: Dict[str, Any]) -> str: return json_string -def json_to_data( - json_string: str, custom_objects: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: +def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Turn JSON serialization of class-parameters back into dictionary. @@ -181,9 +175,7 @@ def json_to_data( @functools.singledispatch -def open_path( - path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None -): +def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose=0, suffix=None): """ Opens a path for reading or writing with a preferred suffix and raises debug information. If the provided path is a derivative of io.BufferedIOBase it ensures that the file @@ -221,9 +213,7 @@ def open_path( @open_path.register(str) -def open_path_str( - path: str, mode: str, verbose=0, suffix=None -) -> io.BufferedIOBase: +def open_path_str(path: str, mode: str, verbose=0, suffix=None) -> io.BufferedIOBase: """ Open a path given by a string. If writing to the path, the function ensures that the path exists. @@ -240,9 +230,7 @@ def open_path_str( @open_path.register(pathlib.Path) -def open_path_pathlib( - path: pathlib.Path, mode: str, verbose=0, suffix=None -) -> io.BufferedIOBase: +def open_path_pathlib(path: pathlib.Path, mode: str, verbose=0, suffix=None) -> io.BufferedIOBase: """ Open a path given by a string. If writing to the path, the function ensures that the path exists. @@ -331,9 +319,7 @@ def save_to_zip_file( th.save(dict_, param_file) -def save_to_pkl( - path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0 -) -> None: +def save_to_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], obj, verbose=0) -> None: """ Save an object to path creating the necessary folders along the way. If the path exists and is a directory, it will raise a warning and rename the path. @@ -364,9 +350,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose=0) def load_from_zip_file( - load_path: Union[str, pathlib.Path, io.BufferedIOBase], - load_data: bool = True, - verbose=0, + load_path: Union[str, pathlib.Path, io.BufferedIOBase], load_data: bool = True, verbose=0, ) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]): """ Load model data from a .zip archive @@ -412,10 +396,7 @@ def load_from_zip_file( # check for all other .pth files other_files = [ - file_name - for file_name in namelist - if os.path.splitext(file_name)[1] == ".pth" - and file_name != "tensors.pth" + file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "tensors.pth" ] # if there are any other files which end with .pth and aren't "params.pth" # assume that they each are optimizer parameters @@ -429,9 +410,7 @@ def load_from_zip_file( # go to start of file file_content.seek(0) # load the parameters with the right ``map_location`` - params[os.path.splitext(file_path)[0]] = th.load( - file_content, map_location=device - ) + params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device) except zipfile.BadZipFile: # load_path wasn't a zip file raise ValueError(f"Error: the file {load_path} wasn't a zip-file") diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index f92b50fc8..9c74017cb 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -1,10 +1,9 @@ -from typing import Union, Type, Dict, List, Tuple - from itertools import zip_longest +from typing import Dict, List, Tuple, Type, Union import gym import torch as th -import torch.nn as nn +from torch import nn as nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space from stable_baselines3.common.utils import get_device @@ -60,22 +59,25 @@ class NatureCNN(BaseFeaturesExtractor): This corresponds to the number of unit for the last layer. """ - def __init__(self, observation_space: gym.spaces.Box, - features_dim: int = 512): + def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): super(NatureCNN, self).__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper - assert is_image_space(observation_space), ('You should use NatureCNN ' - f'only with images not with {observation_space} ' - '(you are probably using `CnnPolicy` instead of `MlpPolicy`)') + assert is_image_space(observation_space), ( + "You should use NatureCNN " + f"only with images not with {observation_space} " + "(you are probably using `CnnPolicy` instead of `MlpPolicy`)" + ) n_input_channels = observation_space.shape[0] - self.cnn = nn.Sequential(nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), - nn.ReLU(), - nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), - nn.ReLU(), - nn.Flatten()) + self.cnn = nn.Sequential( + nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), + nn.ReLU(), + nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0), + nn.ReLU(), + nn.Flatten(), + ) # Compute shape by doing one forward pass with th.no_grad(): @@ -87,11 +89,9 @@ def forward(self, observations: th.Tensor) -> th.Tensor: return self.linear(self.cnn(observations)) -def create_mlp(input_dim: int, - output_dim: int, - net_arch: List[int], - activation_fn: Type[nn.Module] = nn.ReLU, - squash_output: bool = False) -> List[nn.Module]: +def create_mlp( + input_dim: int, output_dim: int, net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False +) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. @@ -152,10 +152,13 @@ class MlpExtractor(nn.Module): :param device: (th.device) """ - def __init__(self, feature_dim: int, - net_arch: List[Union[int, Dict[str, List[int]]]], - activation_fn: Type[nn.Module], - device: Union[th.device, str] = 'auto'): + def __init__( + self, + feature_dim: int, + net_arch: List[Union[int, Dict[str, List[int]]]], + activation_fn: Type[nn.Module], + device: Union[th.device, str] = "auto", + ): super(MlpExtractor, self).__init__() device = get_device(device) shared_net, policy_net, value_net = [], [], [] @@ -173,13 +176,13 @@ def __init__(self, feature_dim: int, last_layer_dim_shared = layer_size else: assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" - if 'pi' in layer: - assert isinstance(layer['pi'], list), "Error: net_arch[-1]['pi'] must contain a list of integers." - policy_only_layers = layer['pi'] + if "pi" in layer: + assert isinstance(layer["pi"], list), "Error: net_arch[-1]['pi'] must contain a list of integers." + policy_only_layers = layer["pi"] - if 'vf' in layer: - assert isinstance(layer['vf'], list), "Error: net_arch[-1]['vf'] must contain a list of integers." - value_only_layers = layer['vf'] + if "vf" in layer: + assert isinstance(layer["vf"], list), "Error: net_arch[-1]['vf'] must contain a list of integers." + value_only_layers = layer["vf"] break # From here on the network splits up in policy and value network last_layer_dim_pi = last_layer_dim_shared diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 5ef94f91e..e16f43539 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -1,14 +1,13 @@ """Common aliases for type hints""" -from typing import Union, Dict, Any, NamedTuple, List, Callable, Tuple +from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union +import gym import numpy as np import torch as th -import gym -from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.callbacks import BaseCallback - +from stable_baselines3.common.vec_env import VecEnv GymEnv = Union[gym.Env, VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index fbe357219..9c00f0e6b 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -1,13 +1,13 @@ -from collections import deque -from typing import Callable, Union, Optional -import random -import os import glob - +import os +import random +from collections import deque +from typing import Callable, Optional, Union import gym import numpy as np import torch as th + # Check if tensorboard is available for pytorch try: from torch.utils.tensorboard import SummaryWriter @@ -15,8 +15,8 @@ SummaryWriter = None from stable_baselines3.common import logger -from stable_baselines3.common.type_aliases import GymEnv from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.type_aliases import GymEnv from stable_baselines3.common.vec_env import VecTransposeImage @@ -68,7 +68,7 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> :param learning_rate: (float) """ for param_group in optimizer.param_groups: - param_group['lr'] = learning_rate + param_group["lr"] = learning_rate def get_schedule_fn(value_schedule: Union[Callable, float]) -> Callable: @@ -128,7 +128,7 @@ def func(_): return func -def get_device(device: Union[th.device, str] = 'auto') -> th.device: +def get_device(device: Union[th.device, str] = "auto") -> th.device: """ Retrieve PyTorch device. It checks that the requested device is available first. @@ -139,19 +139,19 @@ def get_device(device: Union[th.device, str] = 'auto') -> th.device: :return: (th.device) """ # Cuda by default - if device == 'auto': - device = 'cuda' + if device == "auto": + device = "cuda" # Force conversion to th.device device = th.device(device) # Cuda not available - if device == th.device('cuda') and not th.cuda.is_available(): - return th.device('cpu') + if device == th.device("cuda") and not th.cuda.is_available(): + return th.device("cpu") return device -def get_latest_run_id(log_path: Optional[str] = None, log_name: str = '') -> int: +def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int: """ Returns the latest run number for the given log name and log path, by finding the greatest number in the directories. @@ -167,8 +167,9 @@ def get_latest_run_id(log_path: Optional[str] = None, log_name: str = '') -> int return max_run_id -def configure_logger(verbose: int = 0, tensorboard_log: Optional[str] = None, - tb_log_name: str = '', reset_num_timesteps: bool = True) -> None: +def configure_logger( + verbose: int = 0, tensorboard_log: Optional[str] = None, tb_log_name: str = "", reset_num_timesteps: bool = True +) -> None: """ Configure the logger's outputs. @@ -202,13 +203,17 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a :param observation_space: (gym.spaces.Space) Observation space to check against :param action_space: (gym.spaces.Space) Action space to check against """ - if (observation_space != env.observation_space + if ( + observation_space != env.observation_space # Special cases for images that need to be transposed - and not (is_image_space(env.observation_space) - and observation_space == VecTransposeImage.transpose_space(env.observation_space))): - raise ValueError(f'Observation spaces do not match: {observation_space} != {env.observation_space}') + and not ( + is_image_space(env.observation_space) + and observation_space == VecTransposeImage.transpose_space(env.observation_space) + ) + ): + raise ValueError(f"Observation spaces do not match: {observation_space} != {env.observation_space}") if action_space != env.action_space: - raise ValueError(f'Action spaces do not match: {action_space} != {env.action_space}') + raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}") def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool: @@ -226,18 +231,21 @@ def is_vectorized_observation(observation: np.ndarray, observation_space: gym.sp elif observation.shape[1:] == observation_space.shape: return True else: - raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " - + f"Box environment, please use {observation_space.shape} " - + "or (n_env, {}) for the observation shape." - .format(", ".join(map(str, observation_space.shape)))) + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + f"Box environment, please use {observation_space.shape} " + + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) + ) elif isinstance(observation_space, gym.spaces.Discrete): if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' return False elif len(observation.shape) == 1: return True else: - raise ValueError(f"Error: Unexpected observation shape {observation.shape} for " - + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape." + ) elif isinstance(observation_space, gym.spaces.MultiDiscrete): if observation.shape == (len(observation_space.nvec),): @@ -245,21 +253,26 @@ def is_vectorized_observation(observation: np.ndarray, observation_space: gym.sp elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): return True else: - raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " - + f"environment, please use ({len(observation_space.nvec)},) or " - + f"(n_env, {len(observation_space.nvec)}) for the observation shape.") + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + + f"environment, please use ({len(observation_space.nvec)},) or " + + f"(n_env, {len(observation_space.nvec)}) for the observation shape." + ) elif isinstance(observation_space, gym.spaces.MultiBinary): if observation.shape == (observation_space.n,): return False elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: return True else: - raise ValueError(f"Error: Unexpected observation shape {observation.shape} for MultiBinary " - + f"environment, please use ({observation_space.n},) or " - + f"(n_env, {observation_space.n}) for the observation shape.") + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + + f"environment, please use ({observation_space.n},) or " + + f"(n_env, {observation_space.n}) for the observation shape." + ) else: - raise ValueError("Error: Cannot determine if the observation is vectorized " - + f" with the space type {observation_space}.") + raise ValueError( + "Error: Cannot determine if the observation is vectorized " + f" with the space type {observation_space}." + ) def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 535ed031b..9944130b2 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -1,24 +1,29 @@ # flake8: noqa F401 import typing -from typing import Optional, Union from copy import deepcopy +from typing import Optional, Union -from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError, - VecEnv, VecEnvWrapper, CloudpickleWrapper) +from stable_baselines3.common.vec_env.base_vec_env import ( + AlreadySteppingError, + CloudpickleWrapper, + NotSteppingError, + VecEnv, + VecEnvWrapper, +) from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack from stable_baselines3.common.vec_env.vec_normalize import VecNormalize from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder -from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan # Avoid circular import if typing.TYPE_CHECKING: from stable_baselines3.common.type_aliases import GymEnv -def unwrap_vec_normalize(env: Union['GymEnv', VecEnv]) -> Optional[VecNormalize]: +def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]: """ :param env: (gym.Env) :return: (VecNormalize) @@ -32,7 +37,7 @@ def unwrap_vec_normalize(env: Union['GymEnv', VecEnv]) -> Optional[VecNormalize] # Define here to avoid circular import -def sync_envs_normalization(env: 'GymEnv', eval_env: 'GymEnv') -> None: +def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None: """ Sync eval env and train env when using VecNormalize diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index b2caa461c..6b5a42143 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -1,10 +1,10 @@ import inspect import pickle from abc import ABC, abstractmethod -from typing import Sequence, Optional, List, Union +from typing import List, Optional, Sequence, Union -import numpy as np import cloudpickle +import numpy as np from stable_baselines3.common import logger @@ -42,7 +42,7 @@ class AlreadySteppingError(Exception): """ def __init__(self): - msg = 'already running an async step' + msg = "already running an async step" Exception.__init__(self, msg) @@ -53,7 +53,7 @@ class NotSteppingError(Exception): """ def __init__(self): - msg = 'not running an async step' + msg = "not running an async step" Exception.__init__(self, msg) @@ -65,9 +65,8 @@ class VecEnv(ABC): :param observation_space: (Gym Space) the observation space :param action_space: (Gym Space) the action space """ - metadata = { - 'render.modes': ['human', 'rgb_array'] - } + + metadata = {"render.modes": ["human", "rgb_array"]} def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs @@ -168,7 +167,7 @@ def get_images(self) -> Sequence[np.ndarray]: """ raise NotImplementedError - def render(self, mode: str = 'human'): + def render(self, mode: str = "human"): """ Gym environment rendering @@ -177,19 +176,20 @@ def render(self, mode: str = 'human'): try: imgs = self.get_images() except NotImplementedError: - logger.warn(f'Render not defined for {self}') + logger.warn(f"Render not defined for {self}") return # Create a big image by tiling images from subprocesses bigimg = tile_images(imgs) - if mode == 'human': + if mode == "human": import cv2 # pytype:disable=import-error - cv2.imshow('vecenv', bigimg[:, :, ::-1]) + + cv2.imshow("vecenv", bigimg[:, :, ::-1]) cv2.waitKey(1) - elif mode == 'rgb_array': + elif mode == "rgb_array": return bigimg else: - raise NotImplementedError(f'Render mode {mode} is not supported by VecEnvs') + raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs") @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -247,8 +247,12 @@ class VecEnvWrapper(VecEnv): def __init__(self, venv, observation_space=None, action_space=None): self.venv = venv - VecEnv.__init__(self, num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, - action_space=action_space or venv.action_space) + VecEnv.__init__( + self, + num_envs=venv.num_envs, + observation_space=observation_space or venv.observation_space, + action_space=action_space or venv.action_space, + ) self.class_attributes = dict(inspect.getmembers(self.__class__)) def step_async(self, actions): @@ -268,7 +272,7 @@ def seed(self, seed=None): def close(self): return self.venv.close() - def render(self, mode: str = 'human'): + def render(self, mode: str = "human"): return self.venv.render(mode=mode) def get_images(self): @@ -291,8 +295,10 @@ def __getattr__(self, name): blocked_class = self.getattr_depth_check(name, already_found=False) if blocked_class is not None: own_class = f"{type(self).__module__}.{type(self).__name__}" - error_str = (f"Error: Recursive attribute lookup for {name} from {own_class} is " - "ambiguous and hides attribute from {blocked_class}") + error_str = ( + f"Error: Recursive attribute lookup for {name} from {own_class} is " + "ambiguous and hides attribute from {blocked_class}" + ) raise AttributeError(error_str) return self.getattr_recursive(name) @@ -315,7 +321,7 @@ def getattr_recursive(self, name): all_attributes = self._get_all_attributes() if name in all_attributes: # attribute is present in this wrapper attr = getattr(self, name) - elif hasattr(self.venv, 'getattr_recursive'): + elif hasattr(self.venv, "getattr_recursive"): # Attribute not present, child is wrapper. Call getattr_recursive rather than getattr # to avoid a duplicate call to getattr_depth_check. attr = self.venv.getattr_recursive(name) diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index a11108b6b..c577bd87f 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -26,9 +26,7 @@ def __init__(self, env_fns): obs_space = env.observation_space self.keys, shapes, dtypes = obs_space_info(obs_space) - self.buf_obs = OrderedDict([ - (k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) - for k in self.keys]) + self.buf_obs = OrderedDict([(k, np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k])) for k in self.keys]) self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_infos = [{} for _ in range(self.num_envs)] @@ -40,15 +38,15 @@ def step_async(self, actions): def step_wait(self): for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] =\ - self.envs[env_idx].step(self.actions[env_idx]) + obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + self.actions[env_idx] + ) if self.buf_dones[env_idx]: # save final observation where user can get it, then reset - self.buf_infos[env_idx]['terminal_observation'] = obs + self.buf_infos[env_idx]["terminal_observation"] = obs obs = self.envs[env_idx].reset() self._save_obs(env_idx, obs) - return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), - deepcopy(self.buf_infos)) + return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed=None): seeds = list() @@ -67,9 +65,9 @@ def close(self): env.close() def get_images(self) -> Sequence[np.ndarray]: - return [env.render(mode='rgb_array') for env in self.envs] + return [env.render(mode="rgb_array") for env in self.envs] - def render(self, mode: str = 'human'): + def render(self, mode: str = "human"): """ Gym environment rendering. If there are multiple environments then they are tiled together in one image via ``BaseVecEnv.render()``. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index b12218f62..937e94365 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -5,7 +5,7 @@ import gym import numpy as np -from stable_baselines3.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper +from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv def _worker(remote, parent_remote, env_fn_wrapper): @@ -14,32 +14,32 @@ def _worker(remote, parent_remote, env_fn_wrapper): while True: try: cmd, data = remote.recv() - if cmd == 'step': + if cmd == "step": observation, reward, done, info = env.step(data) if done: # save final observation where user can get it, then reset - info['terminal_observation'] = observation + info["terminal_observation"] = observation observation = env.reset() remote.send((observation, reward, done, info)) - elif cmd == 'seed': + elif cmd == "seed": remote.send(env.seed(data)) - elif cmd == 'reset': + elif cmd == "reset": observation = env.reset() remote.send(observation) - elif cmd == 'render': + elif cmd == "render": remote.send(env.render(data)) - elif cmd == 'close': + elif cmd == "close": env.close() remote.close() break - elif cmd == 'get_spaces': + elif cmd == "get_spaces": remote.send((env.observation_space, env.action_space)) - elif cmd == 'env_method': + elif cmd == "env_method": method = getattr(env, data[0]) remote.send(method(*data[1], **data[2])) - elif cmd == 'get_attr': + elif cmd == "get_attr": remote.send(getattr(env, data)) - elif cmd == 'set_attr': + elif cmd == "set_attr": remote.send(setattr(env, data[0], data[1])) else: raise NotImplementedError(f"`{cmd}` is not implemented in the worker") @@ -80,8 +80,8 @@ def __init__(self, env_fns, start_method=None): # Fork is not a thread safe method (see issue #217) # but is more user friendly (does not require to wrap the code in # a `if __name__ == "__main__":`) - forkserver_available = 'forkserver' in multiprocessing.get_all_start_methods() - start_method = 'forkserver' if forkserver_available else 'spawn' + forkserver_available = "forkserver" in multiprocessing.get_all_start_methods() + start_method = "forkserver" if forkserver_available else "spawn" ctx = multiprocessing.get_context(start_method) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(n_envs)]) @@ -94,13 +94,13 @@ def __init__(self, env_fns, start_method=None): self.processes.append(process) work_remote.close() - self.remotes[0].send(('get_spaces', None)) + self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() VecEnv.__init__(self, len(env_fns), observation_space, action_space) def step_async(self, actions): for remote, action in zip(self.remotes, actions): - remote.send(('step', action)) + remote.send(("step", action)) self.waiting = True def step_wait(self): @@ -111,12 +111,12 @@ def step_wait(self): def seed(self, seed=None): for idx, remote in enumerate(self.remotes): - remote.send(('seed', seed + idx)) + remote.send(("seed", seed + idx)) return [remote.recv() for remote in self.remotes] def reset(self): for remote in self.remotes: - remote.send(('reset', None)) + remote.send(("reset", None)) obs = [remote.recv() for remote in self.remotes] return _flatten_obs(obs, self.observation_space) @@ -127,7 +127,7 @@ def close(self): for remote in self.remotes: remote.recv() for remote in self.remotes: - remote.send(('close', None)) + remote.send(("close", None)) for process in self.processes: process.join() self.closed = True @@ -136,7 +136,7 @@ def get_images(self) -> Sequence[np.ndarray]: for pipe in self.remotes: # gather images from subprocesses # `mode` will be taken into account later - pipe.send(('render', 'rgb_array')) + pipe.send(("render", "rgb_array")) imgs = [pipe.recv() for pipe in self.remotes] return imgs @@ -144,14 +144,14 @@ def get_attr(self, attr_name, indices=None): """Return attribute from vectorized environment (see base class).""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: - remote.send(('get_attr', attr_name)) + remote.send(("get_attr", attr_name)) return [remote.recv() for remote in target_remotes] def set_attr(self, attr_name, value, indices=None): """Set attribute inside vectorized environments (see base class).""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: - remote.send(('set_attr', (attr_name, value))) + remote.send(("set_attr", (attr_name, value))) for remote in target_remotes: remote.recv() @@ -159,7 +159,7 @@ def env_method(self, method_name, *method_args, indices=None, **method_kwargs): """Call instance methods of vectorized environments.""" target_remotes = self._get_target_remotes(indices) for remote in target_remotes: - remote.send(('env_method', (method_name, method_args, method_kwargs))) + remote.send(("env_method", (method_name, method_args, method_kwargs))) return [remote.recv() for remote in target_remotes] def _get_target_remotes(self, indices): diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 0e8629a71..0ebecb05b 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -61,7 +61,7 @@ def obs_space_info(obs_space): elif isinstance(obs_space, gym.spaces.Tuple): subspaces = {i: space for i, space in enumerate(obs_space.spaces)} else: - assert not hasattr(obs_space, 'spaces'), f"Unsupported structured space '{type(obs_space)}'" + assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'" subspaces = {None: obs_space} keys = [] shapes = {} diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index b32ddebd0..5e6b7f251 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -30,16 +30,14 @@ def step_wait(self): self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1) for i, done in enumerate(dones): if done: - if 'terminal_observation' in infos[i]: - old_terminal = infos[i]['terminal_observation'] - new_terminal = np.concatenate( - (self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) - infos[i]['terminal_observation'] = new_terminal + if "terminal_observation" in infos[i]: + old_terminal = infos[i]["terminal_observation"] + new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1) + infos[i]["terminal_observation"] = new_terminal else: - warnings.warn( - "VecFrameStack wrapping a VecEnv without terminal_observation info") + warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") self.stackedobs[i] = 0 - self.stackedobs[..., -observations.shape[-1]:] = observations + self.stackedobs[..., -observations.shape[-1] :] = observations return self.stackedobs, rewards, dones, infos def reset(self): @@ -48,7 +46,7 @@ def reset(self): """ obs = self.venv.reset() self.stackedobs[...] = 0 - self.stackedobs[..., -obs.shape[-1]:] = obs + self.stackedobs[..., -obs.shape[-1] :] = obs return self.stackedobs def close(self): diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 47939a577..213dd6a83 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -2,8 +2,8 @@ import numpy as np -from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper from stable_baselines3.common.running_mean_std import RunningMeanStd +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper class VecNormalize(VecEnvWrapper): @@ -21,8 +21,9 @@ class VecNormalize(VecEnvWrapper): :param epsilon: (float) To avoid division by zero """ - def __init__(self, venv, training=True, norm_obs=True, norm_reward=True, - clip_obs=10., clip_reward=10., gamma=0.99, epsilon=1e-8): + def __init__( + self, venv, training=True, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0, gamma=0.99, epsilon=1e-8 + ): VecEnvWrapper.__init__(self, venv) self.obs_rms = RunningMeanStd(shape=self.observation_space.shape) self.ret_rms = RunningMeanStd(shape=()) @@ -45,10 +46,10 @@ def __getstate__(self): Excludes self.venv, as in general VecEnv's may not be pickleable.""" state = self.__dict__.copy() # these attributes are not pickleable - del state['venv'] - del state['class_attributes'] + del state["venv"] + del state["class_attributes"] # these attributes depend on the above and so we would prefer not to pickle - del state['ret'] + del state["ret"] return state def __setstate__(self, state): @@ -59,7 +60,7 @@ def __setstate__(self, state): :param state: (dict)""" self.__dict__.update(state) - assert 'venv' not in state + assert "venv" not in state self.venv = None def set_venv(self, venv): @@ -110,9 +111,7 @@ def normalize_obs(self, obs): Calling this method does not update statistics. """ if self.norm_obs: - obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), - -self.clip_obs, - self.clip_obs) + obs = np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs) return obs def normalize_reward(self, reward): @@ -121,8 +120,7 @@ def normalize_reward(self, reward): Calling this method does not update statistics. """ if self.norm_reward: - reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), - -self.clip_reward, self.clip_reward) + reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward) return reward def unnormalize_obs(self, obs): diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 3d25649b1..5dc0364dd 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -1,9 +1,10 @@ import typing + import numpy as np from gym import spaces -from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper if typing.TYPE_CHECKING: from stable_baselines3.common.type_aliases import GymStepReturn # noqa: F401 @@ -18,7 +19,7 @@ class VecTransposeImage(VecEnvWrapper): """ def __init__(self, venv: VecEnv): - assert is_image_space(venv.observation_space), 'The observation space must be an image' + assert is_image_space(venv.observation_space), "The observation space must be an image" observation_space = self.transpose_space(venv.observation_space) super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) @@ -31,7 +32,7 @@ def transpose_space(observation_space: spaces.Box) -> spaces.Box: :param observation_space: (spaces.Box) :return: (spaces.Box) """ - assert is_image_space(observation_space), 'The observation space must be an image' + assert is_image_space(observation_space), "The observation space must be an image" width, height, channels = observation_space.shape new_shape = (channels, width, height) return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) @@ -48,7 +49,7 @@ def transpose_image(image: np.ndarray) -> np.ndarray: return np.transpose(image, (2, 0, 1)) return np.transpose(image, (0, 3, 1, 2)) - def step_wait(self) -> 'GymStepReturn': + def step_wait(self) -> "GymStepReturn": observations, rewards, dones, infos = self.venv.step_wait() return self.transpose_image(observations), rewards, dones, infos diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 1343ce66d..f7b6ffac7 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -24,8 +24,7 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: (str) Prefix to the video name """ - def __init__(self, venv, video_folder, record_video_trigger, - video_length=200, name_prefix='rl-video'): + def __init__(self, venv, video_folder, record_video_trigger, video_length=200, name_prefix="rl-video"): VecEnvWrapper.__init__(self, venv) @@ -39,7 +38,7 @@ def __init__(self, venv, video_folder, record_video_trigger, temp_env = temp_env.venv if isinstance(temp_env, DummyVecEnv) or isinstance(temp_env, SubprocVecEnv): - metadata = temp_env.get_attr('metadata')[0] + metadata = temp_env.get_attr("metadata")[0] else: metadata = temp_env.metadata @@ -67,12 +66,11 @@ def reset(self): def start_video_recorder(self): self.close_video_recorder() - video_name = f'{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}' + video_name = f"{self.name_prefix}-step-{self.step_id}-to-step-{self.step_id + self.video_length}" base_path = os.path.join(self.video_folder, video_name) - self.video_recorder = video_recorder.VideoRecorder(env=self.env, - base_path=base_path, - metadata={'step_id': self.step_id} - ) + self.video_recorder = video_recorder.VideoRecorder( + env=self.env, base_path=base_path, metadata={"step_id": self.step_id} + ) self.video_recorder.capture_frame() self.recorded_frames = 1 diff --git a/stable_baselines3/ddpg/__init__.py b/stable_baselines3/ddpg/__init__.py index 11a890b95..0b164b2de 100644 --- a/stable_baselines3/ddpg/__init__.py +++ b/stable_baselines3/ddpg/__init__.py @@ -1,2 +1,2 @@ from stable_baselines3.ddpg.ddpg import DDPG -from stable_baselines3.ddpg.policies import MlpPolicy, CnnPolicy +from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 83325b450..cc6e2899f 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -1,11 +1,12 @@ +from typing import Any, Callable, Dict, Optional, Type, Union + import torch as th -from typing import Type, Union, Callable, Optional, Dict, Any -from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback -from stable_baselines3.td3.td3 import TD3 from stable_baselines3.td3.policies import TD3Policy +from stable_baselines3.td3.td3 import TD3 class DDPG(TD3): @@ -50,67 +51,86 @@ class DDPG(TD3): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy: Union[str, Type[TD3Policy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 1e-3, - buffer_size: int = int(1e6), - learning_starts: int = 100, - batch_size: int = 100, - tau: float = 0.005, - gamma: float = 0.99, - train_freq: int = -1, - gradient_steps: int = -1, - n_episodes_rollout: int = 1, - action_noise: Optional[ActionNoise] = None, - optimize_memory_usage: bool = False, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Dict[str, Any] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): + def __init__( + self, + policy: Union[str, Type[TD3Policy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 1e-3, + buffer_size: int = int(1e6), + learning_starts: int = 100, + batch_size: int = 100, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: int = -1, + gradient_steps: int = -1, + n_episodes_rollout: int = 1, + action_noise: Optional[ActionNoise] = None, + optimize_memory_usage: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Dict[str, Any] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): - super(DDPG, self).__init__(policy=policy, - env=env, - learning_rate=learning_rate, - buffer_size=buffer_size, - learning_starts=learning_starts, - batch_size=batch_size, - tau=tau, gamma=gamma, - train_freq=train_freq, - gradient_steps=gradient_steps, - n_episodes_rollout=n_episodes_rollout, - action_noise=action_noise, - policy_kwargs=policy_kwargs, - tensorboard_log=tensorboard_log, - verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, - optimize_memory_usage=optimize_memory_usage, - # Remove all tricks from TD3 to obtain DDPG: - # we still need to specify target_policy_noise > 0 to avoid errors - policy_delay=1, target_noise_clip=0.0, target_policy_noise=0.1, - _init_setup_model=False) + super(DDPG, self).__init__( + policy=policy, + env=env, + learning_rate=learning_rate, + buffer_size=buffer_size, + learning_starts=learning_starts, + batch_size=batch_size, + tau=tau, + gamma=gamma, + train_freq=train_freq, + gradient_steps=gradient_steps, + n_episodes_rollout=n_episodes_rollout, + action_noise=action_noise, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + optimize_memory_usage=optimize_memory_usage, + # Remove all tricks from TD3 to obtain DDPG: + # we still need to specify target_policy_noise > 0 to avoid errors + policy_delay=1, + target_noise_clip=0.0, + target_policy_noise=0.1, + _init_setup_model=False, + ) # Use only one critic - if 'n_critics' not in self.policy_kwargs: - self.policy_kwargs['n_critics'] = 1 + if "n_critics" not in self.policy_kwargs: + self.policy_kwargs["n_critics"] = 1 if _init_setup_model: self._setup_model() - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "DDPG", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> OffPolicyAlgorithm: + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "DDPG", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: - return super(DDPG, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, - tb_log_name=tb_log_name, eval_log_path=eval_log_path, - reset_num_timesteps=reset_num_timesteps) + return super(DDPG, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/stable_baselines3/ddpg/policies.py b/stable_baselines3/ddpg/policies.py index b826c7af8..64b166f44 100644 --- a/stable_baselines3/ddpg/policies.py +++ b/stable_baselines3/ddpg/policies.py @@ -1,2 +1,2 @@ # DDPG can be view as a special case of TD3 -from stable_baselines3.td3.policies import MlpPolicy, CnnPolicy # noqa:F401 +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy # noqa:F401 diff --git a/stable_baselines3/dqn/__init__.py b/stable_baselines3/dqn/__init__.py index 8f6968354..4ae42872c 100644 --- a/stable_baselines3/dqn/__init__.py +++ b/stable_baselines3/dqn/__init__.py @@ -1,3 +1,2 @@ from stable_baselines3.dqn.dqn import DQN -from stable_baselines3.dqn.policies import MlpPolicy -from stable_baselines3.dqn.policies import CnnPolicy +from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 05e0a0ae9..55d4fb69a 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,8 +1,8 @@ -from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import numpy as np import torch as th -import torch.nn.functional as F +from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm @@ -55,41 +55,57 @@ class DQN(OffPolicyAlgorithm): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy: Union[str, Type[DQNPolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 1e-4, - buffer_size: int = 1000000, - learning_starts: int = 50000, - batch_size: Optional[int] = 32, - tau: float = 1.0, - gamma: float = 0.99, - train_freq: int = 4, - gradient_steps: int = 1, - n_episodes_rollout: int = -1, - optimize_memory_usage: bool = False, - target_update_interval: int = 10000, - exploration_fraction: float = 0.1, - exploration_initial_eps: float = 1.0, - exploration_final_eps: float = 0.05, - max_grad_norm: float = 10, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): - - super(DQN, self).__init__(policy, env, DQNPolicy, learning_rate, - buffer_size, learning_starts, batch_size, - tau, gamma, train_freq, gradient_steps, - n_episodes_rollout, action_noise=None, # No action noise - policy_kwargs=policy_kwargs, - tensorboard_log=tensorboard_log, - verbose=verbose, device=device, - create_eval_env=create_eval_env, - seed=seed, sde_support=False, - optimize_memory_usage=optimize_memory_usage) + def __init__( + self, + policy: Union[str, Type[DQNPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 1e-4, + buffer_size: int = 1000000, + learning_starts: int = 50000, + batch_size: Optional[int] = 32, + tau: float = 1.0, + gamma: float = 0.99, + train_freq: int = 4, + gradient_steps: int = 1, + n_episodes_rollout: int = -1, + optimize_memory_usage: bool = False, + target_update_interval: int = 10000, + exploration_fraction: float = 0.1, + exploration_initial_eps: float = 1.0, + exploration_final_eps: float = 0.05, + max_grad_norm: float = 10, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(DQN, self).__init__( + policy, + env, + DQNPolicy, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + n_episodes_rollout, + action_noise=None, # No action noise + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + sde_support=False, + optimize_memory_usage=optimize_memory_usage, + ) self.exploration_initial_eps = exploration_initial_eps self.exploration_final_eps = exploration_final_eps @@ -108,8 +124,9 @@ def __init__(self, policy: Union[str, Type[DQNPolicy]], def _setup_model(self) -> None: super(DQN, self)._setup_model() self._create_aliases() - self.exploration_schedule = get_linear_fn(self.exploration_initial_eps, self.exploration_final_eps, - self.exploration_fraction) + self.exploration_schedule = get_linear_fn( + self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction + ) def _create_aliases(self) -> None: self.q_net = self.policy.q_net @@ -164,12 +181,15 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Increase update counter self._n_updates += gradient_steps - logger.record("train/n_updates", self._n_updates, exclude='tensorboard') + logger.record("train/n_updates", self._n_updates, exclude="tensorboard") - def predict(self, observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, - deterministic: bool = False) -> Tuple[np.ndarray, Optional[np.ndarray]]: + def predict( + self, + observation: np.ndarray, + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. @@ -187,21 +207,30 @@ def predict(self, observation: np.ndarray, action, state = self.policy.predict(observation, state, mask, deterministic) return action, state - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "DQN", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> OffPolicyAlgorithm: - - return super(DQN, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, - tb_log_name=tb_log_name, eval_log_path=eval_log_path, - reset_num_timesteps=reset_num_timesteps) + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "DQN", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super(DQN, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) def excluded_save_params(self) -> List[str]: """ diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 718452229..ea0eaa278 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -1,10 +1,11 @@ -from typing import Optional, List, Callable, Union, Type, Any, Dict +from typing import Any, Callable, Dict, List, Optional, Type, Union import gym import torch as th -import torch.nn as nn +from torch import nn as nn + from stable_baselines3.common.policies import BasePolicy, register_policy -from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp class QNetwork(BasePolicy): @@ -20,18 +21,24 @@ class QNetwork(BasePolicy): dividing by 255.0 (True by default) """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - features_extractor: nn.Module, - features_dim: int, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True): - super(QNetwork, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor: nn.Module, + features_dim: int, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + ): + super(QNetwork, self).__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + device=device, + ) if net_arch is None: net_arch = [64, 64] @@ -63,13 +70,15 @@ def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Ten def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_arch, - features_dim=self.features_dim, - activation_fn=self.activation_fn, - features_extractor=self.features_extractor, - epsilon=self.epsilon, - )) + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor, + epsilon=self.epsilon, + ) + ) return data @@ -94,23 +103,29 @@ class DQNPolicy(BasePolicy): excluding the learning rate, to pass to the optimizer """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): - super(DQNPolicy, self).__init__(observation_space, action_space, - device, - features_extractor_class, - features_extractor_kwargs, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(DQNPolicy, self).__init__( + observation_space, + action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + ) if net_arch is None: if features_extractor_class == FlattenExtractor: @@ -118,22 +133,21 @@ def __init__(self, observation_space: gym.spaces.Space, else: net_arch = [] - self.features_extractor = features_extractor_class(self.observation_space, - **self.features_extractor_kwargs) + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch self.activation_fn = activation_fn self.normalize_images = normalize_images self.net_args = { - 'observation_space': self.observation_space, - 'action_space': self.action_space, - 'features_extractor': self.features_extractor, - 'features_dim': self.features_dim, - 'net_arch': self.net_arch, - 'activation_fn': self.activation_fn, - 'normalize_images': normalize_images, - 'device': device + "observation_space": self.observation_space, + "action_space": self.action_space, + "features_extractor": self.features_extractor, + "features_dim": self.features_dim, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + "device": device, } self.q_net, self.q_net_target = None, None @@ -152,8 +166,7 @@ def _build(self, lr_schedule: Callable) -> None: self.q_net_target.load_state_dict(self.q_net.state_dict()) # Setup optimizer with initial learning rate - self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), - **self.optimizer_kwargs) + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def make_q_net(self) -> QNetwork: return QNetwork(**self.net_args).to(self.device) @@ -167,15 +180,17 @@ def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_args['net_arch'], - activation_fn=self.net_args['activation_fn'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs - )) + data.update( + dict( + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) return data @@ -201,28 +216,33 @@ class CnnPolicy(DQNPolicy): excluding the learning rate, to pass to the optimizer """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None): - super(CnnPolicy, self).__init__(observation_space, - action_space, - lr_schedule, - net_arch, - device, - activation_fn, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(CnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + device, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/ppo/__init__.py b/stable_baselines3/ppo/__init__.py index 68fe97f82..c5b80937c 100644 --- a/stable_baselines3/ppo/__init__.py +++ b/stable_baselines3/ppo/__init__.py @@ -1,2 +1,2 @@ +from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy from stable_baselines3.ppo.ppo import PPO -from stable_baselines3.ppo.policies import MlpPolicy, CnnPolicy diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 95a7ec835..7d21de8bf 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -1,6 +1,6 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for PPO -from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy +from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 2f05c4d57..eadf96162 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -1,15 +1,15 @@ -from typing import Type, Union, Callable, Optional, Dict, Any +from typing import Any, Callable, Dict, Optional, Type, Union -from gym import spaces -import torch as th -import torch.nn.functional as F import numpy as np +import torch as th +from gym import spaces +from torch.nn import functional as F from stable_baselines3.common import logger from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.common.utils import explained_variance, get_schedule_fn -from stable_baselines3.common.policies import ActorCriticPolicy class PPO(OnPolicyAlgorithm): @@ -62,37 +62,53 @@ class PPO(OnPolicyAlgorithm): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy: Union[str, Type[ActorCriticPolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 3e-4, - n_steps: int = 2048, - batch_size: Optional[int] = 64, - n_epochs: int = 10, - gamma: float = 0.99, - gae_lambda: float = 0.95, - clip_range: float = 0.2, - clip_range_vf: Optional[float] = None, - ent_coef: float = 0.0, - vf_coef: float = 0.5, - max_grad_norm: float = 0.5, - use_sde: bool = False, - sde_sample_freq: int = -1, - target_kl: Optional[float] = None, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = "auto", - _init_setup_model: bool = True): - - super(PPO, self).__init__(policy, env, learning_rate=learning_rate, - n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, - ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm, - use_sde=use_sde, sde_sample_freq=sde_sample_freq, - tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, - verbose=verbose, device=device, create_eval_env=create_eval_env, - seed=seed, _init_setup_model=False) + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 3e-4, + n_steps: int = 2048, + batch_size: Optional[int] = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + clip_range: float = 0.2, + clip_range_vf: Optional[float] = None, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: Optional[float] = None, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(PPO, self).__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + ) self.batch_size = batch_size self.n_epochs = n_epochs @@ -110,8 +126,7 @@ def _setup_model(self) -> None: self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: if isinstance(self.clip_range_vf, (float, int)): - assert self.clip_range_vf > 0, ("`clip_range_vf` must be positive, " - "pass `None` to deactivate vf clipping") + assert self.clip_range_vf > 0, "`clip_range_vf` must be positive, " "pass `None` to deactivate vf clipping" self.clip_range_vf = get_schedule_fn(self.clip_range_vf) @@ -173,8 +188,9 @@ def train(self) -> None: else: # Clip the different between old and new value # NOTE: this depends on the reward scaling - values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values, -clip_range_vf, - clip_range_vf) + values_pred = rollout_data.old_values + th.clamp( + values - rollout_data.old_values, -clip_range_vf, clip_range_vf + ) # Value loss using the TD(gae_lambda) target value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) @@ -205,8 +221,7 @@ def train(self) -> None: break self._n_updates += self.n_epochs - explained_var = explained_variance(self.rollout_buffer.returns.flatten(), - self.rollout_buffer.values.flatten()) + explained_var = explained_variance(self.rollout_buffer.returns.flatten(), self.rollout_buffer.values.flatten()) # Logs logger.record("train/entropy_loss", np.mean(entropy_losses)) @@ -224,18 +239,27 @@ def train(self) -> None: if self.clip_range_vf is not None: logger.record("train/clip_range_vf", clip_range_vf) - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 1, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "PPO", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> "PPO": - - return super(PPO, self).learn(total_timesteps=total_timesteps, callback=callback, - log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq, - n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name, - eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps) + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "PPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> "PPO": + + return super(PPO, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/stable_baselines3/sac/__init__.py b/stable_baselines3/sac/__init__.py index 45c7ae740..5b0e89900 100644 --- a/stable_baselines3/sac/__init__.py +++ b/stable_baselines3/sac/__init__.py @@ -1,2 +1,2 @@ +from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy from stable_baselines3.sac.sac import SAC -from stable_baselines3.sac.policies import MlpPolicy, CnnPolicy diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index e40929313..b5a66670c 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -1,13 +1,13 @@ -from typing import Optional, List, Tuple, Callable, Union, Type, Dict, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym import torch as th -import torch.nn as nn +from torch import nn as nn -from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy, create_sde_features_extractor, ContinuousCritic -from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, create_sde_features_extractor, register_policy +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp # CAP the standard deviation of the actor LOG_STD_MAX = 2 @@ -41,25 +41,31 @@ class Actor(BasePolicy): :param device: (Union[th.device, str]) Device on which the code should run. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - use_sde: bool = False, - log_std_init: float = -3, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - clip_mean: float = 2.0, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto'): - super(Actor, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device, - squash_output=True) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + normalize_images: bool = True, + device: Union[th.device, str] = "auto", + ): + super(Actor, self).__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + device=device, + squash_output=True, + ) # Save arguments to re-create object at loading self.use_sde = use_sde @@ -83,14 +89,16 @@ def __init__(self, observation_space: gym.spaces.Space, latent_sde_dim = last_layer_dim # Separate feature extractor for gSDE if sde_net_arch is not None: - self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor(features_dim, sde_net_arch, - activation_fn) - - self.action_dist = StateDependentNoiseDistribution(action_dim, full_std=full_std, use_expln=use_expln, - learn_features=True, squash_output=True) - self.mu, self.log_std = self.action_dist.proba_distribution_net(latent_dim=last_layer_dim, - latent_sde_dim=latent_sde_dim, - log_std_init=log_std_init) + self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( + features_dim, sde_net_arch, activation_fn + ) + + self.action_dist = StateDependentNoiseDistribution( + action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True + ) + self.mu, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=last_layer_dim, latent_sde_dim=latent_sde_dim, log_std_init=log_std_init + ) # Avoid numerical issues by limiting the mean of the Gaussian # to be in [-clip_mean, clip_mean] if clip_mean > 0.0: @@ -103,18 +111,20 @@ def __init__(self, observation_space: gym.spaces.Space, def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_arch, - features_dim=self.features_dim, - activation_fn=self.activation_fn, - use_sde=self.use_sde, - log_std_init=self.log_std_init, - full_std=self.full_std, - sde_net_arch=self.sde_net_arch, - use_expln=self.use_expln, - features_extractor=self.features_extractor, - clip_mean=self.clip_mean - )) + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + full_std=self.full_std, + sde_net_arch=self.sde_net_arch, + use_expln=self.use_expln, + features_extractor=self.features_extractor, + clip_mean=self.clip_mean, + ) + ) return data def get_std(self) -> th.Tensor: @@ -127,7 +137,7 @@ def get_std(self) -> th.Tensor: :return: (th.Tensor) """ - msg = 'get_std() is only available when using gSDE' + msg = "get_std() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg return self.action_dist.get_std(self.log_std) @@ -137,7 +147,7 @@ def reset_noise(self, batch_size: int = 1) -> None: :param batch_size: (int) """ - msg = 'reset_noise() is only available when using gSDE' + msg = "reset_noise() is only available when using gSDE" assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) @@ -167,8 +177,7 @@ def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) # Note: the action is squashed - return self.action_dist.actions_from_params(mean_actions, log_std, - deterministic=deterministic, **kwargs) + return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: mean_actions, log_std, kwargs = self.get_action_dist_params(obs) @@ -210,30 +219,36 @@ class SACPolicy(BasePolicy): :param n_critics: (int) Number of critic networks to create. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - use_sde: bool = False, - log_std_init: float = -3, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - n_critics: int = 2): - super(SACPolicy, self).__init__(observation_space, action_space, - device, - features_extractor_class, - features_extractor_kwargs, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - squash_output=True) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + ): + super(SACPolicy, self).__init__( + observation_space, + action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) if net_arch is None: if features_extractor_class == FlattenExtractor: @@ -242,33 +257,32 @@ def __init__(self, observation_space: gym.spaces.Space, net_arch = [] # Create shared features extractor - self.features_extractor = features_extractor_class(self.observation_space, - **self.features_extractor_kwargs) + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { - 'observation_space': self.observation_space, - 'action_space': self.action_space, - 'features_extractor': self.features_extractor, - 'features_dim': self.features_dim, - 'net_arch': self.net_arch, - 'activation_fn': self.activation_fn, - 'normalize_images': normalize_images, - 'device': device + "observation_space": self.observation_space, + "action_space": self.action_space, + "features_extractor": self.features_extractor, + "features_dim": self.features_dim, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + "device": device, } self.actor_kwargs = self.net_args.copy() sde_kwargs = { - 'use_sde': use_sde, - 'log_std_init': log_std_init, - 'sde_net_arch': sde_net_arch, - 'use_expln': use_expln, - 'clip_mean': clip_mean + "use_sde": use_sde, + "log_std_init": log_std_init, + "sde_net_arch": sde_net_arch, + "use_expln": use_expln, + "clip_mean": clip_mean, } self.actor_kwargs.update(sde_kwargs) self.critic_kwargs = self.net_args.copy() - self.critic_kwargs.update({'n_critics': n_critics}) + self.critic_kwargs.update({"n_critics": n_critics}) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -277,8 +291,7 @@ def __init__(self, observation_space: gym.spaces.Space, def _build(self, lr_schedule: Callable) -> None: self.actor = self.make_actor() - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), - **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) self.critic = self.make_critic() self.critic_target = self.make_critic() @@ -286,29 +299,29 @@ def _build(self, lr_schedule: Callable) -> None: # Do not optimize the shared feature extractor with the critic loss # otherwise, there are gradient computation issues # Another solution: having duplicated features extractor but requires more memory and computation - critic_parameters = [param for name, param in self.critic.named_parameters() if - 'features_extractor' not in name] - self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), - **self.optimizer_kwargs) + critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] + self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_args['net_arch'], - activation_fn=self.net_args['activation_fn'], - use_sde=self.actor_kwargs['use_sde'], - log_std_init=self.actor_kwargs['log_std_init'], - sde_net_arch=self.actor_kwargs['sde_net_arch'], - use_expln=self.actor_kwargs['use_expln'], - clip_mean=self.actor_kwargs['clip_mean'], - n_critics=self.critic_kwargs['n_critics'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs - )) + data.update( + dict( + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + use_sde=self.actor_kwargs["use_sde"], + log_std_init=self.actor_kwargs["log_std_init"], + sde_net_arch=self.actor_kwargs["sde_net_arch"], + use_expln=self.actor_kwargs["use_expln"], + clip_mean=self.actor_kwargs["clip_mean"], + n_critics=self.critic_kwargs["n_critics"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) return data def reset_noise(self, batch_size: int = 1) -> None: @@ -364,40 +377,45 @@ class CnnPolicy(SACPolicy): :param n_critics: (int) Number of critic networks to create. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - use_sde: bool = False, - log_std_init: float = -3, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - clip_mean: float = 2.0, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - n_critics: int = 2): - super(CnnPolicy, self).__init__(observation_space, - action_space, - lr_schedule, - net_arch, - device, - activation_fn, - use_sde, - log_std_init, - sde_net_arch, - use_expln, - clip_mean, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs, - n_critics) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + ): + super(CnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + device, + activation_fn, + use_sde, + log_std_init, + sde_net_arch, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + ) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index c177df342..39b8e3a1a 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,12 +1,13 @@ -from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any -import torch as th -import torch.nn.functional as F +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + import numpy as np +import torch as th +from torch.nn import functional as F from stable_baselines3.common import logger +from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback -from stable_baselines3.common.noise import ActionNoise from stable_baselines3.sac.policies import SACPolicy @@ -68,44 +69,61 @@ class SAC(OffPolicyAlgorithm): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy: Union[str, Type[SACPolicy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 3e-4, - buffer_size: int = int(1e6), - learning_starts: int = 100, - batch_size: int = 256, - tau: float = 0.005, - gamma: float = 0.99, - train_freq: int = 1, - gradient_steps: int = 1, - n_episodes_rollout: int = -1, - action_noise: Optional[ActionNoise] = None, - optimize_memory_usage: bool = False, - ent_coef: Union[str, float] = 'auto', - target_update_interval: int = 1, - target_entropy: Union[str, float] = 'auto', - use_sde: bool = False, - sde_sample_freq: int = -1, - use_sde_at_warmup: bool = False, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Dict[str, Any] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): - - super(SAC, self).__init__(policy, env, SACPolicy, learning_rate, - buffer_size, learning_starts, batch_size, - tau, gamma, train_freq, gradient_steps, - n_episodes_rollout, action_noise, - policy_kwargs=policy_kwargs, - tensorboard_log=tensorboard_log, - verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, - use_sde=use_sde, sde_sample_freq=sde_sample_freq, - use_sde_at_warmup=use_sde_at_warmup, - optimize_memory_usage=optimize_memory_usage) + def __init__( + self, + policy: Union[str, Type[SACPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 3e-4, + buffer_size: int = int(1e6), + learning_starts: int = 100, + batch_size: int = 256, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: int = 1, + gradient_steps: int = 1, + n_episodes_rollout: int = -1, + action_noise: Optional[ActionNoise] = None, + optimize_memory_usage: bool = False, + ent_coef: Union[str, float] = "auto", + target_update_interval: int = 1, + target_entropy: Union[str, float] = "auto", + use_sde: bool = False, + sde_sample_freq: int = -1, + use_sde_at_warmup: bool = False, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Dict[str, Any] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(SAC, self).__init__( + policy, + env, + SACPolicy, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + n_episodes_rollout, + action_noise, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + use_sde_at_warmup=use_sde_at_warmup, + optimize_memory_usage=optimize_memory_usage, + ) self.target_entropy = target_entropy self.log_ent_coef = None # type: Optional[th.Tensor] @@ -122,7 +140,7 @@ def _setup_model(self) -> None: super(SAC, self)._setup_model() self._create_aliases() # Target entropy is used when learning the entropy coefficient - if self.target_entropy == 'auto': + if self.target_entropy == "auto": # automatically set target entropy if needed self.target_entropy = -np.prod(self.env.action_space.shape).astype(np.float32) else: @@ -133,12 +151,12 @@ def _setup_model(self) -> None: # The entropy coefficient or entropy can be learned automatically # see Automating Entropy Adjustment for Maximum Entropy RL section # of https://arxiv.org/abs/1812.05905 - if isinstance(self.ent_coef, str) and self.ent_coef.startswith('auto'): + if isinstance(self.ent_coef, str) and self.ent_coef.startswith("auto"): # Default initial value of ent_coef when learned init_value = 1.0 - if '_' in self.ent_coef: - init_value = float(self.ent_coef.split('_')[1]) - assert init_value > 0., "The initial value of ent_coef must be greater than 0" + if "_" in self.ent_coef: + init_value = float(self.ent_coef.split("_")[1]) + assert init_value > 0.0, "The initial value of ent_coef must be greater than 0" # Note: we optimize the log of the entropy coeff which is slightly different from the paper # as discussed in https://github.com/rail-berkeley/softlearning/issues/37 @@ -243,28 +261,37 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self._n_updates += gradient_steps - logger.record("train/n_updates", self._n_updates, exclude='tensorboard') + logger.record("train/n_updates", self._n_updates, exclude="tensorboard") logger.record("train/ent_coef", np.mean(ent_coefs)) logger.record("train/actor_loss", np.mean(actor_losses)) logger.record("train/critic_loss", np.mean(critic_losses)) if len(ent_coef_losses) > 0: logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "SAC", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> OffPolicyAlgorithm: - - return super(SAC, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, - tb_log_name=tb_log_name, eval_log_path=eval_log_path, - reset_num_timesteps=reset_num_timesteps) + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "SAC", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super(SAC, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) def excluded_save_params(self) -> List[str]: """ @@ -281,9 +308,9 @@ def get_torch_variables(self) -> Tuple[List[str], List[str]]: cf base class """ state_dicts = ["policy", "actor.optimizer", "critic.optimizer"] - saved_tensors = ['log_ent_coef'] + saved_tensors = ["log_ent_coef"] if self.ent_coef_optimizer is not None: - state_dicts.append('ent_coef_optimizer') + state_dicts.append("ent_coef_optimizer") else: - saved_tensors.append('ent_coef_tensor') + saved_tensors.append("ent_coef_tensor") return state_dicts, saved_tensors diff --git a/stable_baselines3/td3/__init__.py b/stable_baselines3/td3/__init__.py index cde926384..ed054f0d9 100644 --- a/stable_baselines3/td3/__init__.py +++ b/stable_baselines3/td3/__init__.py @@ -1,2 +1,2 @@ +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy from stable_baselines3.td3.td3 import TD3 -from stable_baselines3.td3.policies import MlpPolicy, CnnPolicy diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index 325640f5a..bcc64a136 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -1,12 +1,12 @@ -from typing import Optional, List, Callable, Union, Type, Any, Dict +from typing import Any, Callable, Dict, List, Optional, Type, Union import gym import torch as th -import torch.nn as nn +from torch import nn as nn +from stable_baselines3.common.policies import BasePolicy, ContinuousCritic, register_policy from stable_baselines3.common.preprocessing import get_action_dim -from stable_baselines3.common.policies import BasePolicy, register_policy, ContinuousCritic -from stable_baselines3.common.torch_layers import create_mlp, NatureCNN, BaseFeaturesExtractor, FlattenExtractor +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp class Actor(BasePolicy): @@ -25,20 +25,25 @@ class Actor(BasePolicy): :param device: (Union[th.device, str]) Device on which the code should run. """ - def __init__(self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - device: Union[th.device, str] = 'auto'): - super(Actor, self).__init__(observation_space, action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - device=device, - squash_output=True) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + device: Union[th.device, str] = "auto", + ): + super(Actor, self).__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + device=device, + squash_output=True, + ) self.features_extractor = features_extractor self.normalize_images = normalize_images @@ -54,12 +59,14 @@ def __init__(self, def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_arch, - features_dim=self.features_dim, - activation_fn=self.activation_fn, - features_extractor=self.features_extractor - )) + data.update( + dict( + net_arch=self.net_arch, + features_dim=self.features_dim, + activation_fn=self.activation_fn, + features_extractor=self.features_extractor, + ) + ) return data def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor: @@ -93,25 +100,31 @@ class TD3Policy(BasePolicy): :param n_critics: (int) Number of critic networks to create. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - n_critics: int = 2): - super(TD3Policy, self).__init__(observation_space, action_space, - device, - features_extractor_class, - features_extractor_kwargs, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - squash_output=True) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + ): + super(TD3Policy, self).__init__( + observation_space, + action_space, + device, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=True, + ) # Default network architecture, from the original paper if net_arch is None: @@ -120,24 +133,23 @@ def __init__(self, observation_space: gym.spaces.Space, else: net_arch = [] - self.features_extractor = features_extractor_class(self.observation_space, - **self.features_extractor_kwargs) + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) self.features_dim = self.features_extractor.features_dim self.net_arch = net_arch self.activation_fn = activation_fn self.net_args = { - 'observation_space': self.observation_space, - 'action_space': self.action_space, - 'features_extractor': self.features_extractor, - 'features_dim': self.features_dim, - 'net_arch': self.net_arch, - 'activation_fn': self.activation_fn, - 'normalize_images': normalize_images, - 'device': device + "observation_space": self.observation_space, + "action_space": self.action_space, + "features_extractor": self.features_extractor, + "features_dim": self.features_dim, + "net_arch": self.net_arch, + "activation_fn": self.activation_fn, + "normalize_images": normalize_images, + "device": device, } self.critic_kwargs = self.net_args.copy() - self.critic_kwargs.update({'n_critics': n_critics}) + self.critic_kwargs.update({"n_critics": n_critics}) self.actor, self.actor_target = None, None self.critic, self.critic_target = None, None @@ -147,27 +159,27 @@ def _build(self, lr_schedule: Callable) -> None: self.actor = self.make_actor() self.actor_target = self.make_actor() self.actor_target.load_state_dict(self.actor.state_dict()) - self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), - **self.optimizer_kwargs) + self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) self.critic = self.make_critic() self.critic_target = self.make_critic() self.critic_target.load_state_dict(self.critic.state_dict()) - self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), - **self.optimizer_kwargs) + self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) def _get_data(self) -> Dict[str, Any]: data = super()._get_data() - data.update(dict( - net_arch=self.net_args['net_arch'], - activation_fn=self.net_args['activation_fn'], - n_critics=self.critic_kwargs['n_critics'], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs - )) + data.update( + dict( + net_arch=self.net_args["net_arch"], + activation_fn=self.net_args["activation_fn"], + n_critics=self.critic_kwargs["n_critics"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) return data def make_actor(self) -> Actor: @@ -208,30 +220,35 @@ class CnnPolicy(TD3Policy): :param n_critics: (int) Number of critic networks to create. """ - def __init__(self, observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Callable, - net_arch: Optional[List[int]] = None, - device: Union[th.device, str] = 'auto', - activation_fn: Type[nn.Module] = nn.ReLU, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - n_critics: int = 2): - super(CnnPolicy, self).__init__(observation_space, - action_space, - lr_schedule, - net_arch, - device, - activation_fn, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs, - n_critics) + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[int]] = None, + device: Union[th.device, str] = "auto", + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + ): + super(CnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + device, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + ) register_policy("MlpPolicy", MlpPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 83b7447aa..368bce16c 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,10 +1,11 @@ +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + import torch as th -import torch.nn.functional as F -from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any +from torch.nn import functional as F from stable_baselines3.common import logger -from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback from stable_baselines3.td3.policies import TD3Policy @@ -55,39 +56,56 @@ class TD3(OffPolicyAlgorithm): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance """ - def __init__(self, policy: Union[str, Type[TD3Policy]], - env: Union[GymEnv, str], - learning_rate: Union[float, Callable] = 1e-3, - buffer_size: int = int(1e6), - learning_starts: int = 100, - batch_size: int = 100, - tau: float = 0.005, - gamma: float = 0.99, - train_freq: int = -1, - gradient_steps: int = -1, - n_episodes_rollout: int = 1, - action_noise: Optional[ActionNoise] = None, - optimize_memory_usage: bool = False, - policy_delay: int = 2, - target_policy_noise: float = 0.2, - target_noise_clip: float = 0.5, - tensorboard_log: Optional[str] = None, - create_eval_env: bool = False, - policy_kwargs: Dict[str, Any] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: Union[th.device, str] = 'auto', - _init_setup_model: bool = True): - - super(TD3, self).__init__(policy, env, TD3Policy, learning_rate, - buffer_size, learning_starts, batch_size, - tau, gamma, train_freq, gradient_steps, - n_episodes_rollout, action_noise=action_noise, - policy_kwargs=policy_kwargs, - tensorboard_log=tensorboard_log, - verbose=verbose, device=device, - create_eval_env=create_eval_env, seed=seed, - sde_support=False, optimize_memory_usage=optimize_memory_usage) + def __init__( + self, + policy: Union[str, Type[TD3Policy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Callable] = 1e-3, + buffer_size: int = int(1e6), + learning_starts: int = 100, + batch_size: int = 100, + tau: float = 0.005, + gamma: float = 0.99, + train_freq: int = -1, + gradient_steps: int = -1, + n_episodes_rollout: int = 1, + action_noise: Optional[ActionNoise] = None, + optimize_memory_usage: bool = False, + policy_delay: int = 2, + target_policy_noise: float = 0.2, + target_noise_clip: float = 0.5, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Dict[str, Any] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(TD3, self).__init__( + policy, + env, + TD3Policy, + learning_rate, + buffer_size, + learning_starts, + batch_size, + tau, + gamma, + train_freq, + gradient_steps, + n_episodes_rollout, + action_noise=action_noise, + policy_kwargs=policy_kwargs, + tensorboard_log=tensorboard_log, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + sde_support=False, + optimize_memory_usage=optimize_memory_usage, + ) self.policy_delay = policy_delay self.target_noise_clip = target_noise_clip @@ -141,8 +159,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Delayed policy updates if gradient_step % self.policy_delay == 0: # Compute actor loss - actor_loss = -self.critic.q1_forward(replay_data.observations, - self.actor(replay_data.observations)).mean() + actor_loss = -self.critic.q1_forward(replay_data.observations, self.actor(replay_data.observations)).mean() # Optimize the actor self.actor.optimizer.zero_grad() @@ -157,23 +174,32 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) self._n_updates += gradient_steps - logger.record("train/n_updates", self._n_updates, exclude='tensorboard') - - def learn(self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "TD3", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True) -> OffPolicyAlgorithm: - - return super(TD3, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval, - eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, - tb_log_name=tb_log_name, eval_log_path=eval_log_path, - reset_num_timesteps=reset_num_timesteps) + logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 4, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "TD3", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OffPolicyAlgorithm: + + return super(TD3, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) def excluded_save_params(self) -> List[str]: """ diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index fd4abb5ba..b1a3339e8 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,23 +1,28 @@ import os import shutil -import pytest import gym +import pytest -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG -from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback, - EveryNTimesteps, StopTrainingOnRewardThreshold) +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.callbacks import ( + CallbackList, + CheckpointCallback, + EvalCallback, + EveryNTimesteps, + StopTrainingOnRewardThreshold, +) @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) def test_callbacks(tmp_path, model_class): - log_folder = tmp_path / 'logs/callbacks/' + log_folder = tmp_path / "logs/callbacks/" # Dyn only support discrete actions env_name = select_env(model_class) # Create RL model # Small network for fast test - model = model_class('MlpPolicy', env_name, policy_kwargs=dict(net_arch=[32])) + model = model_class("MlpPolicy", env_name, policy_kwargs=dict(net_arch=[32])) checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_folder) @@ -25,14 +30,13 @@ def test_callbacks(tmp_path, model_class): # Stop training if the performance is good enough callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=-1200, verbose=1) - eval_callback = EvalCallback(eval_env, callback_on_new_best=callback_on_best, - best_model_save_path=log_folder, - log_path=log_folder, eval_freq=100) + eval_callback = EvalCallback( + eval_env, callback_on_new_best=callback_on_best, best_model_save_path=log_folder, log_path=log_folder, eval_freq=100 + ) # Equivalent to the `checkpoint_callback` # but here in an event-driven manner - checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, - name_prefix='event') + checkpoint_on_event = CheckpointCallback(save_freq=1, save_path=log_folder, name_prefix="event") event_callback = EveryNTimesteps(n_steps=500, callback=checkpoint_on_event) callback = CallbackList([checkpoint_callback, eval_callback, event_callback]) @@ -49,6 +53,6 @@ def test_callbacks(tmp_path, model_class): def select_env(model_class) -> str: if model_class is DQN: - return 'CartPole-v0' + return "CartPole-v0" else: - return 'Pendulum-v0' + return "Pendulum-v0" diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 8c1727663..58c80b6af 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -3,26 +3,24 @@ import numpy as np import pytest -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.identity_env import FakeImageEnv -@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3, DQN]) +@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN]) def test_cnn(tmp_path, model_class): - SAVE_NAME = 'cnn_model.zip' + SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip # Atari after preprocessing: 84x84x1, here we are using lower resolution # to check that the network handle it automatically - env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, - discrete=model_class not in {SAC, TD3}) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=model_class not in {SAC, TD3}) if model_class in {A2C, PPO}: kwargs = dict(n_steps=100) else: # Avoid memory error when using replay buffer # Reduce the size of the features - kwargs = dict(buffer_size=250, - policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) - model = model_class('CnnPolicy', env, **kwargs).learn(250) + kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32))) + model = model_class("CnnPolicy", env, **kwargs).learn(250) obs = env.reset() diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index 9637f4e52..cd379126c 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -4,31 +4,31 @@ from stable_baselines3 import A2C, PPO, SAC, TD3 -@pytest.mark.parametrize('net_arch', [ - [12, dict(vf=[16], pi=[8])], - [4], - [], - [4, 4], - [12, dict(vf=[8, 4], pi=[8])], - [12, dict(vf=[8], pi=[8, 4])], - [12, dict(pi=[8])], -]) -@pytest.mark.parametrize('model_class', [A2C, PPO]) +@pytest.mark.parametrize( + "net_arch", + [ + [12, dict(vf=[16], pi=[8])], + [4], + [], + [4, 4], + [12, dict(vf=[8, 4], pi=[8])], + [12, dict(vf=[8], pi=[8, 4])], + [12, dict(pi=[8])], + ], +) +@pytest.mark.parametrize("model_class", [A2C, PPO]) def test_flexible_mlp(model_class, net_arch): - _ = model_class('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) + _ = model_class("MlpPolicy", "CartPole-v1", policy_kwargs=dict(net_arch=net_arch), n_steps=100).learn(1000) -@pytest.mark.parametrize('net_arch', [ - [4], - [4, 4], -]) -@pytest.mark.parametrize('model_class', [SAC, TD3]) +@pytest.mark.parametrize("net_arch", [[4], [4, 4],]) +@pytest.mark.parametrize("model_class", [SAC, TD3]) def test_custom_offpolicy(model_class, net_arch): - _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=net_arch)).learn(1000) + _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=net_arch)).learn(1000) -@pytest.mark.parametrize('model_class', [A2C, PPO, SAC, TD3]) -@pytest.mark.parametrize('optimizer_kwargs', [None, dict(weight_decay=0.0)]) +@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) +@pytest.mark.parametrize("optimizer_kwargs", [None, dict(weight_decay=0.0)]) def test_custom_optimizer(model_class, optimizer_kwargs): policy_kwargs = dict(optimizer_class=th.optim.AdamW, optimizer_kwargs=optimizer_kwargs, net_arch=[32]) - _ = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=policy_kwargs).learn(1000) + _ = model_class("MlpPolicy", "Pendulum-v0", policy_kwargs=policy_kwargs).learn(1000) diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py index 645c4d3ff..16cdcfacf 100644 --- a/tests/test_deterministic.py +++ b/tests/test_deterministic.py @@ -1,6 +1,6 @@ import pytest -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.noise import NormalActionNoise N_STEPS_TRAINING = 3000 @@ -12,18 +12,17 @@ def test_deterministic_training_common(algo): results = [[], []] rewards = [[], []] # Smaller network - kwargs = {'policy_kwargs': dict(net_arch=[64])} + kwargs = {"policy_kwargs": dict(net_arch=[64])} if algo in [TD3, SAC]: - env_id = 'Pendulum-v0' - kwargs.update({'action_noise': NormalActionNoise(0.0, 0.1), - 'learning_starts': 100}) + env_id = "Pendulum-v0" + kwargs.update({"action_noise": NormalActionNoise(0.0, 0.1), "learning_starts": 100}) else: - env_id = 'CartPole-v1' + env_id = "CartPole-v1" if algo == DQN: - kwargs.update({'learning_starts': 100}) + kwargs.update({"learning_starts": 100}) for i in range(2): - model = algo('MlpPolicy', env_id, seed=SEED, **kwargs) + model = algo("MlpPolicy", env_id, seed=SEED, **kwargs) model.learn(N_STEPS_TRAINING) env = model.get_env() obs = env.reset() diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 0461e17f4..a73b81ede 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -2,13 +2,17 @@ import torch as th from stable_baselines3 import A2C, PPO -from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, - StateDependentNoiseDistribution, - CategoricalDistribution, SquashedDiagGaussianDistribution, - MultiCategoricalDistribution, BernoulliDistribution) +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + MultiCategoricalDistribution, + SquashedDiagGaussianDistribution, + StateDependentNoiseDistribution, + TanhBijector, +) from stable_baselines3.common.utils import set_random_seed - N_ACTIONS = 2 N_FEATURES = 3 N_SAMPLES = int(5e6) @@ -33,7 +37,7 @@ def test_squashed_gaussian(model_class): """ Test run with squashed Gaussian (notably entropy computation) """ - model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) + model = model_class("MlpPolicy", "Pendulum-v0", use_sde=True, n_steps=100, policy_kwargs=dict(squash_output=True)) model.learn(500) gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS) @@ -62,10 +66,9 @@ def test_sde_distribution(): # TODO: analytical form for squashed Gaussian? -@pytest.mark.parametrize("dist", [ - DiagGaussianDistribution(N_ACTIONS), - StateDependentNoiseDistribution(N_ACTIONS, squash_output=False), -]) +@pytest.mark.parametrize( + "dist", [DiagGaussianDistribution(N_ACTIONS), StateDependentNoiseDistribution(N_ACTIONS, squash_output=False),] +) def test_entropy(dist): # The entropy can be approximated by averaging the negative log likelihood # mean negative log likelihood == differential entropy @@ -89,7 +92,7 @@ def test_entropy(dist): categorical_params = [ (CategoricalDistribution(N_ACTIONS), N_ACTIONS), (MultiCategoricalDistribution([2, 3]), sum([2, 3])), - (BernoulliDistribution(N_ACTIONS), N_ACTIONS) + (BernoulliDistribution(N_ACTIONS), N_ACTIONS), ] diff --git a/tests/test_envs.py b/tests/test_envs.py index 98197d397..e5f6bf3a9 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,18 +1,22 @@ -import pytest import gym -from gym import spaces import numpy as np +import pytest +from gym import spaces -from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.bit_flipping_env import BitFlippingEnv -from stable_baselines3.common.identity_env import (IdentityEnv, IdentityEnvBox, FakeImageEnv, - IdentityEnvMultiBinary, IdentityEnvMultiDiscrete) +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.identity_env import ( + FakeImageEnv, + IdentityEnv, + IdentityEnvBox, + IdentityEnvMultiBinary, + IdentityEnvMultiDiscrete, +) -ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, - IdentityEnvMultiDiscrete, FakeImageEnv] +ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv] -@pytest.mark.parametrize("env_id", ['CartPole-v0', 'Pendulum-v0']) +@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"]) def test_env(env_id): """ Check that environmnent integrated in Gym pass the test. @@ -25,7 +29,7 @@ def test_env(env_id): # Pendulum-v0 will produce a warning because the action space is # in [-2, 2] and not [-1, 1] - if env_id == 'Pendulum-v0': + if env_id == "Pendulum-v0": assert len(record) == 1 else: # The other environments must pass without warning @@ -50,24 +54,28 @@ def test_high_dimension_action_space(): # Patch to avoid error def patched_step(_action): return env.observation_space.sample(), 0.0, False, {} + env.step = patched_step check_env(env) -@pytest.mark.parametrize("new_obs_space", [ - # Small image - spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), - # Range not in [0, 255] - spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8), - # Wrong dtype - spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32), - # Not an image, it should be a 1D vector - spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), - # Tuple space is not supported by SB - spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), - # Dict space is not supported by SB when env is not a GoalEnv - spaces.Dict({"position": spaces.Discrete(5)}), -]) +@pytest.mark.parametrize( + "new_obs_space", + [ + # Small image + spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8), + # Range not in [0, 255] + spaces.Box(low=0, high=1, shape=(64, 64, 3), dtype=np.uint8), + # Wrong dtype + spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.float32), + # Not an image, it should be a 1D vector + spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), + # Tuple space is not supported by SB + spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), + # Dict space is not supported by SB when env is not a GoalEnv + spaces.Dict({"position": spaces.Discrete(5)}), + ], +) def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space diff --git a/tests/test_identity.py b/tests/test_identity.py index 72d8c4153..38c6570be 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -1,13 +1,11 @@ import numpy as np import pytest -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG -from stable_baselines3.common.identity_env import (IdentityEnvBox, IdentityEnv, - IdentityEnvMultiBinary, IdentityEnvMultiDiscrete) - -from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete from stable_baselines3.common.noise import NormalActionNoise +from stable_baselines3.common.vec_env import DummyVecEnv DIM = 4 @@ -25,7 +23,7 @@ def test_discrete(model_class, env): if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return - model = model_class('MlpPolicy', env_, gamma=0.5, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.5, seed=1, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90) obs = env.reset() @@ -37,24 +35,14 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = { - A2C: 3500, - PPO: 3000, - SAC: 700, - TD3: 500, - DDPG: 500 - }[model_class] - - kwargs = dict( - policy_kwargs=dict(net_arch=[64, 64]), - seed=0, - gamma=0.95 - ) + n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class] + + kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) if model_class in [TD3]: n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) - kwargs['action_noise'] = action_noise + kwargs["action_noise"] = action_noise - model = model_class('MlpPolicy', env, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env, **kwargs).learn(n_steps) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) diff --git a/tests/test_logger.py b/tests/test_logger.py index 6975ce5c6..c399c9ee4 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,9 +1,24 @@ -import pytest import numpy as np +import pytest -from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure, - info, debug, set_level, configure, record, record_dict, - dump, record_mean, warn, error, reset) +from stable_baselines3.common.logger import ( + DEBUG, + ScopedConfigure, + configure, + debug, + dump, + error, + info, + make_output_format, + read_csv, + read_json, + record, + record_dict, + record_mean, + reset, + set_level, + warn, +) KEY_VALUES = { "test": 1, @@ -55,23 +70,23 @@ def test_main(tmp_path): record_dict({"test": 1}) -@pytest.mark.parametrize('_format', ['stdout', 'log', 'json', 'csv', 'tensorboard']) +@pytest.mark.parametrize("_format", ["stdout", "log", "json", "csv", "tensorboard"]) def test_make_output(tmp_path, _format): """ test make output :param _format: (str) output format """ - if _format == 'tensorboard': + if _format == "tensorboard": # Skip if no tensorboard installed pytest.importorskip("tensorboard") writer = make_output_format(_format, tmp_path) writer.write(KEY_VALUES, KEY_EXCLUDED) if _format == "csv": - read_csv(tmp_path / 'progress.csv') - elif _format == 'json': - read_json(tmp_path / 'progress.json') + read_csv(tmp_path / "progress.csv") + elif _format == "json": + read_json(tmp_path / "progress.json") writer.close() @@ -80,4 +95,4 @@ def test_make_output_fail(tmp_path): test value error on logger """ with pytest.raises(ValueError): - make_output_format('dummy_format', tmp_path) + make_output_format("dummy_format", tmp_path) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 00a78026f..d3d041b4d 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,9 +1,9 @@ -import uuid import json import os +import uuid -import pandas import gym +import pandas from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results @@ -37,15 +37,15 @@ def test_monitor(tmp_path): assert sum(monitor_env.get_episode_rewards()) == sum(ep_rewards) _ = monitor_env.get_episode_times() - with open(monitor_file, 'rt') as file_handler: + with open(monitor_file, "rt") as file_handler: first_line = file_handler.readline() - assert first_line.startswith('#') + assert first_line.startswith("#") metadata = json.loads(first_line[1:]) - assert metadata['env_id'] == "CartPole-v1" - assert set(metadata.keys()) == {'env_id', 't_start'}, "Incorrect keys in monitor metadata" + assert metadata["env_id"] == "CartPole-v1" + assert set(metadata.keys()) == {"env_id", "t_start"}, "Incorrect keys in monitor metadata" last_logline = pandas.read_csv(file_handler, index_col=None) - assert set(last_logline.keys()) == {'l', 't', 'r'}, "Incorrect keys in monitor logline" + assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" os.remove(monitor_file) diff --git a/tests/test_predict.py b/tests/test_predict.py index c95e19f77..4f60f1bc0 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,7 +1,7 @@ import gym import pytest -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN +from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 from stable_baselines3.common.vec_env import DummyVecEnv MODEL_LIST = [ @@ -19,26 +19,26 @@ def test_auto_wrap(model_class): # Use different environment for DQN if model_class is DQN: - env_name = 'CartPole-v0' + env_name = "CartPole-v0" else: - env_name = 'Pendulum-v0' + env_name = "Pendulum-v0" env = gym.make(env_name) eval_env = gym.make(env_name) - model = model_class('MlpPolicy', env) + model = model_class("MlpPolicy", env) model.learn(100, eval_env=eval_env) @pytest.mark.parametrize("model_class", MODEL_LIST) -@pytest.mark.parametrize("env_id", ['Pendulum-v0', 'CartPole-v1']) +@pytest.mark.parametrize("env_id", ["Pendulum-v0", "CartPole-v1"]) def test_predict(model_class, env_id): - if env_id == 'CartPole-v1': + if env_id == "CartPole-v1": if model_class in [SAC, TD3]: return elif model_class in [DQN]: return # test detection of different shapes by the predict method - model = model_class('MlpPolicy', env_id) + model = model_class("MlpPolicy", env_id) env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) diff --git a/tests/test_run.py b/tests/test_run.py index d22d346c8..368381ee2 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,60 +1,97 @@ import numpy as np import pytest -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) -@pytest.mark.parametrize('model_class', [TD3, DDPG]) -@pytest.mark.parametrize('action_noise', [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))]) +@pytest.mark.parametrize("model_class", [TD3, DDPG]) +@pytest.mark.parametrize("action_noise", [normal_action_noise, OrnsteinUhlenbeckActionNoise(np.zeros(1), 0.1 * np.ones(1))]) def test_deterministic_pg(model_class, action_noise): """ Test for DDPG and variants (TD3). """ - model = model_class('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=100, verbose=1, create_eval_env=True, action_noise=action_noise) + model = model_class( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=100, + verbose=1, + create_eval_env=True, + action_noise=action_noise, + ) model.learn(total_timesteps=1000, eval_freq=500) -@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0']) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) def test_a2c(env_id): - model = A2C('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) + model = A2C("MlpPolicy", env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True) model.learn(total_timesteps=1000, eval_freq=500) -@pytest.mark.parametrize("env_id", ['CartPole-v1', 'Pendulum-v0']) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) @pytest.mark.parametrize("clip_range_vf", [None, 0.2, -0.2]) def test_ppo(env_id, clip_range_vf): if clip_range_vf is not None and clip_range_vf < 0: # Should throw an error with pytest.raises(AssertionError): - model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True, - clip_range_vf=clip_range_vf) + model = PPO( + "MlpPolicy", + env_id, + seed=0, + policy_kwargs=dict(net_arch=[16]), + verbose=1, + create_eval_env=True, + clip_range_vf=clip_range_vf, + ) else: - model = PPO('MlpPolicy', env_id, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True, - clip_range_vf=clip_range_vf) + model = PPO( + "MlpPolicy", + env_id, + seed=0, + policy_kwargs=dict(net_arch=[16]), + verbose=1, + create_eval_env=True, + clip_range_vf=clip_range_vf, + ) model.learn(total_timesteps=1000, eval_freq=500) -@pytest.mark.parametrize("ent_coef", ['auto', 0.01, 'auto_0.01']) +@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) def test_sac(ent_coef): - model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=100, verbose=1, create_eval_env=True, ent_coef=ent_coef, - action_noise=NormalActionNoise(np.zeros(1), np.zeros(1))) + model = SAC( + "MlpPolicy", + "Pendulum-v0", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=100, + verbose=1, + create_eval_env=True, + ent_coef=ent_coef, + action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), + ) model.learn(total_timesteps=1000, eval_freq=500) @pytest.mark.parametrize("n_critics", [1, 3]) def test_n_critics(n_critics): # Test SAC with different number of critics, for TD3, n_critics=1 corresponds to DDPG - model = SAC('MlpPolicy', 'Pendulum-v0', policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), - learning_starts=100, verbose=1) + model = SAC( + "MlpPolicy", "Pendulum-v0", policy_kwargs=dict(net_arch=[64, 64], n_critics=n_critics), learning_starts=100, verbose=1 + ) model.learn(total_timesteps=1000) def test_dqn(): - model = DQN('MlpPolicy', 'CartPole-v1', policy_kwargs=dict(net_arch=[64, 64]), - learning_starts=500, buffer_size=500, learning_rate=3e-4, verbose=1, create_eval_env=True) + model = DQN( + "MlpPolicy", + "CartPole-v1", + policy_kwargs=dict(net_arch=[64, 64]), + learning_starts=500, + buffer_size=500, + learning_rate=3e-4, + verbose=1, + create_eval_env=True, + ) model.learn(total_timesteps=1000, eval_freq=500) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 69df31a98..d56c40d5e 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,33 +1,21 @@ -import os import io +import os +import pathlib import warnings from copy import deepcopy -import pathlib -import pytest import gym import numpy as np +import pytest import torch as th -from stable_baselines3 import A2C, PPO, SAC, TD3, DQN, DDPG +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv +from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox +from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.vec_env import DummyVecEnv -from stable_baselines3.common.identity_env import FakeImageEnv -from stable_baselines3.common.save_util import ( - open_path, - save_to_pkl, - load_from_pkl, -) - -MODEL_LIST = [ - PPO, - A2C, - TD3, - SAC, - DQN, - DDPG -] + +MODEL_LIST = [PPO, A2C, TD3, SAC, DQN, DDPG] def select_env(model_class: BaseAlgorithm) -> gym.Env: @@ -58,17 +46,13 @@ def test_save_load(tmp_path, model_class): model.learn(total_timesteps=500, eval_freq=250) env.reset() - observations = np.concatenate( - [env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0 - ) + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) # Get dictionary of current parameters params = deepcopy(model.policy.state_dict()) # Modify all parameters to be random values - random_params = dict( - (param_name, th.rand_like(param)) for param_name, param in params.items() - ) + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) # Update model parameters with the new random values model.policy.load_state_dict(random_params) @@ -76,9 +60,7 @@ def test_save_load(tmp_path, model_class): new_params = model.policy.state_dict() # Check that all params are different now for k in params: - assert not th.allclose( - params[k], new_params[k] - ), "Parameters did not change as expected." + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." params = new_params @@ -95,9 +77,7 @@ def test_save_load(tmp_path, model_class): # Check that all params are the same as before save load procedure now for key in params: - assert th.allclose( - params[key], new_params[key] - ), "Model parameters not the same after save and load." + assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = model.predict(observations, deterministic=True) @@ -229,9 +209,7 @@ def test_warn_buffer(recwarn, model_class, optimize_memory_usage): if optimize_memory_usage: assert len(recwarn) == 1 warning = recwarn.pop(UserWarning) - assert "The last trajectory in the replay buffer will be truncated" in str( - warning.message - ) + assert "The last trajectory in the replay buffer will be truncated" in str(warning.message) else: assert len(recwarn) == 0 @@ -253,22 +231,16 @@ def test_save_load_policy(tmp_path, model_class, policy_str): # Avoid memory error when using replay buffer # Reduce the size of the features kwargs = dict(buffer_size=250) - env = FakeImageEnv( - screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN - ) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) env = DummyVecEnv([lambda: env]) # create model - model = model_class( - policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs - ) + model = model_class(policy_str, env, policy_kwargs=dict(net_arch=[16]), verbose=1, **kwargs) model.learn(total_timesteps=500, eval_freq=250) env.reset() - observations = np.concatenate( - [env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0 - ) + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) policy = model.policy policy_class = policy.__class__ @@ -281,9 +253,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): params = deepcopy(policy.state_dict()) # Modify all parameters to be random values - random_params = dict( - (param_name, th.rand_like(param)) for param_name, param in params.items() - ) + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) # Update model parameters with the new random values policy.load_state_dict(random_params) @@ -291,9 +261,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): new_params = policy.state_dict() # Check that all params are different now for k in params: - assert not th.allclose( - params[k], new_params[k] - ), "Parameters did not change as expected." + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." params = new_params @@ -320,9 +288,7 @@ def test_save_load_policy(tmp_path, model_class, policy_str): # Check that all params are the same as before save load procedure now for key in params: - assert th.allclose( - params[key], new_params[key] - ), "Policy parameters not the same after save and load." + assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." # check if model still selects the same actions new_selected_actions, _ = policy.predict(observations, deterministic=True) @@ -370,17 +336,11 @@ def test_open_file_str_pathlib(tmp_path, pathtype): save_to_pkl(fp1, "foo") assert fp1.closed with pytest.warns(None) as record: - assert ( - load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) - == "foo" - ) + assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl")) == "foo" assert len(record) == 0 with pytest.warns(None) as record: - assert ( - load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) - == "foo" - ) + assert load_from_pkl(open_path(pathtype(f"{tmp_path}/t2"), "r", suffix="pkl", verbose=2)) == "foo" assert len(record) == 1 fp = pathlib.Path(f"{tmp_path}/t2").open("w") diff --git a/tests/test_sde.py b/tests/test_sde.py index dbf196368..d26e06cec 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -2,7 +2,7 @@ import torch as th from torch.distributions import Normal -from stable_baselines3 import A2C, SAC, PPO +from stable_baselines3 import A2C, PPO, SAC def test_state_dependent_exploration_grad(): @@ -58,6 +58,13 @@ def test_state_dependent_exploration_grad(): @pytest.mark.parametrize("sde_net_arch", [None, [32, 16], []]) @pytest.mark.parametrize("use_expln", [False, True]) def test_state_dependent_offpolicy_noise(model_class, sde_net_arch, use_expln): - model = model_class('MlpPolicy', 'Pendulum-v0', use_sde=True, seed=None, create_eval_env=True, - verbose=1, policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln)) + model = model_class( + "MlpPolicy", + "Pendulum-v0", + use_sde=True, + seed=None, + create_eval_env=True, + verbose=1, + policy_kwargs=dict(log_std_init=-2, sde_net_arch=sde_net_arch, use_expln=use_expln), + ) model.learn(total_timesteps=int(500), eval_freq=250) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index e057f97b3..98a1953fc 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,6 +1,6 @@ +import gym import numpy as np import pytest -import gym from stable_baselines3 import DQN, SAC, TD3 from stable_baselines3.common.evaluation import evaluate_policy diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py index e5441a85c..3f755a7aa 100644 --- a/tests/test_tensorboard.py +++ b/tests/test_tensorboard.py @@ -1,13 +1,14 @@ import os + import pytest from stable_baselines3 import A2C, PPO, SAC, TD3 MODEL_DICT = { - 'a2c': (A2C, 'CartPole-v1'), - 'ppo': (PPO, 'CartPole-v1'), - 'sac': (SAC, 'Pendulum-v0'), - 'td3': (TD3, 'Pendulum-v0'), + "a2c": (A2C, "CartPole-v1"), + "ppo": (PPO, "CartPole-v1"), + "sac": (SAC, "Pendulum-v0"), + "td3": (TD3, "Pendulum-v0"), } N_STEPS = 100 @@ -20,7 +21,7 @@ def test_tensorboard(tmp_path, model_name): logname = model_name.upper() algo, env_id = MODEL_DICT[model_name] - model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=tmp_path) + model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path) model.learn(N_STEPS) model.learn(N_STEPS, reset_num_timesteps=False) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4f0473871..7f03c1d0c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,27 +1,25 @@ import os import shutil -import pytest import gym import numpy as np +import pytest from stable_baselines3 import A2C -from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.atari_wrappers import ClipRewardEnv +from stable_baselines3.common.cmd_util import make_atari_env, make_vec_env from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.cmd_util import make_vec_env, make_atari_env +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv -from stable_baselines3.common.noise import (VectorizedActionNoise, - OrnsteinUhlenbeckActionNoise, ActionNoise) -@pytest.mark.parametrize("env_id", ['CartPole-v1', lambda: gym.make('CartPole-v1')]) +@pytest.mark.parametrize("env_id", ["CartPole-v1", lambda: gym.make("CartPole-v1")]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("vec_env_cls", [None, SubprocVecEnv]) @pytest.mark.parametrize("wrapper_class", [None, gym.wrappers.TimeLimit]) def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class): - env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, - wrapper_class=wrapper_class, monitor_dir=None, seed=0) + env = make_vec_env(env_id, n_envs, vec_env_cls=vec_env_cls, wrapper_class=wrapper_class, monitor_dir=None, seed=0) assert env.num_envs == n_envs @@ -37,13 +35,12 @@ def test_make_vec_env(env_id, n_envs, vec_env_cls, wrapper_class): env.close() -@pytest.mark.parametrize("env_id", ['BreakoutNoFrameskip-v4']) +@pytest.mark.parametrize("env_id", ["BreakoutNoFrameskip-v4"]) @pytest.mark.parametrize("n_envs", [1, 2]) @pytest.mark.parametrize("wrapper_kwargs", [None, dict(clip_reward=False, screen_size=60)]) def test_make_atari_env(env_id, n_envs, wrapper_kwargs): - env_id = 'BreakoutNoFrameskip-v4' - env = make_atari_env(env_id, n_envs, - wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0) + env_id = "BreakoutNoFrameskip-v4" + env = make_atari_env(env_id, n_envs, wrapper_kwargs=wrapper_kwargs, monitor_dir=None, seed=0) assert env.num_envs == n_envs @@ -70,10 +67,15 @@ def test_custom_vec_env(tmp_path): """ Stand alone test for a special case (passing a custom VecEnv class) to avoid doubling the number of tests. """ - monitor_dir = tmp_path / 'test_make_vec_env/' - env = make_vec_env('CartPole-v1', n_envs=1, - monitor_dir=monitor_dir, seed=0, - vec_env_cls=SubprocVecEnv, vec_env_kwargs={'start_method': None}) + monitor_dir = tmp_path / "test_make_vec_env/" + env = make_vec_env( + "CartPole-v1", + n_envs=1, + monitor_dir=monitor_dir, + seed=0, + vec_env_cls=SubprocVecEnv, + vec_env_kwargs={"start_method": None}, + ) assert env.num_envs == 1 assert isinstance(env, SubprocVecEnv) @@ -85,20 +87,27 @@ def test_custom_vec_env(tmp_path): # This should fail because DummyVecEnv does not have any keyword argument with pytest.raises(TypeError): - make_vec_env('CartPole-v1', n_envs=1, vec_env_kwargs={'dummy': False}) + make_vec_env("CartPole-v1", n_envs=1, vec_env_kwargs={"dummy": False}) def test_evaluate_policy(): - model = A2C('MlpPolicy', 'Pendulum-v0', seed=0) + model = A2C("MlpPolicy", "Pendulum-v0", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2 model.n_callback_calls = 0 def dummy_callback(locals_, _globals): - locals_['model'].n_callback_calls += 1 - - _, episode_lengths = evaluate_policy(model, model.get_env(), n_eval_episodes, deterministic=True, - render=False, callback=dummy_callback, reward_threshold=None, - return_episode_rewards=True) + locals_["model"].n_callback_calls += 1 + + _, episode_lengths = evaluate_policy( + model, + model.get_env(), + n_eval_episodes, + deterministic=True, + render=False, + callback=dummy_callback, + reward_threshold=None, + return_episode_rewards=True, + ) n_steps = sum(episode_lengths) assert n_steps == n_steps_per_episode * n_eval_episodes diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index c6db91a94..265da2ed9 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -1,14 +1,15 @@ import gym -from gym import spaces import numpy as np import pytest +from gym import spaces from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan class NanAndInfEnv(gym.Env): """Custom Environment that raised NaNs and Infs""" - metadata = {'render.modes': ['human']} + + metadata = {"render.modes": ["human"]} def __init__(self): super(NanAndInfEnv, self).__init__() @@ -18,9 +19,9 @@ def __init__(self): @staticmethod def step(action): if np.all(np.array(action) > 0): - obs = float('NaN') + obs = float("NaN") elif np.all(np.array(action) < 0): - obs = float('inf') + obs = float("inf") else: obs = 0 return [obs], 0.0, False, {} @@ -29,7 +30,7 @@ def step(action): def reset(): return [0.0] - def render(self, mode='human', close=False): + def render(self, mode="human", close=False): pass @@ -42,10 +43,10 @@ def test_check_nan(): env.step([[0]]) with pytest.raises(ValueError): - env.step([[float('NaN')]]) + env.step([[float("NaN")]]) with pytest.raises(ValueError): - env.step([[float('inf')]]) + env.step([[float("inf")]]) with pytest.raises(ValueError): env.step([[-1]]) diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 5e559dd19..8c33341c5 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -3,11 +3,11 @@ import itertools import multiprocessing -import pytest import gym import numpy as np +import pytest -from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize N_ENVS = 3 VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv] @@ -39,8 +39,8 @@ def step(self, action): def _choose_next_state(self): self.state = self.observation_space.sample() - def render(self, mode='human'): - if mode == 'rgb_array': + def render(self, mode="human"): + if mode == "rgb_array": return np.zeros((4, 4, 3)) def seed(self, seed=None): @@ -59,8 +59,8 @@ def custom_method(dim_0=1, dim_1=1): return np.ones((dim_0, dim_1)) -@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) -@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS) +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS) def test_vecenv_custom_calls(vec_env_class, vec_env_wrapper): """Test access to methods/attributes of vectorized environments""" @@ -79,14 +79,14 @@ def make_env(): vec_env.seed(0) # Test render method call # vec_env.render() # we need a X server to test the "human" mode - vec_env.render(mode='rgb_array') - env_method_results = vec_env.env_method('custom_method', 1, indices=None, dim_1=2) + vec_env.render(mode="rgb_array") + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value for env_idx in range(N_ENVS): - setattr_results.append(vec_env.set_attr('current_step', env_idx, indices=env_idx)) + setattr_results.append(vec_env.set_attr("current_step", env_idx, indices=env_idx)) # Retrieve the value for each environment - getattr_results = vec_env.get_attr('current_step') + getattr_results = vec_env.get_attr("current_step") assert len(env_method_results) == N_ENVS assert len(setattr_results) == N_ENVS @@ -98,34 +98,34 @@ def make_env(): assert getattr_results[env_idx] == env_idx # Call env_method on a subset of the VecEnv - env_method_subset = vec_env.env_method('custom_method', 1, indices=[0, 2], dim_1=3) + env_method_subset = vec_env.env_method("custom_method", 1, indices=[0, 2], dim_1=3) assert (env_method_subset[0] == np.ones((1, 3))).all() assert (env_method_subset[1] == np.ones((1, 3))).all() assert len(env_method_subset) == 2 # Test to change value for all the environments - setattr_result = vec_env.set_attr('current_step', 42, indices=None) - getattr_result = vec_env.get_attr('current_step') + setattr_result = vec_env.set_attr("current_step", 42, indices=None) + getattr_result = vec_env.get_attr("current_step") assert setattr_result is None assert getattr_result == [42 for _ in range(N_ENVS)] # Additional tests for setattr that does not affect all the environments vec_env.reset() - setattr_result = vec_env.set_attr('current_step', 12, indices=[0, 1]) - getattr_result = vec_env.get_attr('current_step') - getattr_result_subset = vec_env.get_attr('current_step', indices=[0, 1]) + setattr_result = vec_env.set_attr("current_step", 12, indices=[0, 1]) + getattr_result = vec_env.get_attr("current_step") + getattr_result_subset = vec_env.get_attr("current_step", indices=[0, 1]) assert setattr_result is None assert getattr_result == [12 for _ in range(2)] + [0 for _ in range(N_ENVS - 2)] assert getattr_result_subset == [12, 12] - assert vec_env.get_attr('current_step', indices=[0, 2]) == [12, 0] + assert vec_env.get_attr("current_step", indices=[0, 2]) == [12, 0] vec_env.reset() # Change value only for first and last environment - setattr_result = vec_env.set_attr('current_step', 12, indices=[0, -1]) - getattr_result = vec_env.get_attr('current_step') + setattr_result = vec_env.set_attr("current_step", 12, indices=[0, -1]) + getattr_result = vec_env.get_attr("current_step") assert setattr_result is None assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12] - assert vec_env.get_attr('current_step', indices=[-1]) == [12] + assert vec_env.get_attr("current_step", indices=[-1]) == [12] vec_env.close() @@ -135,24 +135,23 @@ def __init__(self, max_steps): """Gym environment for testing that terminal observation is inserted correctly.""" self.action_space = gym.spaces.Discrete(2) - self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), - dtype='int') + self.observation_space = gym.spaces.Box(np.array([0]), np.array([999]), dtype="int") self.max_steps = max_steps self.current_step = 0 def reset(self): self.current_step = 0 - return np.array([self.current_step], dtype='int') + return np.array([self.current_step], dtype="int") def step(self, action): prev_step = self.current_step self.current_step += 1 done = self.current_step >= self.max_steps - return np.array([prev_step], dtype='int'), 0.0, done, {} + return np.array([prev_step], dtype="int"), 0.0, done, {} -@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) -@pytest.mark.parametrize('vec_env_wrapper', VEC_ENV_WRAPPERS) +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +@pytest.mark.parametrize("vec_env_wrapper", VEC_ENV_WRAPPERS) def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): """Test that 'terminal_observation' gets added to info dict upon termination.""" @@ -165,7 +164,7 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): else: vec_env = vec_env_wrapper(vec_env) - zero_acts = np.zeros((N_ENVS,), dtype='int') + zero_acts = np.zeros((N_ENVS,), dtype="int") prev_obs_b = vec_env.reset() for step_num in range(1, max(step_nums) + 1): obs_b, _, done_b, info_b = vec_env.step(zero_acts) @@ -176,9 +175,9 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): for prev_obs, obs, done, info, final_step_num in env_iter: assert done == (step_num == final_step_num) if not done: - assert 'terminal_observation' not in info + assert "terminal_observation" not in info else: - terminal_obs = info['terminal_observation'] + terminal_obs = info["terminal_observation"] # do some rough ordering checks that should work for all # wrappers, including VecNormalize @@ -196,12 +195,14 @@ def test_vecenv_terminal_obs(vec_env_class, vec_env_wrapper): vec_env.close() -SPACES = collections.OrderedDict([ - ('discrete', gym.spaces.Discrete(2)), - ('multidiscrete', gym.spaces.MultiDiscrete([2, 3])), - ('multibinary', gym.spaces.MultiBinary(3)), - ('continuous', gym.spaces.Box(low=np.zeros(2), high=np.ones(2))), -]) +SPACES = collections.OrderedDict( + [ + ("discrete", gym.spaces.Discrete(2)), + ("multidiscrete", gym.spaces.MultiDiscrete([2, 3])), + ("multibinary", gym.spaces.MultiBinary(3)), + ("continuous", gym.spaces.Box(low=np.zeros(2), high=np.ones(2))), + ] +) def check_vecenv_spaces(vec_env_class, space, obs_assert): @@ -230,7 +231,7 @@ def check_vecenv_obs(obs, space): assert space.contains(value) -@pytest.mark.parametrize('vec_env_class,space', itertools.product(VEC_ENV_CLASSES, SPACES.values())) +@pytest.mark.parametrize("vec_env_class,space", itertools.product(VEC_ENV_CLASSES, SPACES.values())) def test_vecenv_single_space(vec_env_class, space): def obs_assert(obs): return check_vecenv_obs(obs, space) @@ -245,7 +246,7 @@ def sample(self): return dict(super().sample()) -@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vecenv_dict_spaces(vec_env_class): """Test dictionary observation spaces with vectorized environments.""" space = gym.spaces.Dict(SPACES) @@ -263,7 +264,7 @@ def obs_assert(obs): check_vecenv_spaces(vec_env_class, unordered_space, obs_assert) -@pytest.mark.parametrize('vec_env_class', VEC_ENV_CLASSES) +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vecenv_tuple_spaces(vec_env_class): """Test tuple observation spaces with vectorized environments.""" space = gym.spaces.Tuple(tuple(SPACES.values())) @@ -281,7 +282,7 @@ def test_subproc_start_method(): start_methods = [None] # Only test thread-safe methods. Others may deadlock tests! (gh/428) # Note: adding unsafe `fork` method as we are now using PyTorch - all_methods = {'forkserver', 'spawn', 'fork'} + all_methods = {"forkserver", "spawn", "fork"} available_methods = multiprocessing.get_all_start_methods() start_methods += list(all_methods.intersection(available_methods)) space = gym.spaces.Discrete(2) @@ -294,20 +295,20 @@ def obs_assert(obs): check_vecenv_spaces(vec_env_class, space, obs_assert) with pytest.raises(ValueError, match="cannot find context for 'illegal_method'"): - vec_env_class = functools.partial(SubprocVecEnv, start_method='illegal_method') + vec_env_class = functools.partial(SubprocVecEnv, start_method="illegal_method") check_vecenv_spaces(vec_env_class, space, obs_assert) class CustomWrapperA(VecNormalize): def __init__(self, venv): VecNormalize.__init__(self, venv) - self.var_a = 'a' + self.var_a = "a" class CustomWrapperB(VecNormalize): def __init__(self, venv): VecNormalize.__init__(self, venv) - self.var_b = 'b' + self.var_b = "b" def func_b(self): return self.var_b @@ -319,7 +320,7 @@ def name_test(self): class CustomWrapperBB(CustomWrapperB): def __init__(self, venv): CustomWrapperB.__init__(self, venv) - self.var_bb = 'bb' + self.var_bb = "bb" def test_vecenv_wrapper_getattr(): @@ -328,10 +329,10 @@ def make_env(): vec_env = DummyVecEnv([make_env for _ in range(N_ENVS)]) wrapped = CustomWrapperA(CustomWrapperBB(vec_env)) - assert wrapped.var_a == 'a' - assert wrapped.var_b == 'b' - assert wrapped.var_bb == 'bb' - assert wrapped.func_b() == 'b' + assert wrapped.var_a == "a" + assert wrapped.var_b == "b" + assert wrapped.var_bb == "bb" + assert wrapped.func_b() == "b" assert wrapped.name_test() == CustomWrapperBB double_wrapped = CustomWrapperA(CustomWrapperB(wrapped)) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 979ea67e2..311e3c92e 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,13 +1,18 @@ import gym -import pytest import numpy as np +import pytest -from stable_baselines3.common.running_mean_std import RunningMeanStd -from stable_baselines3.common.vec_env import (DummyVecEnv, VecNormalize, VecFrameStack, sync_envs_normalization, - unwrap_vec_normalize) from stable_baselines3 import SAC, TD3 +from stable_baselines3.common.running_mean_std import RunningMeanStd +from stable_baselines3.common.vec_env import ( + DummyVecEnv, + VecFrameStack, + VecNormalize, + sync_envs_normalization, + unwrap_vec_normalize, +) -ENV_ID = 'Pendulum-v0' +ENV_ID = "Pendulum-v0" def make_env(): @@ -54,8 +59,9 @@ def _make_warmstart_cartpole(): def test_runningmeanstd(): """Test RunningMeanStd object""" for (x_1, x_2, x_3) in [ - (np.random.randn(3), np.random.randn(4), np.random.randn(5)), - (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2))]: + (np.random.randn(3), np.random.randn(4), np.random.randn(5)), + (np.random.randn(3, 2), np.random.randn(4, 2), np.random.randn(5, 2)), + ]: rms = RunningMeanStd(epsilon=0.0, shape=x_1.shape[1:]) x_cat = np.concatenate([x_1, x_2, x_3], axis=0) @@ -120,12 +126,12 @@ def test_normalize_external(): @pytest.mark.parametrize("model_class", [SAC, TD3]) def test_offpolicy_normalization(model_class): env = DummyVecEnv([make_env]) - env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10., clip_reward=10.) + env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0) eval_env = DummyVecEnv([make_env]) - eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10., clip_reward=10.) + eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0) - model = model_class('MlpPolicy', env, verbose=1, policy_kwargs=dict(net_arch=[64])) + model = model_class("MlpPolicy", env, verbose=1, policy_kwargs=dict(net_arch=[64])) model.learn(total_timesteps=1000, eval_env=eval_env, eval_freq=500) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize) @@ -136,7 +142,7 @@ def test_sync_vec_normalize(): assert unwrap_vec_normalize(env) is None - env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100., clip_reward=100.) + env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) assert isinstance(unwrap_vec_normalize(env), VecNormalize) @@ -145,8 +151,7 @@ def test_sync_vec_normalize(): assert isinstance(unwrap_vec_normalize(env), VecNormalize) eval_env = DummyVecEnv([make_env]) - eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, - clip_obs=100., clip_reward=100.) + eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=True, clip_obs=100.0, clip_reward=100.0) eval_env = VecFrameStack(eval_env, 1) env.seed(0)