diff --git a/tensorflow_probability/python/experimental/mcmc/nuts_unrolled.py b/tensorflow_probability/python/experimental/mcmc/nuts_unrolled.py index 73975cdfc8..83a6ddd168 100644 --- a/tensorflow_probability/python/experimental/mcmc/nuts_unrolled.py +++ b/tensorflow_probability/python/experimental/mcmc/nuts_unrolled.py @@ -67,58 +67,47 @@ ### END STATIC CONFIGURATION ################################# ############################################################## - __all__ = [ 'NoUTurnSamplerUnrolled', ] - -NUTSKernelResults = collections.namedtuple( - 'NUTSKernelResults', - [ - 'target_log_prob', - 'grads_target_log_prob', - 'leapfrogs_computed', - 'is_accepted', - 'reach_max_depth', - # TODO(junpenglao): expose divergence diagnostic - # 'has_divergence', - ]) - - -TraceArrays = collections.namedtuple( - 'TraceArrays', - [ - 'momentum_swap', - 'state_swap', - ]) - - -TreeDoublingState = collections.namedtuple( - 'TreeDoublingState', - [ - 'momentum', - 'state', - 'target', - 'target_grad_parts', - ]) +NUTSKernelResults = collections.namedtuple('NUTSKernelResults', [ + 'target_log_prob', + 'grads_target_log_prob', + 'momentum_state_memory', + 'leapfrogs_computed', + 'is_accepted', + 'reach_max_depth', + 'has_divergence', +]) + +MomentumStateSwap = collections.namedtuple('MomentumStateSwap', [ + 'momentum_swap', + 'state_swap', +]) + +TreeDoublingState = collections.namedtuple('TreeDoublingState', [ + 'momentum', + 'state', + 'target', + 'target_grad_parts', +]) TreeDoublingStateCandidate = collections.namedtuple( - 'TreeDoublingStateCandidate', - [ + 'TreeDoublingStateCandidate', [ 'state', 'target', 'target_grad_parts', 'weight', ]) - TreeDoublingMetaState = collections.namedtuple( 'TreeDoublingMetaState', [ 'leapfrog_count', 'candidate_state', # A namedtuple of TreeDoublingStateCandidate 'continue_tree', + 'not_divergence', ]) @@ -157,7 +146,6 @@ def __init__(self, max_tree_depth=6, max_energy_diff=1000., unrolled_leapfrog_steps=1, - trace_arrays=None, seed=None, name=None): """Initializes this transition kernel. @@ -176,19 +164,15 @@ def __init__(self, possible, it's often helpful to match per-variable step sizes to the standard deviations of the target distribution in each variable. max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The - maximum number of leapfrog steps is bounded by `2**max_tree_depth` - i.e. the number of nodes in a binary tree `max_tree_depth` nodes deep. - The default setting of 6 takes up to 64 leapfrog steps. + maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e. + the number of nodes in a binary tree `max_tree_depth` nodes deep. The + default setting of 6 takes up to 64 leapfrog steps. max_energy_diff: Scaler threshold of energy differences at each leapfrog, divergence samples are defined as leapfrog steps that exceed this threshold. Default to 1000. unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree expansion step. Applies a direct linear multipler to the maximum trajectory length implied by max_tree_depth. Defaults to 1. - trace_arrays: `tf.TensorArray` that contains the important leapfrog - information within the same tree doubling. It is a swap memory that will - be overwrite at each tree doubling. Default to None and allocate at run - time. seed: Python integer to seed the random number generator. name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'nuts_kernel'). @@ -214,16 +198,18 @@ def __init__(self, self._read_instruction = tf.ragged.constant(read_instruction) else: f = lambda int_iter: write_instruction[int_iter] - self._write_instruction = {x: functools.partial(f, x) - for x in range(len(write_instruction))} + self._write_instruction = { + x: functools.partial(f, x) for x in range(len(write_instruction)) + } self._read_instruction = read_instruction # Process all other arguments. self._target_log_prob_fn = target_log_prob_fn if not tf.nest.is_nested(step_size): step_size = [step_size] - step_size = [tf.convert_to_tensor(s, dtype_hint=tf.float32) - for s in step_size] + step_size = [ + tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in step_size + ] self._step_size = step_size self._parameters = dict( @@ -232,14 +218,12 @@ def __init__(self, max_tree_depth=max_tree_depth, max_energy_diff=max_energy_diff, unrolled_leapfrog_steps=unrolled_leapfrog_steps, - trace_arrays=trace_arrays, seed=seed, name=name, ) self._seed_stream = SeedStream(seed, salt='nuts_one_step') self._unrolled_leapfrog_steps = unrolled_leapfrog_steps self._name = name - self._trace_arrays = trace_arrays self._max_energy_diff = max_energy_diff @property @@ -262,10 +246,6 @@ def max_energy_diff(self): def unrolled_leapfrog_steps(self): return self._unrolled_leapfrog_steps - @property - def trace_arrays(self): - return self._trace_arrays - @property def name(self): return self._name @@ -296,31 +276,10 @@ def one_step(self, current_state, previous_kernel_results): init_momentum, log_slice_sample = self._start_trajectory_batched( current_state, current_target_log_prob) batch_size = prefer_static.size(current_target_log_prob) - if self.trace_arrays is None: - def _init(shape_and_dtype): - if USE_TENSORARRAY: - return [ - tf.TensorArray(dtype=d, # pylint: disable=g-complex-comprehension - size=self.max_tree_depth + 1, - element_shape=s, - clear_after_read=False) - for (s, d) in shape_and_dtype] - else: - return [ - tf.zeros( # pylint: disable=g-complex-comprehension - tf.TensorShape([self.max_tree_depth + 1]).concatenate(s), - dtype=d) - for (s, d) in shape_and_dtype] - get_shapes_and_dtypes = lambda x: [(x_.shape, x_.dtype) for x_ in x] - # TODO(jvdillon): We don't want to rely on mutating TransitionKernel - # state. This can be rectified by including trace_arrays in the - # previous_kernel_results. - self._trace_arrays = TraceArrays( - momentum_swap=_init(get_shapes_and_dtypes(init_momentum)), - state_swap=_init(get_shapes_and_dtypes(current_state))) - trace_arrays = self.trace_arrays + momentum_state_memory = previous_kernel_results.momentum_state_memory init_weight = tf.ones(batch_size, dtype=TREE_COUNT_DTYPE) continue_tree = tf.ones(batch_size, dtype=tf.bool) + not_divergence = tf.ones([batch_size], dtype=tf.bool) def _copy(v): return v * prefer_static.ones( @@ -344,15 +303,15 @@ def _copy(v): initial_step_metastate = TreeDoublingMetaState( leapfrog_count=tf.zeros([], dtype=tf.int32, name='leapfrog_count'), candidate_state=candidate_state, - continue_tree=continue_tree) + continue_tree=continue_tree, + not_divergence=not_divergence) _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, state, metastate: ( # pylint: disable=g-long-lambda ((iter_ < self.max_tree_depth) & tf.reduce_any(metastate.continue_tree))), body=lambda iter_, state, metastate: self._loop_one_step( # pylint: disable=g-long-lambda - log_slice_sample, trace_arrays, - iter_, state, metastate), + log_slice_sample, momentum_state_memory, iter_, state, metastate), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), initial_step_state, @@ -376,9 +335,12 @@ def _copy(v): kernel_results = NUTSKernelResults( target_log_prob=candidate_state.target, grads_target_log_prob=candidate_state.target_grad_parts, + momentum_state_memory=momentum_state_memory, leapfrogs_computed=leapfrogs_computed, is_accepted=is_accepted, - reach_max_depth=reach_max_depth) + reach_max_depth=reach_max_depth, + has_divergence=~new_step_metastate.not_divergence, + ) return result_state, kernel_results @@ -396,38 +358,66 @@ def bootstrap_results(self, init_state): raise ValueError('Expected either one step size or {} (size of ' '`init_state`), but found {}'.format( len(init_state), len(step_size))) + dummy_momentum = [tf.ones_like(state) for state in init_state] + + def _init(shape_and_dtype): + """Allocate TensorArray for storing state and momentum.""" + if USE_TENSORARRAY: + return [ # pylint: disable=g-complex-comprehension + tf.TensorArray( + dtype=d, + size=self.max_tree_depth + 1, + element_shape=s, + clear_after_read=False) for (s, d) in shape_and_dtype + ] + else: + return [ # pylint: disable=g-complex-comprehension + tf.zeros( + tf.TensorShape([self.max_tree_depth + 1]).concatenate(s), + dtype=d) for (s, d) in shape_and_dtype + ] + + get_shapes_and_dtypes = lambda x: [(x_.shape, x_.dtype) for x_ in x] + momentum_state_memory = MomentumStateSwap( + momentum_swap=_init(get_shapes_and_dtypes(dummy_momentum)), + state_swap=_init(get_shapes_and_dtypes(init_state))) [ _, _, current_target_log_prob, current_grads_log_prob, - ] = leapfrog_impl.process_args(self.target_log_prob_fn, - dummy_momentum, + ] = leapfrog_impl.process_args(self.target_log_prob_fn, dummy_momentum, init_state) batch_size = prefer_static.size(current_target_log_prob) return NUTSKernelResults( target_log_prob=current_target_log_prob, grads_target_log_prob=current_grads_log_prob, - leapfrogs_computed=tf.zeros([], dtype=tf.int32, + momentum_state_memory=momentum_state_memory, + leapfrogs_computed=tf.zeros([], + dtype=tf.int32, name='leapfrogs_computed'), - is_accepted=tf.zeros([batch_size], dtype=tf.bool, - name='is_accepted'), - reach_max_depth=tf.zeros([batch_size], dtype=tf.bool, - name='is_accepted'), - ) + is_accepted=tf.zeros([batch_size], dtype=tf.bool, name='is_accepted'), + reach_max_depth=tf.zeros([batch_size], + dtype=tf.bool, + name='reach_max_depth'), + has_divergence=tf.zeros([batch_size], + dtype=tf.bool, + name='has_divergence'), + ) def _start_trajectory_batched(self, state, target_log_prob): """Computations needed to start a trajectory.""" with tf.name_scope('start_trajectory_batched'): - seed_stream = SeedStream(self._seed_stream, - salt='start_trajectory_batched') + seed_stream = SeedStream( + self._seed_stream, salt='start_trajectory_batched') momentum = [ - tf.random.normal(shape=prefer_static.shape(x), # pylint: disable=g-complex-comprehension - dtype=x.dtype, - seed=seed_stream()) - for x in state] + tf.random.normal( # pylint: disable=g-complex-comprehension + shape=prefer_static.shape(x), + dtype=x.dtype, + seed=seed_stream()) for x in state + ] # Draw a slice variable u ~ Uniform(0, p(initial state, initial # momentum)) and compute log u. For numerical stability, we perform this # in log space where log u = log (u' * p(...)) = log u' + log @@ -439,8 +429,8 @@ def _start_trajectory_batched(self, state, target_log_prob): log_slice_sample += compute_hamiltonian(target_log_prob, momentum) return momentum, log_slice_sample - def _loop_one_step(self, log_slice_sample, trace_arrays, - iter_, initial_step_state, initial_step_metastate): + def _loop_one_step(self, log_slice_sample, momentum_state_memory, iter_, + initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_size = prefer_static.size(log_slice_sample) @@ -465,15 +455,18 @@ def _loop_one_step(self, log_slice_sample, trace_arrays, [ candidate_tree_state, tree_final_states, + final_not_divergence, continue_tree_final, leapfrogs_computed, - ] = self._build_sub_tree(direction, - log_slice_sample, - # num_steps_at_this_depth = 2**iter_ = 1 << iter_ - tf.bitwise.left_shift(1, iter_), - tree_start_states, - initial_step_metastate.continue_tree, - trace_arrays) + ] = self._build_sub_tree( + direction, + log_slice_sample, + # num_steps_at_this_depth = 2**iter_ = 1 << iter_ + tf.bitwise.left_shift(1, iter_), + tree_start_states, + initial_step_metastate.continue_tree, + initial_step_metastate.not_divergence, + momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight @@ -546,7 +539,8 @@ def _loop_one_step(self, log_slice_sample, trace_arrays, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_computed), candidate_state=new_candidate_state, - continue_tree=continue_next_tree) + continue_tree=continue_next_tree, + not_divergence=final_not_divergence) return iter_ + 1, new_step_state, new_step_metastate @@ -556,7 +550,8 @@ def _build_sub_tree(self, nsteps, initial_state, continue_tree, - trace_arrays, + not_divergence, + momentum_state_memory, name=None): with tf.name_scope('build_sub_tree'): batch_size = prefer_static.size(log_slice_sample) @@ -571,20 +566,25 @@ def _build_sub_tree(self, final_state, candidate_tree_state, final_continue_tree, - trace_arrays, + final_not_divergence, + momentum_state_memory, ] = tf.while_loop( - cond=lambda iter_, state, state_c, continue_tree, trace_arrays: ( # pylint: disable=g-long-lambda - (iter_ < nsteps) & tf.reduce_any(continue_tree)), - body=lambda iter_, state, state_c, continue_tree, trace_arrays: ( # pylint: disable=g-long-lambda - self._loop_build_sub_tree( - direction, log_slice_sample, - iter_, state, state_c, continue_tree, trace_arrays)), + cond=lambda iter_, state, state_c, continue_tree, not_divergence, # pylint: disable=g-long-lambda + momentum_state_memory: ( + (iter_ < nsteps) & tf.reduce_any(continue_tree)), + body=lambda iter_, state, state_c, continue_tree, not_divergence, # pylint: disable=g-long-lambda + momentum_state_memory: ( + self._loop_build_sub_tree( + direction, log_slice_sample, iter_, state, + state_c, continue_tree, not_divergence, + momentum_state_memory)), loop_vars=( tf.zeros([], dtype=tf.int32, name='iter'), initial_state, initial_state_candidate, continue_tree, - trace_arrays, + not_divergence, + momentum_state_memory, ), parallel_iterations=TF_WHILE_PARALLEL_ITERATIONS, ) @@ -592,25 +592,28 @@ def _build_sub_tree(self, return ( candidate_tree_state, final_state, + final_not_divergence, final_continue_tree, leapfrogs_computed, ) - def _loop_build_sub_tree( - self, direction, log_slice_sample, - iter_, prev_tree_state, candidate_tree_state, - continue_tree_previous, trace_arrays): + def _loop_build_sub_tree(self, direction, log_slice_sample, iter_, + prev_tree_state, candidate_tree_state, + continue_tree_previous, not_divergent_previous, + momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) - for state in prev_tree_state.state] + for state in prev_tree_state.state + ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, - step_sizes=[tf.where(direction, ss, -ss) - for direction, ss in zip( - directions_expanded, self.step_size)], + step_sizes=[ + tf.where(direction, ss, -ss) + for direction, ss in zip(directions_expanded, self.step_size) + ], num_steps=self.unrolled_leapfrog_steps) [ next_momentum_parts, @@ -637,27 +640,29 @@ def _loop_build_sub_tree( else: write_index_ = tf.switch_case(index, self.write_instruction) - write_index = tf.where(tf.equal(iter_ % 2, 0), - write_index_, self.max_tree_depth) + write_index = tf.where( + tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: - trace_arrays = TraceArrays( + momentum_state_memory = MomentumStateSwap( momentum_swap=[ old.write(write_index, new) for old, new in - zip(trace_arrays.momentum_swap, next_momentum_parts)], + zip(momentum_state_memory.momentum_swap, next_momentum_parts)], state_swap=[ old.write(write_index, new) for old, new in - zip(trace_arrays.state_swap, next_state_parts)]) + zip(momentum_state_memory.state_swap, next_state_parts)]) else: - trace_arrays = TraceArrays( + momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) - for old, new in zip( - trace_arrays.momentum_swap, next_momentum_parts)], + for old, new in zip(momentum_state_memory.momentum_swap, + next_momentum_parts) + ], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) - for old, new in zip( - trace_arrays.state_swap, next_state_parts)]) + for old, new in zip(momentum_state_memory.state_swap, + next_state_parts) + ]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) @@ -667,16 +672,17 @@ def _loop_build_sub_tree( lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions_expanded, - trace_arrays, next_momentum_parts, next_state_parts)) + momentum_state_memory, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda - self.read_instruction, int_iter, directions_expanded, trace_arrays, - next_momentum_parts, next_state_parts) - branch_excution = {x: functools.partial(f, x) - for x in range(len(self.read_instruction))} + self.read_instruction, int_iter, directions_expanded, + momentum_state_memory, next_momentum_parts, next_state_parts) + branch_excution = { + x: functools.partial(f, x) + for x in range(len(self.read_instruction)) + } no_u_turns_within_tree = tf.cond( - tf.equal(iter_ % 2, 0), - lambda: has_not_u_turn_at_even_step, + tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) @@ -686,33 +692,27 @@ def _loop_build_sub_tree( sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE) weight_sum = candidate_tree_state.weight + sample_weight log_accept_thresh = tf.math.log( - tf.cast(sample_weight, tf.float32) / - tf.cast(weight_sum, tf.float32)) + tf.cast(sample_weight, tf.float32) / tf.cast(weight_sum, tf.float32)) log_accept_thresh = tf.where( tf.math.is_nan(log_accept_thresh), - tf.zeros([], log_accept_thresh.dtype), - log_accept_thresh) + tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform( - shape=[batch_size], - dtype=tf.float32, - seed=self._seed_stream())) + shape=[batch_size], dtype=tf.float32, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension - _expand_dims_under_batch_dim( - is_sample_accepted, prefer_static.rank(s0)), s0, s1) - for s0, s1 in zip(next_state_parts, - candidate_tree_state.state) + _expand_dims_under_batch_dim(is_sample_accepted, + prefer_static.rank(s0)), s0, s1) + for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], - target=tf.where(is_sample_accepted, - next_target, + target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension - _expand_dims_under_batch_dim( - is_sample_accepted, prefer_static.rank(grad0)), + _expand_dims_under_batch_dim(is_sample_accepted, + prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip(next_target_grad_parts, candidate_tree_state.target_grad_parts) @@ -723,35 +723,42 @@ def _loop_build_sub_tree( continue_tree = not_divergent & no_u_turns_within_tree continue_tree_next = continue_tree_previous & continue_tree + not_divergent_tokeep = tf.where(continue_tree_previous, not_divergent, + tf.ones([batch_size], dtype=tf.bool)) + return ( iter_ + 1, next_tree_state, next_candidate_tree_state, continue_tree_next, - trace_arrays, + not_divergent_previous & not_divergent_tokeep, + momentum_state_memory, ) -def has_not_u_turn_at_odd_step( - instruction, - iter_, - direction, - trace_arrays, - momentum_right, - state_right): +def has_not_u_turn_at_odd_step(instruction, iter_, direction, + momentum_state_memory, momentum_right, + state_right): """Check u turn for early stopping.""" # Note that here iter_ is actually iter_ // 2 left_current_index = instruction[iter_] if USE_TENSORARRAY: - momentum_left = [x.gather(left_current_index) - for x in trace_arrays.momentum_swap] - state_left = [x.gather(left_current_index) - for x in trace_arrays.state_swap] + momentum_left = [ + x.gather(left_current_index) + for x in momentum_state_memory.momentum_swap + ] + state_left = [ + x.gather(left_current_index) for x in momentum_state_memory.state_swap + ] else: - momentum_left = [tf.gather(x, left_current_index, axis=0) - for x in trace_arrays.momentum_swap] - state_left = [tf.gather(x, left_current_index, axis=0) - for x in trace_arrays.state_swap] + momentum_left = [ + tf.gather(x, left_current_index, axis=0) + for x in momentum_state_memory.momentum_swap + ] + state_left = [ + tf.gather(x, left_current_index, axis=0) + for x in momentum_state_memory.state_swap + ] no_u_turns_within_tree_ = has_not_u_turn( state_left, @@ -771,21 +778,23 @@ def _batchwise_reduce_sum(x, rank_diff): def has_not_u_turn(state_left, momentum_left, state_right, momentum_right): """If two given states and momentum do not exhibit a U-turn pattern.""" with tf.name_scope('has_not_u_turn'): - batch_dot_product_left = sum( - [_batchwise_reduce_sum( - (s1 - s2) * m, prefer_static.rank(s2) - prefer_static.rank(s1)) - for s1, s2, m in zip(state_right, state_left, momentum_left)]) - batch_dot_product_right = sum( - [_batchwise_reduce_sum( - (s1 - s2) * m, prefer_static.rank(s2) - prefer_static.rank(s1)) - for s1, s2, m in zip(state_right, state_left, momentum_right)]) + batch_dot_product_left = sum([ + _batchwise_reduce_sum((s1 - s2) * m, + prefer_static.rank(s2) - prefer_static.rank(s1)) + for s1, s2, m in zip(state_right, state_left, momentum_left) + ]) + batch_dot_product_right = sum([ + _batchwise_reduce_sum((s1 - s2) * m, + prefer_static.rank(s2) - prefer_static.rank(s1)) + for s1, s2, m in zip(state_right, state_left, momentum_right) + ]) return (batch_dot_product_left >= 0) & (batch_dot_product_right >= 0) def _expand_dims_under_batch_dim(tensor, new_rank): """Adds size-1 dimensions below the first until `tensor` has `new_rank`.""" - ones = prefer_static.ones( - [new_rank - prefer_static.rank(tensor)], dtype=tf.int32) + ones = prefer_static.ones([new_rank - prefer_static.rank(tensor)], + dtype=tf.int32) shape = prefer_static.shape(tensor) new_shape = prefer_static.concat([shape[:1], ones, shape[1:]], axis=0) return tf.reshape(tensor, new_shape) @@ -793,6 +802,7 @@ def _expand_dims_under_batch_dim(tensor, new_rank): def build_tree_uturn_instruction(max_depth, init_memory=0): """Run build tree and output the u turn checking input instruction.""" + def _buildtree(address, depth): if depth == 0: address += 1 @@ -802,6 +812,7 @@ def _buildtree(address, depth): _, address_right = _buildtree(address_right, depth - 1) instruction.append((address_left, address_right)) return address_left, address_right + instruction = [] _, _ = _buildtree(init_memory, max_depth) return np.array(instruction, dtype=np.int32) @@ -822,8 +833,8 @@ def generate_efficient_write_read_instruction(instruction_array): # 0 : still in memory but not needed for check u turn instruction_mat2 = np.zeros(instruction_mat.shape) instruction_mat2[instruction_mat == 0] = -1 - instruction_mat2[(instruction_mat_cumsum < max_to_retain) & - (instruction_mat_cumsum > 0)] = 0 + instruction_mat2[(instruction_mat_cumsum < max_to_retain) + & (instruction_mat_cumsum > 0)] = 0 instruction_mat2[instruction_mat == 1] = 1 np.fill_diagonal(instruction_mat2, (max_to_retain > 0) - 1) # plt.imshow(instruction_mat2, interpolation='None') @@ -843,12 +854,14 @@ def generate_efficient_write_read_instruction(instruction_array): def compute_hamiltonian(target_log_prob, momentum_parts): + """Compute the Hamiltonian of the current system.""" independent_chain_ndims = prefer_static.rank(target_log_prob) momentum_sq_parts = ( - tf.cast(tf.reduce_sum(tf.square(m), # pylint: disable=g-complex-comprehension - axis=prefer_static.range(independent_chain_ndims, - prefer_static.rank(m))), - dtype=target_log_prob.dtype) - for m in momentum_parts) + tf.cast( # pylint: disable=g-complex-comprehension + tf.reduce_sum( + tf.square(m), + axis=prefer_static.range(independent_chain_ndims, + prefer_static.rank(m))), + dtype=target_log_prob.dtype) for m in momentum_parts) # TODO(jvdillon): Verify no broadcasting happening. return target_log_prob - 0.5 * sum(momentum_sq_parts)