Skip to content

Commit

Permalink
Variational Autoencoder generate() function fixed (z fed in rather th…
Browse files Browse the repository at this point in the history
…an z_mean)
  • Loading branch information
joshthoward authored and nealwu committed Mar 27, 2017
1 parent ae50fa9 commit dec7c89
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions autoencoder/autoencoder_models/VariationalAutoencoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import tensorflow as tf
import numpy as np

class VariationalAutoencoder(object):

Expand Down Expand Up @@ -57,8 +56,8 @@ def transform(self, X):

def generate(self, hidden = None):
if hidden is None:
hidden = np.random.normal(size=self.weights["b1"])
return self.sess.run(self.reconstruction, feed_dict={self.z_mean: hidden})
hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
return self.sess.run(self.reconstruction, feed_dict={self.z: hidden})

def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict={self.x: X})
Expand Down

0 comments on commit dec7c89

Please sign in to comment.