Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderChristgau committed Apr 14, 2021
1 parent 56aa148 commit ba07883
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mermaid/forward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ class MomentODE(ForwardModel):
'''
Forward model for first order momentum equations of stochastic EPDiff.
'''
def __init__(self, sz, spacing, smoother, sigma, params=None):
def __init__(self, sz, spacing, smoother,sigma,params=None):
super(MomentODE, self).__init__(sz, spacing, params)
self.smoother = smoother
self.sigma = sigma
Expand Down
2 changes: 1 addition & 1 deletion mermaid/ode_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_dt(self):

def init_solver(self,pars_to_pass_i,variables_from_optimizer,has_combined_input=False):
if self.use_sdeint:
self.integrator = RK.Milstein(self.model.f, self.model.g, pars_to_pass_i, self.cparams)
self.integrator = RK.MultiEulerHeun(self.model.f, self.model.g, pars_to_pass_i, self.cparams)
self.integrator.set_pars(pars_to_pass_i)
elif self.use_odeint:
self.integrator = ODEBlock(self.cparams)
Expand Down
31 changes: 31 additions & 0 deletions mermaid/rungekutta_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,37 @@ def solve_one_step(self, x, t, dt, dW, vo=None):
return xp1


class MultiEulerHeun(MultiSDEIntegrator):
"""
Euler Heun scheme for forward integration of:
x(t + dt) = x(t) + f(x(t))dt + g_a(x(t))dW_t^a
"""

def solve_one_step(self, x, t, dt, dW, vo=None):
"""
One step for Euler-forward
:param x: state at time t
:param t: initial time
:param dt: time increment
:param vo: variables from optimizer
:return: state at x+dt
"""
# Compute Euler-Maryuama step as supporting value
f_val = self.f(t, x, self.pars, vo)
g_val = self.g(t,x,self.pars,vo)
noise1 = self._sumBtdW(g_val, dW)
x_bar = self._xpytspz(x,f_val,dt,noise1)

# Take midpoints of current and supporting values
f_bar = self.f(t, x_bar, self.pars, vo)
g_bar = self.g(t,x_bar,self.pars,vo)
noise_bar = self._xpyts(self._sumBtdW(g_bar, 0.5*dW),noise1,0.5)
xp1 = self._xpytspz(x,self._xpy(f_val,f_bar),0.5*dt,noise_bar)
return xp1



class Milstein(MultiSDEIntegrator):
"""
Derivative free Milstein forward integration of:
Expand Down
1 change: 0 additions & 1 deletion mermaid/similarity_measure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def compute_similarity_multiNC(self, I0, I1, I0Source=None, phi=None):
I1mean = I1.mean(dim=2, keepdim=True)
I0_m_mean = I0-I0mean
I1_m_mean = I1-I1mean
I1_m_mean = I1-I1mean
nccSqr = (((I0_m_mean)*(I1_m_mean)).mean(dim=2)**2)/\
(((I0_m_mean)**2).mean(dim=2)*((I1_m_mean)**2).mean(dim=2))
nccSqr =nccSqr.mean(dim=1).sum()
Expand Down

0 comments on commit ba07883

Please sign in to comment.