Skip to content

Commit

Permalink
Avoid transposing channel-first envs (DLR-RM#213)
Browse files Browse the repository at this point in the history
* Add test for channel-first environments

* Add support for channel-first envs, including more tests

* Update changelog

* Run black

* Run black, again

* Improve NatureCNN error message

* Update image checks and FrameStack wrapper

* Update tests

* Update docs

* Run isort

* Reformat

* Fixes: avoid breaking changes for non-image env

* Add additional checks

* Update docstring

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
Miffyli and araffin authored Nov 3, 2020
1 parent 9d463bc commit e2b6f54
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 24 deletions.
7 changes: 4 additions & 3 deletions docs/guide/custom_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ That is to say, your environment must implement the following methods (and inher


.. note::
If you are using images as input, the input values must be in [0, 255] as the observation
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies.
If you are using images as input, the input values must be in [0, 255] and np.uint8 as the observation
is normalized (dividing by 255 to have values in [0, 1]) when using CNN policies. Images can be either
channel-first or channel-last.



Expand All @@ -28,7 +29,7 @@ That is to say, your environment must implement the following methods (and inher
# They must be gym.spaces objects
# Example when using discrete actions:
self.action_space = spaces.Discrete(N_DISCRETE_ACTIONS)
# Example for using image as input:
# Example for using image as input (can be channel-first or channel-last):
self.observation_space = spaces.Box(low=0, high=255,
shape=(HEIGHT, WIDTH, N_CHANNELS), dtype=np.uint8)
Expand Down
3 changes: 2 additions & 1 deletion docs/guide/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ Pre-Processing
To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``.

For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention.
For images, environment is automatically wrapped with ``VecTransposeImage`` if observations are detected to be images with
channel-last convention to transform it to PyTorch's channel-first convention.


Policy Structure
Expand Down
8 changes: 7 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@ Changelog
Pre-Release 0.11.0a0 (WIP)
-------------------------------


Breaking Changes:
^^^^^^^^^^^^^^^^^


New Features:
^^^^^^^^^^^^^
- Add support for ``VecFrameStack`` to stack on first or last observation dimension, along with
automatic check for image spaces.
- ``VecFrameStack`` now has a ``channels_order`` argument to tell if observations should be stacked
on the first or last observation dimension (originally always stacked on last).

Bug Fixes:
^^^^^^^^^^
- Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv)

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Add more issue templates
- Improve error message in ``NatureCNN``

Documentation:
^^^^^^^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Run the benchmark (replace ``$ENV_ID`` by the env id, for instance ``BreakoutNoF

.. code-block:: bash
python train.py --algo a2c --env $ENV_ID --eval-episodes 10 --eval-freq 10000
python train.py --algo dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000
Plot the results:
Expand Down
8 changes: 6 additions & 2 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import (
Expand Down Expand Up @@ -176,7 +176,11 @@ def _wrap_env(env: GymEnv, verbose: int = 0) -> VecEnv:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])

if is_image_space(env.observation_space) and not is_wrapped(env, VecTransposeImage):
if (
is_image_space(env.observation_space)
and not is_wrapped(env, VecTransposeImage)
and not is_image_space_channels_first(env.observation_space)
):
if verbose >= 1:
print("Wrapping the env in a VecTransposeImage.")
env = VecTransposeImage(env)
Expand Down
17 changes: 13 additions & 4 deletions stable_baselines3/common/identity_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,23 @@ class FakeImageEnv(Env):
:param screen_height: Height of the image
:param screen_width: Width of the image
:param n_channels: Number of color channels
:param discrete:
:param discrete: Create discrete action space instead of continuous
:param channel_first: Put channels on first axis instead of last
"""

def __init__(
self, action_dim: int = 6, screen_height: int = 84, screen_width: int = 84, n_channels: int = 1, discrete: bool = True
self,
action_dim: int = 6,
screen_height: int = 84,
screen_width: int = 84,
n_channels: int = 1,
discrete: bool = True,
channel_first: bool = False,
):

self.observation_space = Box(low=0, high=255, shape=(screen_height, screen_width, n_channels), dtype=np.uint8)
self.observation_shape = (screen_height, screen_width, n_channels)
if channel_first:
self.observation_shape = (n_channels, screen_height, screen_width)
self.observation_space = Box(low=0, high=255, shape=self.observation_shape, dtype=np.uint8)
if discrete:
self.action_space = Discrete(action_dim)
else:
Expand Down
18 changes: 18 additions & 0 deletions stable_baselines3/common/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Tuple

import numpy as np
Expand All @@ -6,6 +7,23 @@
from torch.nn import functional as F


def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
"""
Check if an image observation space (see ``is_image_space``)
is channels-first (CxHxW, True) or channels-last (HxWxC, False).
Use a heuristic that channel dimension is the smallest of the three.
If second dimension is smallest, raise an exception (no support).
:param observation_space:
:return: True if observation space is channels-first image, False if channels-last.
"""
smallest_dimension = np.argmin(observation_space.shape).item()
if smallest_dimension == 1:
warnings.warn("Treating image space as channels-last, while second dimension was smallest of the three.")
return smallest_dimension == 0


def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool:
"""
Check if a observation space has the shape, limits and dtype
Expand Down
7 changes: 5 additions & 2 deletions stable_baselines3/common/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
# Re-ordering will be done by pre-preprocessing or wrapper
assert is_image_space(observation_space), (
"You should use NatureCNN "
f"only with images not with {observation_space} "
"(you are probably using `CnnPolicy` instead of `MlpPolicy`)"
f"only with images not with {observation_space}\n"
"(you are probably using `CnnPolicy` instead of `MlpPolicy`)\n"
"If you are using a custom environment,\n"
"please check it using our env checker:\n"
"https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html"
)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
Expand Down
59 changes: 49 additions & 10 deletions stable_baselines3/common/vec_env/vec_frame_stack.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,52 @@
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
from gym import spaces

from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper


class VecFrameStack(VecEnvWrapper):
"""
Frame stacking wrapper for vectorized environment
Frame stacking wrapper for vectorized environment. Designed for image observations.
Dimension to stack over is either first (channels-first) or
last (channels-last), which is detected automatically using
``common.preprocessing.is_image_space_channels_first`` if
observation is an image space.
:param venv: the vectorized environment to wrap
:param n_stack: Number of frames to stack
:param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension.
If None, automatically detect channel to stack over in case of image observation or default to "last" (default).
"""

def __init__(self, venv: VecEnv, n_stack: int):
def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[str] = None):
self.venv = venv
self.n_stack = n_stack

wrapped_obs_space = venv.observation_space
assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space"
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=-1)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=-1)

