Skip to content

Commit

Permalink
Merge pull request tensorflow#200 from tensorflow/fast-fft-param
Browse files Browse the repository at this point in the history
change fft image parametrization to use only one op for the batch
  • Loading branch information
michaelpetrov authored Oct 2, 2019
2 parents 8fb9374 + 90f23d2 commit 67d3e73
Showing 1 changed file with 21 additions and 27 deletions.
48 changes: 21 additions & 27 deletions lucid/optvis/param/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,27 @@ def fft_image(shape, sd=None, decay_power=1):
sd = sd or 0.01
batch, h, w, ch = shape
freqs = rfft2d_freqs(h, w)
init_val_size = (2, ch) + freqs.shape

images = []
for _ in range(batch):
# Create a random variable holding the actual 2D fourier coefficients
init_val = np.random.normal(size=init_val_size, scale=sd).astype(np.float32)
spectrum_real_imag_t = tf.Variable(init_val)
spectrum_t = tf.complex(spectrum_real_imag_t[0], spectrum_real_imag_t[1])

# Scale the spectrum. First normalize energy, then scale by the square-root
# of the number of pixels to get a unitary transformation.
# This allows to use similar leanring rates to pixel-wise optimisation.
scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power
scale *= np.sqrt(w * h)
scaled_spectrum_t = scale * spectrum_t

# convert complex scaled spectrum to shape (h, w, ch) image tensor
# needs to transpose because irfft2d returns channels first
image_t = tf.transpose(tf.spectral.irfft2d(scaled_spectrum_t), (1, 2, 0))

# in case of odd spatial input dimensions we need to crop
image_t = image_t[:h, :w, :ch]

images.append(image_t)

batched_image_t = tf.stack(images) / 4.0 # TODO: is that a magic constant?
return batched_image_t
init_val_size = (2, batch, ch) + freqs.shape

init_val = np.random.normal(size=init_val_size, scale=sd).astype(np.float32)
spectrum_real_imag_t = tf.Variable(init_val)
spectrum_t = tf.complex(spectrum_real_imag_t[0], spectrum_real_imag_t[1])

# Scale the spectrum. First normalize energy, then scale by the square-root
# of the number of pixels to get a unitary transformation.
# This allows to use similar leanring rates to pixel-wise optimisation.
scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power
scale *= np.sqrt(w * h)
scaled_spectrum_t = scale * spectrum_t

# convert complex scaled spectrum to shape (h, w, ch) image tensor
# needs to transpose because irfft2d returns channels first
image_t = tf.transpose(tf.spectral.irfft2d(scaled_spectrum_t), (0, 2, 3, 1))

# in case of odd spatial input dimensions we need to crop
image_t = image_t[:batch, :h, :w, :ch]
image_t = image_t / 4.0 # TODO: is that a magic constant?
return image_t


def laplacian_pyramid_image(shape, n_levels=4, sd=None):
Expand Down

0 comments on commit 67d3e73

Please sign in to comment.