forked from mattjj/pyhsmm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhmm-EM.py
57 lines (40 loc) · 1.26 KB
/
hmm-EM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import division
import numpy as np
np.seterr(divide='ignore') # these warnings are usually harmless for this code
from matplotlib import pyplot as plt
import matplotlib
import os
matplotlib.rcParams['font.size'] = 8
import pyhsmm
from pyhsmm.util.text import progprint_xrange
save_images = False
#### load data
data = np.loadtxt(os.path.join(os.path.dirname(__file__),'example-data.txt'))
#### EM
N = 4
obs_dim = data.shape[1]
obs_hypparams = {'mu_0':np.zeros(obs_dim),
'sigma_0':np.eye(obs_dim),
'kappa_0':0.25,
'nu_0':obs_dim+2}
obs_distns = [pyhsmm.distributions.Gaussian(**obs_hypparams) for state in xrange(N)]
# Build the HMM model that will represent the fitmodel
fitmodel = pyhsmm.models.HMM(
alpha=50.,init_state_concentration=50., # these are only used for initialization
obs_distns=obs_distns)
fitmodel.add_data(data)
print 'Gibbs sampling for initialization'
for idx in progprint_xrange(25):
fitmodel.resample_model()
plt.figure()
fitmodel.plot()
plt.gcf().suptitle('Gibbs-sampled initialization')
print 'EM'
likes = fitmodel.EM_fit()
plt.figure()
fitmodel.plot()
plt.gcf().suptitle('EM fit')
plt.figure()
plt.plot(likes)
plt.gcf().suptitle('log likelihoods during EM')
plt.show()