if channels_order is None:
# Detect channel location automatically for images
if is_image_space(wrapped_obs_space):
self.channels_first = is_image_space_channels_first(wrapped_obs_space)
else:
# Default behavior for non-image space, stack on the last axis
self.channels_first = False
else:
assert channels_order in {"last", "first"}, "`channels_order` must be one of following: 'last', 'first'"

self.channels_first = channels_order == "first"

# This includes the vec-env dimension (first)
self.stack_dimension = 1 if self.channels_first else -1
repeat_axis = 0 if self.channels_first else -1
low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=repeat_axis)
high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=repeat_axis)
self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype)
observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype)
VecEnvWrapper.__init__(self, venv, observation_space=observation_space)
Expand All @@ -30,18 +55,29 @@ def step_wait(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str,
observations, rewards, dones, infos = self.venv.step_wait()
# Let pytype know that observation is not a dict
assert isinstance(observations, np.ndarray)
last_ax_size = observations.shape[-1]
self.stackedobs = np.roll(self.stackedobs, shift=-last_ax_size, axis=-1)
stack_ax_size = observations.shape[self.stack_dimension]
self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension)
for i, done in enumerate(dones):
if done:
if "terminal_observation" in infos[i]:
old_terminal = infos[i]["terminal_observation"]
new_terminal = np.concatenate((self.stackedobs[i, ..., :-last_ax_size], old_terminal), axis=-1)
if self.channels_first:
new_terminal = np.concatenate(
(self.stackedobs[i, :-stack_ax_size, ...], old_terminal), axis=self.stack_dimension
)
else:
new_terminal = np.concatenate(
(self.stackedobs[i, ..., :-stack_ax_size], old_terminal), axis=self.stack_dimension
)
infos[i]["terminal_observation"] = new_terminal
else:
warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info")
self.stackedobs[i] = 0
self.stackedobs[..., -observations.shape[-1] :] = observations
if self.channels_first:
self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations
else:
self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations

