Skip to content

Commit

Permalink
replace matmuls
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Jan 17, 2022
1 parent 5a819e4 commit 97e93de
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchdiffeq/_impl/rk_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ def _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau):
else:
ti = t0 + alpha_i * dt
perturb = Perturb.NONE
yi = y0 + k[..., :i + 1].matmul(beta_i * dt).view_as(f0)
yi = y0 + torch.sum(k[..., :i + 1] * (beta_i * dt), dim=-1).view_as(f0)
f = func(ti, yi, perturb=perturb)
k = _UncheckedAssign.apply(k, f, (..., i + 1))

if not (tableau.c_sol[-1] == 0 and (tableau.c_sol[:-1] == tableau.beta[-1]).all()):
# This property (true for Dormand-Prince) lets us save a few FLOPs.
yi = y0 + k.matmul(dt * tableau.c_sol).view_as(f0)
yi = y0 + torch.sum(k * (dt * tableau.c_sol), dim=-1).view_as(f0)

y1 = yi
f1 = k[..., -1]
y1_error = k.matmul(dt * tableau.c_error)
y1_error = torch.sum(k * (dt * tableau.c_error), dim=-1)
return y1, f1, y1_error, k


Expand Down Expand Up @@ -295,7 +295,7 @@ def _adaptive_step(self, rk_state):
def _interp_fit(self, y0, y1, k, dt):
"""Fit an interpolating polynomial to the results of a Runge-Kutta step."""
dt = dt.type_as(y0)
y_mid = y0 + k.matmul(dt * self.mid).view_as(y0)
y_mid = y0 + torch.sum(k * (dt * self.mid), dim=-1).view_as(y0)
f0 = k[..., 0]
f1 = k[..., -1]
return _interp_fit(y0, y1, y_mid, f0, f1, dt)
Expand Down

0 comments on commit 97e93de

Please sign in to comment.