-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CategoricalConvPolicy to TF sandbox to align with Theano implemen…
…tation (rll#133)
- Loading branch information
1 parent
6be5750
commit 38a2a41
Showing
1 changed file
with
96 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from sandbox.rocky.tf.core.layers_powered import LayersPowered | ||
import sandbox.rocky.tf.core.layers as L | ||
from sandbox.rocky.tf.core.network import ConvNetwork | ||
from rllab.core.serializable import Serializable | ||
from sandbox.rocky.tf.distributions.categorical import Categorical | ||
from sandbox.rocky.tf.policies.base import StochasticPolicy | ||
from rllab.misc import ext | ||
from sandbox.rocky.tf.misc import tensor_utils | ||
from rllab.misc.overrides import overrides | ||
from sandbox.rocky.tf.spaces.discrete import Discrete | ||
import tensorflow as tf | ||
|
||
|
||
class CategoricalConvPolicy(StochasticPolicy, LayersPowered, Serializable): | ||
def __init__( | ||
self, | ||
name, | ||
env_spec, | ||
conv_filters, conv_filter_sizes, conv_strides, conv_pads, | ||
hidden_sizes=[], | ||
hidden_nonlinearity=tf.nn.relu, | ||
output_nonlinearity=tf.nn.softmax, | ||
prob_network=None, | ||
): | ||
""" | ||
:param env_spec: A spec for the mdp. | ||
:param hidden_sizes: list of sizes for the fully connected hidden layers | ||
:param hidden_nonlinearity: nonlinearity used for each hidden layer | ||
:param prob_network: manually specified network for this policy, other network params | ||
are ignored | ||
:return: | ||
""" | ||
Serializable.quick_init(self, locals()) | ||
|
||
assert isinstance(env_spec.action_space, Discrete) | ||
|
||
self._env_spec = env_spec | ||
# import pdb; pdb.set_trace() | ||
if prob_network is None: | ||
prob_network = ConvNetwork( | ||
input_shape=env_spec.observation_space.shape, | ||
output_dim=env_spec.action_space.n, | ||
conv_filters=conv_filters, | ||
conv_filter_sizes=conv_filter_sizes, | ||
conv_strides=conv_strides, | ||
conv_pads=conv_pads, | ||
hidden_sizes=hidden_sizes, | ||
hidden_nonlinearity=hidden_nonlinearity, | ||
output_nonlinearity=output_nonlinearity, | ||
name="prob_network", | ||
) | ||
|
||
self._l_prob = prob_network.output_layer | ||
self._l_obs = prob_network.input_layer | ||
self._f_prob = tensor_utils.compile_function( | ||
[prob_network.input_layer.input_var], | ||
L.get_output(prob_network.output_layer) | ||
) | ||
|
||
self._dist = Categorical(env_spec.action_space.n) | ||
|
||
super(CategoricalConvPolicy, self).__init__(env_spec) | ||
LayersPowered.__init__(self, [prob_network.output_layer]) | ||
|
||
@property | ||
def vectorized(self): | ||
return True | ||
|
||
@overrides | ||
def dist_info_sym(self, obs_var, state_info_vars=None): | ||
return dict(prob=L.get_output(self._l_prob, {self._l_obs: tf.cast(obs_var, tf.float32)})) | ||
|
||
@overrides | ||
def dist_info(self, obs, state_infos=None): | ||
return dict(prob=self._f_prob(obs)) | ||
|
||
# The return value is a pair. The first item is a matrix (N, A), where each | ||
# entry corresponds to the action value taken. The second item is a vector | ||
# of length N, where each entry is the density value for that action, under | ||
# the current policy | ||
@overrides | ||
def get_action(self, observation): | ||
flat_obs = self.observation_space.flatten(observation) | ||
prob = self._f_prob([flat_obs])[0] | ||
action = self.action_space.weighted_sample(prob) | ||
return action, dict(prob=prob) | ||
|
||
def get_actions(self, observations): | ||
flat_obs = self.observation_space.flatten_n(observations) | ||
probs = self._f_prob(flat_obs) | ||
actions = list(map(self.action_space.weighted_sample, probs)) | ||
return actions, dict(prob=probs) | ||
|
||
@property | ||
def distribution(self): | ||
return self._dist |