Skip to content

Commit

Permalink
Reshape input in tf/cnn (rlworkgroup#2168)
Browse files Browse the repository at this point in the history
* Reshape input in tf/cnn

* Reshape input in tf/cnn

* Fix test

* Pass input_dim in constructor of CNN model

* Pass input_dim in constructor of CNN model

* Make tf/categotical_cnn_policy's obs_ph flattened

* Fix precommit
  • Loading branch information
yeukfu authored Dec 17, 2020
1 parent 75e4c9a commit e395ba6
Show file tree
Hide file tree
Showing 20 changed files with 181 additions and 57 deletions.
4 changes: 2 additions & 2 deletions src/garage/tf/baselines/gaussian_cnn_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(self,
else:
self._optimizer = make_optimizer(optimizer, **optimizer_args)

super().__init__(input_shape=env_spec.observation_space.shape,
super().__init__(input_dim=env_spec.observation_space.shape,
output_dim=1,
filters=filters,
strides=strides,
Expand Down Expand Up @@ -318,7 +318,7 @@ def clone_model(self, name):
"""
new_baseline = GaussianCNNBaselineModel(
name=name,
input_shape=self._env_spec.observation_space.shape,
input_dim=self._env_spec.observation_space.shape,
output_dim=1,
filters=self._filters,
strides=self._strides,
Expand Down
13 changes: 8 additions & 5 deletions src/garage/tf/baselines/gaussian_cnn_baseline_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ class GaussianCNNBaselineModel(GaussianCNNModel):
distribution to the outputs.
Args:
input_shape(tuple[int]): Input shape of the model (without the batch
dimension).
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand Down Expand Up @@ -93,7 +94,7 @@ class GaussianCNNBaselineModel(GaussianCNNModel):
"""

def __init__(self,
input_shape,
input_dim,
output_dim,
filters,
strides,
Expand Down Expand Up @@ -127,7 +128,8 @@ def __init__(self,
seed=deterministic.get_tf_seed_stream()),
std_parameterization='exp',
layer_normalization=False):
super().__init__(output_dim=output_dim,
super().__init__(input_dim=input_dim,
output_dim=output_dim,
filters=filters,
strides=strides,
padding=padding,
Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(self,
std_parameterization=std_parameterization,
layer_normalization=layer_normalization,
name=name)
self._input_shape = input_shape
self._input_shape = input_dim

def network_output_spec(self):
"""Network output spec.
Expand All @@ -171,6 +173,7 @@ def network_output_spec(self):
'y_mean', 'y_std'
]

# pylint: disable=arguments-differ
def _build(self, state_input, name=None):
"""Build model given input placeholder(s).
Expand Down
7 changes: 6 additions & 1 deletion src/garage/tf/models/categorical_cnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class CategoricalCNNModel(Model):
by a multilayer perceptron (MLP).
Args:
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
output_dim (int): Dimension of the network output.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
Expand Down Expand Up @@ -59,6 +62,7 @@ class CategoricalCNNModel(Model):
"""

def __init__(self,
input_dim,
output_dim,
filters,
strides,
Expand All @@ -77,7 +81,8 @@ def __init__(self,
layer_normalization=False):
super().__init__(name)
self._is_image = is_image
self._cnn_model = CNNModel(filters=filters,
self._cnn_model = CNNModel(input_dim=input_dim,
filters=filters,
strides=strides,
padding=padding,
hidden_nonlinearity=hidden_nonlinearity,
Expand Down
14 changes: 14 additions & 0 deletions src/garage/tf/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


def cnn(input_var,
input_dim,
filters,
strides,
name,
Expand All @@ -21,6 +22,9 @@ def cnn(input_var,
Args:
input_var (tf.Tensor): Input tf.Tensor to the CNN.
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand All @@ -47,6 +51,9 @@ def cnn(input_var,
"""
with tf.compat.v1.variable_scope(name):
# unflatten
input_var = tf.reshape(input_var, [-1, *input_dim])

h = input_var
for index, (filter_iter, stride) in enumerate(zip(filters, strides)):
_stride = [1, stride, stride, 1]
Expand All @@ -61,6 +68,7 @@ def cnn(input_var,


def cnn_with_max_pooling(input_var,
input_dim,
filters,
strides,
name,
Expand All @@ -78,6 +86,9 @@ def cnn_with_max_pooling(input_var,
Args:
input_var (tf.Tensor): Input tf.Tensor to the CNN.
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand Down Expand Up @@ -113,6 +124,9 @@ def cnn_with_max_pooling(input_var,
pool_shapes = [1, pool_shapes[0], pool_shapes[1], 1]

with tf.compat.v1.variable_scope(name):
# unflatten
input_var = tf.reshape(input_var, [-1, *input_dim])

h = input_var
for index, (filter_iter, stride) in enumerate(zip(filters, strides)):
_stride = [1, stride, stride, 1]
Expand Down
6 changes: 6 additions & 0 deletions src/garage/tf/models/cnn_mlp_merge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class CNNMLPMergeModel(Model):
the MLP accepts the CNN's output and the action as inputs.
Args:
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand Down Expand Up @@ -76,6 +79,7 @@ class CNNMLPMergeModel(Model):
"""

def __init__(self,
input_dim,
filters,
strides,
hidden_sizes=(256, ),
Expand Down Expand Up @@ -103,6 +107,7 @@ def __init__(self,

if not max_pooling:
self.cnn_model = CNNModel(
input_dim=input_dim,
filters=filters,
hidden_w_init=cnn_hidden_w_init,
hidden_b_init=cnn_hidden_b_init,
Expand All @@ -111,6 +116,7 @@ def __init__(self,
hidden_nonlinearity=cnn_hidden_nonlinearity)
else:
self.cnn_model = CNNModelWithMaxPooling(
input_dim=input_dim,
filters=filters,
hidden_w_init=cnn_hidden_w_init,
hidden_b_init=cnn_hidden_b_init,
Expand Down
6 changes: 6 additions & 0 deletions src/garage/tf/models/cnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class CNNModel(Model):
"""CNN Model.
Args:
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand All @@ -34,6 +37,7 @@ class CNNModel(Model):
"""

def __init__(self,
input_dim,
filters,
strides,
padding,
Expand All @@ -43,6 +47,7 @@ def __init__(self,
seed=deterministic.get_tf_seed_stream()),
hidden_b_init=tf.zeros_initializer()):
super().__init__(name)
self._input_dim = input_dim
self._filters = filters
self._strides = strides
self._padding = padding
Expand All @@ -66,6 +71,7 @@ def _build(self, state_input, name=None):
"""
del name
return cnn(input_var=state_input,
input_dim=self._input_dim,
filters=self._filters,
hidden_nonlinearity=self._hidden_nonlinearity,
hidden_w_init=self._hidden_w_init,
Expand Down
6 changes: 6 additions & 0 deletions src/garage/tf/models/cnn_model_max_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class CNNModelWithMaxPooling(Model):
"""CNN Model with max pooling.
Args:
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand Down Expand Up @@ -40,6 +43,7 @@ class CNNModelWithMaxPooling(Model):
"""

def __init__(self,
input_dim,
filters,
strides,
name=None,
Expand All @@ -51,6 +55,7 @@ def __init__(self,
seed=deterministic.get_tf_seed_stream()),
hidden_b_init=tf.zeros_initializer()):
super().__init__(name)
self._input_dim = input_dim
self._filters = filters
self._strides = strides
self._padding = padding
Expand All @@ -77,6 +82,7 @@ def _build(self, state_input, name=None):
del name
return cnn_with_max_pooling(
input_var=state_input,
input_dim=self._input_dim,
filters=self._filters,
hidden_nonlinearity=self._hidden_nonlinearity,
hidden_w_init=self._hidden_w_init,
Expand Down
8 changes: 8 additions & 0 deletions src/garage/tf/models/gaussian_cnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class GaussianCNNModel(Model):
"""GaussianCNNModel.
Args:
input_dim (Tuple[int, int, int]): Dimensions of unflattened input,
which means [in_height, in_width, in_channels]. If the last 3
dimensions of input_var is not this shape, it will be reshaped.
filters (Tuple[Tuple[int, Tuple[int, int]], ...]): Number and dimension
of filters. For example, ((3, (3, 5)), (32, (3, 3))) means there
are two convolutional layers. The filter for the first layer have 3
Expand Down Expand Up @@ -91,6 +94,7 @@ class GaussianCNNModel(Model):
"""

def __init__(self,
input_dim,
output_dim,
filters,
strides,
Expand Down Expand Up @@ -126,6 +130,7 @@ def __init__(self,
layer_normalization=False):
# Network parameters
super().__init__(name)
self._input_dim = input_dim
self._output_dim = output_dim
self._filters = filters
self._strides = strides
Expand Down Expand Up @@ -214,6 +219,7 @@ def _build(self, state_input, name=None):

mean_std_conv = cnn(
input_var=state_input,
input_dim=self._input_dim,
filters=self._filters,
hidden_nonlinearity=self._hidden_nonlinearity,
hidden_w_init=self._hidden_w_init,
Expand Down Expand Up @@ -242,6 +248,7 @@ def _build(self, state_input, name=None):
# separate MLPs for mean and std networks
# mean network
mean_conv = cnn(input_var=state_input,
input_dim=self._input_dim,
filters=self._filters,
hidden_nonlinearity=self._hidden_nonlinearity,
hidden_w_init=self._hidden_w_init,
Expand All @@ -267,6 +274,7 @@ def _build(self, state_input, name=None):
if self._adaptive_std:
log_std_conv = cnn(
input_var=state_input,
input_dim=self._input_dim,
filters=self._std_filters,
hidden_nonlinearity=self._std_hidden_nonlinearity,
hidden_w_init=self._std_hidden_w_init,
Expand Down
15 changes: 7 additions & 8 deletions src/garage/tf/policies/categorical_cnn_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__(self,

is_image = isinstance(self.env_spec.observation_space, akro.Image)

super().__init__(output_dim=self._action_dim,
super().__init__(input_dim=self._obs_dim,
output_dim=self._action_dim,
filters=filters,
strides=strides,
padding=padding,
Expand All @@ -123,9 +124,9 @@ def __init__(self,

def _initialize(self):
"""Initialize policy."""
flat_dim = np.prod(self._obs_dim)
state_input = tf.compat.v1.placeholder(tf.float32,
shape=(None, None) +
self._obs_dim)
shape=(None, None, flat_dim))
if isinstance(self.env_spec.observation_space, akro.Image):
augmented_state_input = tf.cast(state_input, tf.float32)
augmented_state_input /= 255.0
Expand Down Expand Up @@ -169,11 +170,9 @@ def get_actions(self, observations):
dict(numpy.ndarray): Distribution parameters.
"""
if isinstance(self.env_spec.observation_space, akro.Image) and \
len(observations[0].shape) < \
len(self.env_spec.observation_space.shape):
observations = self.env_spec.observation_space.unflatten_n(
observations)
if not isinstance(observations[0],
np.ndarray) or len(observations[0].shape) > 1:
observations = self.observation_space.flatten_n(observations)
samples, probs = self._f_prob(np.expand_dims(observations, 1))
return np.squeeze(samples), dict(prob=np.squeeze(probs, axis=1))

Expand Down
1 change: 1 addition & 0 deletions src/garage/tf/q_functions/continuous_cnn_q_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self,
self._action_dim = self._env_spec.action_space.shape

super().__init__(name=name,
input_dim=self._obs_dim,
filters=self._filters,
strides=self._strides,
hidden_sizes=self._hidden_sizes,
Expand Down
11 changes: 5 additions & 6 deletions src/garage/tf/q_functions/discrete_cnn_q_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@
import tensorflow as tf

from garage.experiment import deterministic
from garage.tf.models import (CNNModel,
CNNModelWithMaxPooling,
MLPDuelingModel,
MLPModel,
Sequential)
from garage.tf.models import (CNNModel, CNNModelWithMaxPooling,
MLPDuelingModel, MLPModel, Sequential)

# yapf: enable

Expand Down Expand Up @@ -123,12 +120,14 @@ def __init__(self,
action_dim = self._env_spec.action_space.flat_dim

if not max_pooling:
cnn_model = CNNModel(filters=filters,
cnn_model = CNNModel(input_dim=self.obs_dim,
filters=filters,
strides=strides,
padding=padding,
hidden_nonlinearity=cnn_hidden_nonlinearity)
else:
cnn_model = CNNModelWithMaxPooling(
input_dim=self.obs_dim,
filters=filters,
strides=strides,
padding=padding,
Expand Down
1 change: 1 addition & 0 deletions src/garage/torch/modules/discrete_cnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,
self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(),
nn.Flatten(), mlp_module)

# pylint: disable=arguments-differ
def forward(self, inputs):
"""Forward method.
Expand Down
1 change: 1 addition & 0 deletions src/garage/torch/modules/discrete_dueling_cnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self,
self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(),
nn.Flatten())

# pylint: disable=arguments-differ
def forward(self, inputs):
"""Forward method.
Expand Down
Loading

0 comments on commit e395ba6

Please sign in to comment.