Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkingma committed Oct 18, 2013
1 parent 66c1910 commit 29d7e2e
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 28 deletions.
36 changes: 27 additions & 9 deletions BNModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BNModel(object):
def __init__(self, theano_warning='raise', hessian=True):

theanofunction = lazytheanofunc('warn', mode='FAST_RUN')
theanofunction_silent = lazytheanofunc('ignore', mode='FAST_RUN')

# Create theano expressions
# TODO: change order to (w, x, z) everywhere
Expand All @@ -41,6 +42,7 @@ def __init__(self, theano_warning='raise', hessian=True):
# Get gradient symbols
allvars = w.values() + x.values() + z.values() + [A] # note: '+' concatenates lists

# TODO: Split Hessian code from the core code (it's too rarely used), e.g. just in experiment script.
if False and hessian:
# Hessian of logpxz wrt z
raise Exception("Needs fix: assumes fixed n_batch, which is not true anymore")
Expand Down Expand Up @@ -73,7 +75,7 @@ def __init__(self, theano_warning='raise', hessian=True):
for i in _x: x[i].tag.test_value = _x[i]
for i in _z: z[i].tag.test_value = _z[i]

logpw, logpx, logpz = self.factors(w, x, z, A)
logpw, logpx, logpz, dists = self.factors(w, x, z, A)

# Complete-data likelihood estimate
logpxz = logpx.sum() + logpz.sum()
Expand All @@ -91,6 +93,12 @@ def __init__(self, theano_warning='raise', hessian=True):
self.f_logpw = theanofunction(w.values(), logpw)
self.f_dlogpw_dw = theanofunction(w.values(), [logpw] + dlogpw_dw)

# distributions
self.f_dists = {}
for name in dists:
_vars, dist = dists[name]
self.f_dists[name] = theanofunction_silent(_vars, dist)

if False:
raise Exception("Code block needs refactoring: n_batch no longer a field of the model")
# MC-LIKELIHOOD
Expand All @@ -100,7 +108,7 @@ def __init__(self, theano_warning='raise', hessian=True):
dlogpxmc_dw = T.grad(logpxmc, w.values(), disconnected_inputs=theano_warning)
self.f_dlogpxmc_dw = theanofunction(allvars, [logpxmc] + dlogpxmc_dw)

if True:
if True and len(z) > 0:
# Fisher divergence (FD)
gz = T.grad(logpxz, z.values())
gz2 = [T.dmatrix() for _ in gz]
Expand All @@ -123,6 +131,14 @@ def factors(self): raise NotImplementedError()
def gen_xz(self): raise NotImplementedError()
def init_w(self): raise NotImplementedError()

# Prediction
def distribution(self, w, x, z, name):
x, z = self.xz_to_theano(x, z)
w, z, x = ordereddicts((w, z, x))
A = self.get_A(x)
allvars = w.values() + x.values() + z.values() + [A]
return self.f_dists[name](*allvars)

# Numpy <-> Theano var conversion
def xz_to_theano(self, x, z): return x, z
def gwgz_to_numpy(self, gw, gz): return gw, gz
Expand All @@ -133,8 +149,9 @@ def get_A(self, x): return np.ones((1, x.itervalues().next().shape[1]))
# Likelihood: logp(x,z|w)
def logpxz(self, w, z, x):
_x, _z = self.xz_to_theano(x, z)
A = self.get_A(x)
logpx, logpz = self.f_logpxz(*orderedvals((w, _x, _z))+[A])
A = self.get_A(_x)
allvars = w.values() + _x.values() + _z.values() + [A]
logpx, logpz = self.f_logpxz(*allvars)
if np.isnan(logpx).any() or np.isnan(logpz).any():
print 'v: ', logpx, logpz
print 'Values:'
Expand All @@ -149,7 +166,8 @@ def dlogpxz_dwz(self, w, z, x):
x, z = self.xz_to_theano(x, z)
w, z, x = ordereddicts((w, z, x))
A = self.get_A(x)
r = self.f_dlogpxz_dwz(*(w.values() + x.values() + z.values() + [A]))
allvars = w.values() + x.values() + z.values() + [A]
r = self.f_dlogpxz_dwz(*allvars)
logpx, logpz, gw, gz = r[0], r[1], dict(zip(w.keys(), r[2:2+len(w)])), dict(zip(z.keys(), r[2+len(w):]))

if ndict.hasNaN(gw) or ndict.hasNaN(gz):
Expand Down Expand Up @@ -325,8 +343,8 @@ def val(self, w, z):
return np.hstack(logpx), np.hstack(logpz)

def grad(self, w, z):
if self.cardinality==1: return self.grad(0, w, z)
logpxi, logpzi, gwi, _ = tuple(zip(*[self.grad(i, w, z) for i in range(self.cardinality)]))
if self.cardinality==1: return self.subgrad(0, w, z)
logpxi, logpzi, gwi, _ = tuple(zip(*[self.subgrad(i, w, z) for i in range(self.cardinality)]))
return np.hstack(logpxi), np.hstack(logpzi), ndict.sum(gwi)

# Parallel version of likelihood
Expand Down Expand Up @@ -482,11 +500,11 @@ def subgrad(self, i, w, z):
for j in gw: gw[j] += prior_weight * gw_prior[j]
return logpx.sum() + logpz.sum() + prior_weight * prior, gw

