Skip to content

Commit

Permalink
minor refactorings; removed some unreachable code
Browse files Browse the repository at this point in the history
Double-checked exception cases and added some doctests in
diagnostics/sample_chains.py
  • Loading branch information
MFreidank committed Oct 14, 2017
1 parent 8bce73d commit 06bbc67
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 56 deletions.
38 changes: 29 additions & 9 deletions pysgmcmc/diagnostics/sample_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def __init__(self, chain_id, samples, varnames=None):
>>> trace.n_vars, trace.varnames == names, len(trace.varnames) == trace.n_vars
(2, True, True)
If `varnames` is `None`, anonymous names resulting from enumerating all
target parameters are used:
>>> dummy_samples = [[0., 0.], [0.2, -0.2], [0.3, -0.5], [0.1, 0.]]
>>> trace = PYSGMCMCTrace(chain_id=0, samples=dummy_samples, varnames=None)
>>> trace.varnames
['0', '1']
"""
self.chain = chain_id

Expand Down Expand Up @@ -156,16 +164,16 @@ def from_sampler(cls, chain_id, sampler, n_samples, varnames=None):
sample for sample, _ in islice(sampler, n_samples)
]

# try to read variable names from sampler parameters
# ensure all sampler target parameters have a name
# => tensorflow names variables automatically, so this assumption
# is fair
assert all(hasattr(param, "name") for param in sampler.params)

# read variable names from sampler parameters
if varnames is None:
try:
varnames = [
param.name for param in sampler.params
]
except AttributeError:
# could not read sampler parameters, passing `None`
# which will use enumerated names for the parameters
varnames = None
varnames = [
param.name for param in sampler.params
]
return PYSGMCMCTrace(chain_id, samples, varnames)

def __len__(self):
Expand Down Expand Up @@ -206,6 +214,7 @@ def get_values(self, varname, burn=0, thin=1):
This method makes each variable in a trace accessible by its name:
>>> import tensorflow as tf
>>> graph = tf.Graph()
>>> params = [tf.Variable(0., name="x"), tf.Variable(0., name="y")]
>>> params[0].name, params[1].name
('x_1:0', 'y_1:0')
Expand All @@ -219,6 +228,17 @@ def get_values(self, varname, burn=0, thin=1):
>>> trace.get_values(varname="x_1:0"), trace.get_values(varname="y_1:0")
(array([ 0. , 0.2, 0.3, 0.1]), array([ 0. , -0.2, -0.5, 0. ]))
If a queried name does not correspond to any parameter in the trace,
a `ValueError` is raised:
>>> names = [variable.name for variable in params]
>>> dummy_samples = [[0., 0.], [0.2, -0.2], [0.3, -0.5], [0.1, 0.]]
>>> trace = PYSGMCMCTrace(chain_id=0, samples=dummy_samples, varnames=names)
>>> trace.get_values(varname="FANTASYVARNAME")
Traceback (most recent call last):
...
ValueError: Queried `PYSGMCMCTrace` for values of parameter with name 'FANTASYVARNAME' but the trace does not contain any parameter of that name. Known variable names were: '['x_1:0', 'y_1:0']'
"""

if varname not in self.varnames:
Expand Down
11 changes: 2 additions & 9 deletions pysgmcmc/samplers/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,19 +273,12 @@ def __next__(self, feed_dict=None):
"""
assert (feed_dict is None or hasattr(feed_dict, "update"))
# Ensure self.Theta_t and self.Cost are defined
assert hasattr(self, "Theta_t") or not hasattr(self, "Cost")

if feed_dict is None:
feed_dict = dict()

if not hasattr(self, "Theta_t") or not hasattr(self, "Cost"):
# Ensure self.Theta_t and self.Cost are defined
raise ValueError(
"MCMCSampler subclass attempted to compute the next sample "
"with corresponding costs, but one of the "
"two necessary sampler member variables 'Theta_t' and 'Cost' "
"were not found in the samplers instance dictionary."
)

feed_dict.update(self._next_batch())
feed_dict.update(self._next_stepsize())
params, cost = self.session.run(
Expand Down
48 changes: 10 additions & 38 deletions pysgmcmc/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,30 +80,27 @@ def vectorize(tensor):
>>> v = vectorize([1.0])
Traceback (most recent call last):
...
AssertionError: Unsupported input to tensor_utils.vectorize: [1.0] is not a tensorflow.Tensor subclass
ValueError: Unsupported input to tensor_utils.vectorize: [1.0] is not a tensorflow.Tensor subclass
"""

error_msg = ("Unsupported input to tensor_utils.vectorize: "
"{value} is not a tensorflow.Tensor subclass".format(value=tensor))
def vectorized_shape(tensor):
# Compute vectorized shape
n_elements = np.prod(np.asarray(tensor.shape, dtype=np.int))
return (n_elements, 1)

assert isinstance(tensor, (tf.Variable, tf.Tensor,)), error_msg

# Compute vectorized shape
n_elements = np.prod(np.asarray(tensor.shape, dtype=np.int))
vectorized_shape = (n_elements, 1)

if type(tensor) == tf.Variable:
if isinstance(tensor, tf.Variable):
return tf.Variable(
tf.reshape(tensor.initialized_value(), shape=vectorized_shape)
tf.reshape(tensor.initialized_value(), shape=vectorized_shape(tensor))
)

elif isinstance(tensor, tf.Tensor):
return tf.reshape(tensor, shape=vectorized_shape)
return tf.reshape(tensor, shape=vectorized_shape(tensor))

else:
raise ValueError(
error_msg
"Unsupported input to tensor_utils.vectorize: "
"{value} is not a tensorflow.Tensor subclass".format(value=tensor)
)


Expand Down Expand Up @@ -607,28 +604,3 @@ def uninitialized_params(params, session):
)

return [param for param, flag in zip(params, init_flag) if not flag]


def all_uninitialized_variables(session, scope=None):
"""
Return all uninitialized `tensorflow.Variable` objects in the
current default graph. Uses `session` to determine if a variable
was initialized.
Parameters
----------
session : tf.Session
Session used to determine which variables are uninitialized.
Returns
----------
params_uninitialized: list of tensorflow.Variable objects
All `tensorflow.Variable` objects in the current default graph
that were not yet initialized.
"""

return uninitialized_params(
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope),
session=session
)

0 comments on commit 06bbc67

Please sign in to comment.