Skip to content

Commit

Permalink
Add comments to link convolutional factors code with definitions in t…
Browse files Browse the repository at this point in the history
…he KFC paper.

PiperOrigin-RevId: 179925679
  • Loading branch information
tensorflower-gardener committed Dec 22, 2017
1 parent fd29e95 commit d1ca27b
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions tensorflow/contrib/kfac/python/ops/fisher_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ class FullyConnectedDiagonalFactor(DiagonalFactor):
Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],
approximates the covariance as,
Cov(in, out) = (1/batch_size) \sum_{i} outer(in[i], out_grad[i]) ** 2.0
Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0
where the square is taken element-wise.
"""
Expand Down Expand Up @@ -765,7 +765,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
example x. Expectation is taken over all examples and locations.
Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See
Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See
Section 3.1 Estimating the factors.
"""

Expand Down Expand Up @@ -837,11 +837,23 @@ def _compute_new_cov(self, idx=0):
padding=self._padding)

flatten_size = (filter_height * filter_width * in_channels)
# patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
# omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
# where M = minibatch size, |T| = number of spatial locations,
# |Delta| = number of spatial offsets, and J = number of input maps
# for convolutional layer l.
patches_flat = array_ops.reshape(patches, [-1, flatten_size])

# We append a homogenous coordinate to patches_flat if the layer has
# bias parameters. This gives us [[A_l]]_H from the paper.
if self._has_bias:
patches_flat = _append_homog(patches_flat)

# We call _compute_cov without passing in a normalizer. _compute_cov uses
# the first dimension of patches_flat i.e. M|T| as the normalizer by
# default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
# shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
# the paper but has a different scale here for consistency with
# ConvOutputKroneckerFactor.
# (Tilde omitted over A for clarity.)
return _compute_cov(patches_flat)


Expand All @@ -852,7 +864,7 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
all examples and locations.
Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See
Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See
Section 3.1 Estimating the factors.
"""

Expand Down Expand Up @@ -890,8 +902,16 @@ def _dtype(self):
def _compute_new_cov(self, idx=0):
with _maybe_colocate_with(self._outputs_grads[idx],
self._colocate_cov_ops_with_inputs):
# reshaped_tensor below is the matrix DS_l defined in the KFC paper
# (tilde omitted over S for clarity). It has shape M|T| x I, where
# M = minibatch size, |T| = number of spatial locations, and
# I = number of output maps for convolutional layer l.
reshaped_tensor = array_ops.reshape(self._outputs_grads[idx],
[-1, self._out_channels])
# Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
# _compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
# as defined in the paper, with shape I x I.
# (Tilde omitted over S for clarity.)
return _compute_cov(reshaped_tensor)


Expand Down Expand Up @@ -1109,7 +1129,7 @@ def make_inverse_update_ops(self):
# depending on how psd_eig is defined. I'm not sure why.
C1 = (C1 + array_ops.transpose(C1)) / 2.0

# hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi})
# hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0)

# Compute the decomposition U*diag(psi)*U^T = hPsi
Expand All @@ -1134,7 +1154,7 @@ def make_inverse_update_ops(self):
# Compute the product C0^(-1/2) * C1
invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1)

# hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi})
# hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0)

# Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
Expand Down

0 comments on commit d1ca27b

Please sign in to comment.