return self.stackedobs, rewards, dones, infos

def reset(self) -> np.ndarray:
Expand All @@ -50,7 +86,10 @@ def reset(self) -> np.ndarray:
"""
obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1] :] = obs
if self.channels_first:
self.stackedobs[:, -obs.shape[self.stack_dimension] :, ...] = obs
else:
self.stackedobs[..., -obs.shape[self.stack_dimension] :] = obs
return self.stackedobs

def close(self) -> None:
Expand Down
76 changes: 76 additions & 0 deletions tests/test_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import numpy as np
import pytest
import torch as th
from gym import spaces

from stable_baselines3 import A2C, DQN, PPO, SAC, TD3
from stable_baselines3.common.identity_env import FakeImageEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.utils import zip_strict
from stable_baselines3.common.vec_env import VecTransposeImage, is_wrapped


@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN])
Expand All @@ -25,6 +28,9 @@ def test_cnn(tmp_path, model_class):
kwargs = dict(buffer_size=250, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)))
model = model_class("CnnPolicy", env, **kwargs).learn(250)

# FakeImageEnv is channel last by default and should be wrapped
assert is_wrapped(model.get_env(), VecTransposeImage)

obs = env.reset()

action, _ = model.predict(obs, deterministic=True)
Expand Down Expand Up @@ -174,3 +180,73 @@ def test_features_extractor_target_net(model_class, share_features_extractor):
params_should_match(original_actor_param, model.actor.parameters())

td3_features_extractor_check(model)


def test_channel_first_env(tmp_path):
# test_cnn uses environment with HxWxC setup that is transposed, but we
# also want to work with CxHxW envs directly without transposing wrapper.
SAVE_NAME = "cnn_model.zip"

# Create environment with transposed images (CxHxW).
# If underlying CNN processes the data in wrong format,
# it will raise an error of negative dimension sizes while creating convolutions
env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=1, discrete=True, channel_first=True)

model = A2C("CnnPolicy", env, n_steps=100).learn(250)

assert not is_wrapped(model.get_env(), VecTransposeImage)

obs = env.reset()

action, _ = model.predict(obs, deterministic=True)

model.save(tmp_path / SAVE_NAME)
del model

model = A2C.load(tmp_path / SAVE_NAME)

# Check that the prediction is the same
assert np.allclose(action, model.predict(obs, deterministic=True)[0])

os.remove(str(tmp_path / SAVE_NAME))


def test_image_space_checks():
not_image_space = spaces.Box(0, 1, shape=(10,))
assert not is_image_space(not_image_space)

# Not uint8
not_image_space = spaces.Box(0, 255, shape=(10, 10, 3))
assert not is_image_space(not_image_space)

# Not correct shape
not_image_space = spaces.Box(0, 255, shape=(10, 10), dtype=np.uint8)
assert not is_image_space(not_image_space)

# Not correct low/high
not_image_space = spaces.Box(0, 10, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space(not_image_space)

# Not correct space
not_image_space = spaces.Discrete(n=10)
assert not is_image_space(not_image_space)

an_image_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
assert is_image_space(an_image_space)

an_image_space_with_odd_channels = spaces.Box(0, 255, shape=(10, 10, 5), dtype=np.uint8)
assert is_image_space(an_image_space_with_odd_channels)
# Should not pass if we check if channels are valid for an image
assert not is_image_space(an_image_space_with_odd_channels, check_channels=True)

# Test if channel-check works
channel_first_space = spaces.Box(0, 255, shape=(3, 10, 10), dtype=np.uint8)
assert is_image_space_channels_first(channel_first_space)

channel_last_space = spaces.Box(0, 255, shape=(10, 10, 3), dtype=np.uint8)
assert not is_image_space_channels_first(channel_last_space)

channel_mid_space = spaces.Box(0, 255, shape=(10, 3, 10), dtype=np.uint8)
# Should raise a warning
with pytest.warns(Warning):
assert not is_image_space_channels_first(channel_mid_space)
Loading

0 comments on commit e2b6f54

Please sign in to comment.