def val(self, w, z=None):
def val(self, w, z={}):
logpx, logpz = self.ll.val(w, z)
return logpx.sum() + logpz.sum() + self.model.logpw(w)

def grad(self, w, z=None):
def grad(self, w, z={}):
logpx, logpz, gw = self.ll.grad(w, z)
prior, gw_prior = self.model.dlogpw_dw(w)
for i in gw: gw[i] += gw_prior[i]
Expand Down
Binary file modified BNModel.pyc
Binary file not shown.
3 changes: 3 additions & 0 deletions hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ def compute_logq_kde(_z):
logpdf = np.vstack(logpdf)
return logpdf.reshape((-1, n_batch*n_est))

# TODO: implement Infinite Gaussian Mixture Model from scikit learn
# http://scikit-learn.org/stable/modules/generated/sklearn.mixture.DPGMM.html#sklearn.mixture.DPGMM

logq = 0
for i in z:
_z = z[i].reshape((-1, n_samples))
Expand Down
Binary file modified hmc.pyc
Binary file not shown.
8 changes: 4 additions & 4 deletions models/DBN.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

class DBN(BNModel):

def __init__(self, n_z, n_x, n_steps, n_batch, prior_sd=0.1):
def __init__(self, n_z, n_x, n_steps, prior_sd=0.1):
self.constr = (__name__, inspect.stack()[0][3], locals())
self.n_z, self.n_x, self.n_steps, self.n_batch = n_z, n_x, n_steps, n_batch
self.n_z, self.n_x, self.n_steps, self.n_batch = n_z, n_x, n_steps
self.prior_sd = prior_sd

theano_warning = 'raise'
if n_steps == 1: theano_warning = 'warn'

super(DBN, self).__init__(n_batch, theano_warning)
super(DBN, self).__init__(theano_warning)


def factors(self, w, x, z, A):
Expand Down Expand Up @@ -53,7 +53,7 @@ def f_xi(zi, xi):
for i in w:
logpw += anglepy.logpdfs.normal(w[i], 0, self.prior_sd).sum() # logp(w)

return logpw, logpx, logpz
return logpw, logpx, logpz, {}

# Confabulate hidden states 'z'
def gen_xz(self, w, x, z):
Expand Down
2 changes: 1 addition & 1 deletion models/DBN_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def f_step(eps_t, x_t, _z_prev, logpx_prev):
for i in w:
logpw += anglepy.logpdfs.normal(w[i], 0, self.prior_sd).sum() # logp(w)

return logpw, logpx, logpz
return logpw, logpx, logpz, {}

# Numpy <-> Theano var conversion
def xz_to_theano(self, x, z):
Expand Down
2 changes: 1 addition & 1 deletion models/MLBN.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def f_prior(_w):
logpw += f_prior(w['w%i'%i])
logpw += f_prior(w['wout'])

return logpw, logpx, logpz
return logpw, logpx, logpz, {}

# Confabulate latent variables
def gen_xz(self, w, x, z, n_batch):
Expand Down
2 changes: 1 addition & 1 deletion models/MLBN_Inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def f_prior(_w):
logpw += f_prior(w['w_mean'])
logpw += f_prior(w['w_logvar'])

return logpw, logpx, logpz
return logpw, logpx, logpz, {}

# Confabulate hidden states 'z'
def gen_xz(self, w, x, z, n_batch):
Expand Down
21 changes: 9 additions & 12 deletions woptim.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def eval(y):
# maybe just execute a remote np.random.seed(0) there?
np.random.seed(0)
_w = ndict.unflatten(y, w)
logLik, gw = cache[1], cache[2]
logpw, gw = cache[1], cache[2]
if np.linalg.norm(y) != cache[0]:
logLik, gw = posterior.grad(_w)
print logLik, np.linalg.norm(y)
cache[0], cache[1], cache[2] = [np.linalg.norm(y), logLik, gw]
return logLik, gw
logpw, gw = posterior.grad(_w)
#print logpw, np.linalg.norm(y)
cache[0], cache[1], cache[2] = [np.linalg.norm(y), logpw, gw]
return logpw, gw

def f(y):
logLik, gw = eval(y)
Expand All @@ -129,23 +129,20 @@ def callback(wz):
t[1] += 1
if time.time() - t[2] > hook_wavelength:
_w = ndict.unflatten(wz, w)
hook(t[1], _w)
hook(t[1], _w, cache[1]) # num_its, w, logpw
t[2] = time.time()
#t[0] += 1
#if t[0]%5 is not 0: return
#if time.time() - t[2] < 1: return
#t[2] = time.time()

x0 = ndict.flatten(w)
xn, f, d = scipy.optimize.fmin_l_bfgs_b(func=f, x0=x0, fprime=fprime, m=m, iprint=0, callback=callback, maxiter=maxiter)

#scipy.optimize.fmin_cg(f=f, x0=x0, fprime=fprime, full_output=True, callback=hook)
#scipy.optimize.fmin_ncg(f=f, x0=x0, fprime=fprime, full_output=True, callback=hook)
w = ndict.unflatten(xn, w)
print 'd: ', d
#print 'd: ', d
if d['warnflag'] is 2:
print 'warnflag:', d['warnflag']
print d['task']
return w
info = d
return w, info


Binary file modified woptim.pyc
Binary file not shown.

0 comments on commit 29d7e2e

Please sign in to comment.