Skip to content

Commit

Permalink
Add custom objects support + bug fix (DLR-RM#336)
Browse files Browse the repository at this point in the history
* Add support for custom objects

* Add python 3.8 to the CI

* Bump version

* PyType fixes

* [ci skip] Fix typo

* Add note about slow-down + fix typos

* Minor edits to the doc

* Bug fix for DQN

* Update test

* Add test for custom objects
  • Loading branch information
araffin authored Mar 6, 2021
1 parent f13de5b commit c62e925
Show file tree
Hide file tree
Showing 27 changed files with 118 additions and 60 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
python-version: [3.6, 3.7, 3.8]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libglib2.0-0 && \
rm -rf /var/lib/apt/lists/*

# Install anaconda abd dependencies
# Install Anaconda and dependencies
RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/conda && \
Expand Down
7 changes: 7 additions & 0 deletions docs/guide/migration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ You can also take a look at the `rl-zoo3 <https://github.com/DLR-RM/rl-baselines
to the `rl-zoo <https://github.com/araffin/rl-baselines-zoo>`_ of SB2 to have a concrete example of successful migration.


.. note::

If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used,
using ``torch.set_num_threads(1)`` or ``OMP_NUM_THREADS=1``, see `issue #122 <https://github.com/DLR-RM/stable-baselines3/issues/122>`_
and `issue #90 <https://github.com/DLR-RM/stable-baselines3/issues/90>`_.


Breaking Changes
================

Expand Down
14 changes: 7 additions & 7 deletions docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ Discrete Actions
Discrete Actions - Single Process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
We notably provide QR-DQN in our :ref:`contrib repo <sb3_contrib>`.
DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).
``DQN`` with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
We notably provide ``QR-DQN`` in our :ref:`contrib repo <sb3_contrib>`.
``DQN`` is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).

Discrete Actions - Multiprocessed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You should give a try to PPO or A2C.
You should give a try to ``PPO`` or ``A2C``.


Continuous Actions
Expand All @@ -142,7 +142,7 @@ Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-base
Continuous Actions - Multiprocessed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Take a look at PPO, TRPO or A2C. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_
Take a look at ``PPO`` or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_
for continuous actions problems (cf *Bullet* envs).

.. note::
Expand All @@ -155,12 +155,12 @@ Goal Environment
-----------------

If your environment follows the ``GoalEnv`` interface (cf :ref:`HER <her>`), then you should use
HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space.
HER + (SAC/TD3/DDPG/DQN/QR-DQN/TQC) depending on the action space.


.. note::

The number of workers is an important hyperparameters for experiments with HER
The ``batch_size`` is an important hyperparameter for experiments with :ref:`HER <her>`



Expand Down
11 changes: 10 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,26 @@
Changelog
==========

Release 1.0rc0 (2021-02-28)
Release 1.0rc1 (WIP)
-------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed ``stable_baselines3.common.cmd_util`` (already deprecated), please use ``env_util`` instead

New Features:
^^^^^^^^^^^^^
- Added support for ``custom_objects`` when loading models

Bug Fixes:
^^^^^^^^^^
- Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space

Documentation:
^^^^^^^^^^^^^^
- Fixed examples
- Added new project using SB3: rl_reach (@PierreExeter)
- Added note about slow-down when switching to PyTorch
- Add a note on continual learning and resetting environment


Expand Down
5 changes: 2 additions & 3 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
import gym
from stable_baselines3 import A2C
from stable_baselines3.a2c import MlpPolicy
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
env = make_vec_env("CartPole-v1", n_envs=4)
model = A2C(MlpPolicy, env, verbose=1)
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")
Expand Down
4 changes: 2 additions & 2 deletions docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ Example
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
env = gym.make("Pendulum-v0")
# The noise objects for DDPG
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = DDPG('MlpPolicy', env, action_noise=action_noise, verbose=1)
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("ddpg_pendulum")
env = model.get_env()
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ Example
import numpy as np
from stable_baselines3 import DQN
from stable_baselines3.dqn import MlpPolicy
env = gym.make('CartPole-v0')
env = gym.make("CartPole-v0")
model = DQN(MlpPolicy, env, verbose=1)
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("dqn_pendulum")
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/her.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Notes
Can I use?
----------

Please refer to the used model (DQN, SAC, TD3 or DDPG) for that section.
Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section.

Example
-------
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments.
import gym
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.env_util import make_vec_env
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
env = make_vec_env("CartPole-v1", n_envs=4)
model = PPO(MlpPolicy, env, verbose=1)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=25000)
model.save("ppo_cartpole")
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,10 @@ Example
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.sac import MlpPolicy
env = gym.make('Pendulum-v0')
env = gym.make("Pendulum-v0")
model = SAC(MlpPolicy, env, verbose=1)
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("sac_pendulum")
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ Example
import numpy as np
from stable_baselines3 import TD3
from stable_baselines3.td3.policies import MlpPolicy
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
env = gym.make("Pendulum-v0")
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
model.learn(total_timesteps=10000, log_interval=10)
model.save("td3_pendulum")
env = model.get_env()
Expand Down
5 changes: 5 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,8 @@ cuda
Polyak
gSDE
rollouts
Pyro
softmax
stdout
Contrib
Quantile
11 changes: 9 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def load(
path: Union[str, pathlib.Path, io.BufferedIOBase],
env: Optional[GymEnv] = None,
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
**kwargs,
) -> "BaseAlgorithm":
"""
Expand All @@ -596,9 +597,15 @@ def load(
:param env: the new environment to run the loaded model on
(can be None if you only need prediction from a trained model) has priority over any saved environment
:param device: Device on which the code should run.
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param kwargs: extra arguments to change the model when loading
"""
data, params, pytorch_variables = load_from_zip_file(path, device=device)
data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects)

# Remove stored device information and replace with ours
if "policy_kwargs" in data:
Expand All @@ -625,7 +632,7 @@ def load(
env = data["env"]

# noinspection PyArgumentList
model = cls(
model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
policy=data["policy_class"],
env=env,
device=device,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def atanh(x: th.Tensor) -> th.Tensor:
"""
Inverse of Tanh
Taken from pyro: https://github.com/pyro-ppl/pyro
Taken from Pyro: https://github.com/pyro-ppl/pyro
0.5 * torch.log((1 + x ) / (1 - x))
"""
return 0.5 * (x.log1p() - (-x).log1p())
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def evaluate_policy(
called after each step. Gets locals() and globals() passed as parameters.
:param reward_threshold: Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: If True, a list of rewards and episde lengths
:param return_episode_rewards: If True, a list of rewards and episode lengths
per episode will be returned instead of the mean.
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
evaluation environment.
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _setup_model(self) -> None:
self.device,
optimize_memory_usage=self.optimize_memory_usage,
)
self.policy = self.policy_class(
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _setup_model(self) -> None:
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
)
self.policy = self.policy_class(
self.policy = self.policy_class( # pytype:disable=not-instantiable
self.observation_space,
self.action_space,
self.lr_schedule,
Expand Down
15 changes: 2 additions & 13 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
StateDependentNoiseDistribution,
make_proba_distribution,
)
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs
from stable_baselines3.common.preprocessing import get_action_dim, maybe_transpose, preprocess_obs
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.utils import get_device, is_vectorized_observation
from stable_baselines3.common.vec_env import VecTransposeImage
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper


Expand Down Expand Up @@ -266,17 +265,7 @@ def predict(

# Handle the different cases for images
# as PyTorch use channel first format
if is_image_space(self.observation_space):
if not (
observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape
):
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if (
transpose_obs.shape == self.observation_space.shape
or transpose_obs.shape[1:] == self.observation_space.shape
):
observation = transpose_obs
observation = maybe_transpose(observation, self.observation_space)

vectorized_env = is_vectorized_observation(observation, self.observation_space)

Expand Down
20 changes: 20 additions & 0 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True,
return False


def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
"""
Handle the different cases for images as PyTorch use channel first format.
:param observation:
:param observation_space:
:return: channel first observation if observation is an image
"""
# Avoid circular import
from stable_baselines3.common.vec_env import VecTransposeImage

if is_image_space(observation_space):
if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
# Try to re-order the channels
transpose_obs = VecTransposeImage.transpose_image(observation)
if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
observation = transpose_obs
return observation


def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor:
"""
Preprocess observation to be to a neural network.
Expand Down
13 changes: 10 additions & 3 deletions stable_baselines3/common/save_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
`keras.models.load_model`. Useful when you have an object in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:return: Loaded class parameters.
"""
Expand All @@ -162,7 +162,7 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
try:
base64_object = base64.b64decode(serialization.encode())
deserialized_object = cloudpickle.loads(base64_object)
except RuntimeError:
except (RuntimeError, TypeError):
warnings.warn(
f"Could not deserialize object {data_key}. "
+ "Consider using `custom_objects` argument to replace "
Expand Down Expand Up @@ -359,6 +359,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
def load_from_zip_file(
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
load_data: bool = True,
custom_objects: Optional[Dict[str, Any]] = None,
device: Union[th.device, str] = "auto",
verbose: int = 0,
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
Expand All @@ -368,6 +369,12 @@ def load_from_zip_file(
:param load_path: Where to load the model from
:param load_data: Whether we should load and return data
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
:param custom_objects: Dictionary of objects to replace
upon loading. If a variable is present in this dictionary as a
key, it will not be deserialized and the corresponding item
will be used instead. Similar to custom_objects in
``keras.models.load_model``. Useful when you have an object in
file that can not be deserialized.
:param device: Device on which the code should run.
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
and dict of pytorch variables
Expand All @@ -392,7 +399,7 @@ def load_from_zip_file(
# Load class parameters that are stored
# with either JSON or pickle (not PyTorch variables).
json_data = archive.read("data").decode()
data = json_to_data(json_data)
data = json_to_data(json_data, custom_objects=custom_objects)

# Check for all .pth files and load them using th.load.
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth
Expand Down
Loading

0 comments on commit c62e925

Please sign in to comment.