Skip to content

Commit

Permalink
Propagating temperature parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbw committed May 17, 2013
1 parent a41ccc1 commit 3cc4794
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions examples/library-mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

T = 1000

f = np.load('/Users/mattjj/Dropbox/Test Data/TMT_50p_mixtures_and_data.npz')
f = np.load("/Users/Alex/Dropbox/Science/Datta lab/Posture Tracking/Test Data/TMT_50p_mixtures_and_data.npz")
mus = f['mu']
sigmas = f['sigma']
data = f['data'][:T]
Expand Down Expand Up @@ -87,12 +87,12 @@
# for the GMMs described above. roughly, gamma controls the total number
# of states while alpha controls the diversity of the transition
# distributions.
alpha=10.,gamma=10.,
# alpha=10.,gamma=10.,
# NOTE: as with a_0 and b_0 for the GMMs described above, we can also
# put gamma priors over alpha and gamma by commenting out the direct
# alpha= and gamma= lines and using these instead
# alpha_a_0=1.,alpha_b_0=1./10,
# gamma_a_0=1.,gamma_b_0=1./10,
alpha_a_0=1.,alpha_b_0=1./10,
gamma_a_0=1.,gamma_b_0=1./10,
obs_distns=hsmm_obs_distns,
dur_distns=dur_distns)
model.trans_distn.max_likelihood([rle(labels)[0]])
Expand Down
12 changes: 6 additions & 6 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ def _clear_caches(self):

### Gibbs sampling

def resample_model(self):
self.resample_obs_distns()
def resample_model(self, temp=None):
self.resample_obs_distns(temp=temp)
self.resample_trans_distn()
self.resample_init_state_distn()
self.resample_states()

def resample_obs_distns(self):
def resample_obs_distns(self, temp=None):
for state, distn in enumerate(self.obs_distns):
distn.resample([s.data[s.stateseq == state] for s in self.states_list])
distn.resample([s.data[s.stateseq == state] for s in self.states_list], temp=temp)
self._clear_caches()

def resample_trans_distn(self):
Expand Down Expand Up @@ -381,9 +381,9 @@ def generate(self,T,keep=True,**kwargs):

### Gibbs sampling

def resample_model(self):
def resample_model(self, temp=None):
self.resample_dur_distns()
super(HSMM,self).resample_model()
super(HSMM,self).resample_model(temp=temp)

def resample_dur_distns(self):
for state, distn in enumerate(self.dur_distns):
Expand Down

0 comments on commit 3cc4794

Please sign in to comment.