forked from mattjj/pyhsmm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmeanfield.py
39 lines (26 loc) · 960 Bytes
/
meanfield.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from __future__ import division
import numpy as np
from matplotlib import pyplot as plt
from pyhsmm import models, distributions
np.random.seed(0)
np.seterr(invalid='raise')
obs_hypparams = dict(mu_0=np.zeros(2),sigma_0=np.eye(2),kappa_0=0.05,nu_0=5)
### generate data
num_modes = 3
true_obs_distns = [distributions.Gaussian(**obs_hypparams) for i in range(num_modes)]
data = np.concatenate([true_obs_distns[i % num_modes].rvs(25) for i in range(25)])
## inference!
hmm = models.HMM(
obs_distns=[distributions.Gaussian(**obs_hypparams) for i in range(num_modes*3)],
alpha=3.,init_state_concentration=1.)
hmm.add_data(data)
hmm.meanfield_coordinate_descent_step()
scores = [hmm.meanfield_coordinate_descent_step() for i in range(50)]
scores = np.array(scores)
hmm.plot()
plt.figure()
plt.plot(scores)
def normalize(A):
return A / A.sum(1)[:,None]
plt.matshow(normalize(hmm.trans_distn.exp_expected_log_trans_matrix))
plt.show()