Skip to content

Commit

Permalink
fix[tutorials]: Changes integrator for pure casadi example.
Browse files Browse the repository at this point in the history
Changes the casadi integrator to idas in example on section 10.3 with
pure casadi.
  • Loading branch information
maxspahn committed Jan 17, 2023
1 parent d5d68b9 commit 76de800
Showing 1 changed file with 32 additions and 23 deletions.
55 changes: 32 additions & 23 deletions tutorial/section_10_3_casadi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import pdb
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import casadi as ca

EPS = 1e-5

n = 2
q = ca.SX.sym("q", n)
qdot = ca.SX.sym("qdot", n)
Expand Down Expand Up @@ -179,41 +182,47 @@ def createSolver(self, dt=0.01):
ode = {}
ode['x'] = self._z
ode['ode'] = self._rhs_aug
self._int_fun = ca.integrator("int_fun", 'cvodes', ode, {'tf':dt})
self._int_fun = ca.integrator("int_fun", 'idas', ode, {'tf':dt})

def computePath(self, z0, T):
num_steps = int(T/self._dt)
z = z0
sols = np.zeros((num_steps, 4))
max_step = num_steps
for i in range(num_steps):
res = self._int_fun(x0=z)
if np.linalg.norm(z) < 0.1:
break
try:
res = self._int_fun(x0=z)
except Exception as e:
if i < max_step:
max_step = i
break
z = np.array(res['xf'])[:, 0]
qdot = z[n:2*n]
qdot_norm = np.linalg.norm(qdot)
sols[i, :] = z
if qdot_norm < 0.030:
max_step = i
if i < max_step:
max_step = i
print("zero velocity")
break
print("finished")
return sols[:max_step, :]

def update(num, x1, x2, x3, y1, y2, y3, line1, line2, line3, point1, point2, point3):
def update(num, x1, x2, y1, y2, line1, line2, point1, point2):
start = max(0, num - 100)
line1.set_data(x1[start:num], y1[start:num])
point1.set_data(x1[num], y1[num])
line2.set_data(x2[start:num], y2[start:num])
point2.set_data(x2[num], y2[num])
line3.set_data(x3[start:num], y3[start:num])
point3.set_data(x3[num], y3[num])
return line1, point1, line2, point2, line3, point3
return line1, point1, line2, point2

def plotTraj(sol, ax, fig):
x = sol[:, 0]
y = sol[:, 1]
ax.set_xlim([-4, 4])
ax.set_ylim([-4, 4])
ax.set_xlim([-10, 5])
ax.set_ylim([-10, 5])
ax.plot(x, y)
(line,) = ax.plot(x, y, color="k")
(point,) = ax.plot(x, y, "rx")
Expand All @@ -237,24 +246,26 @@ def main():
forcing = ForcingPotential(M_b, der_psi)
limit_forcing = ForcingPotential(M_b, der_psi2)
speedController = SpeedController(lex, le, beta, eta_switch)
geo1 = Geometry(h = h_b, Forcing=limit_forcing)
geo1 = Geometry(h = h_b, Forcing=forcing)
geo1.createSolver(dt=0.01)
geo2 = Geometry(h = h_b, Forcing=forcing, SpeedController=speedController)
geo2.createSolver(dt=0.01)
#geo3 = Geometry(h_fun = h_b_fun, Forcing=forcing, SpeedController=speedController)
geos = [geo1, geo2]
q0 = np.array([2.0, 3.0])
v0 = 1.5
a0s = [((i * np.pi)/7) for i in range(14)]
T = 16.0
a0s = [(i * np.pi)/7 for i in range(14)]
T = 20.0
# solving
sols = []
for geo in geos:
geoSols = []
for a0 in a0s:
print(a0)
if a0 == 0.0:
continue
q0_dot = v0 * np.array([np.cos(a0), np.sin(a0)])
print(q0)
z0 = np.concatenate((q0, q0_dot))
print("Compute path for a0 : ", a0)
geoSols.append(geo.computePath(z0, T))
Expand All @@ -263,19 +274,17 @@ def main():
sol2 = sols[1][0]
sol3 = sols[1][0]
# plotting
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
fig.suptitle("<Title>")
ax[0][0].set_title("<Subtitle>")
ax[0][1].set_title("<Subtitle>")
ax[1][0].set_title("<Subtitle>")
plotMultipleTraj(sols[0], ax[0][0], fig)
plotMultipleTraj(sols[1], ax[0][1], fig)
(x, y, line, point) = plotTraj(sol1, ax[0][0], fig)
(x2, y2, line2, point2) = plotTraj(sol2, ax[0][1], fig)
(x3, y3, line3, point3) = plotTraj(sol3, ax[1][0], fig)
fig, ax = plt.subplots(1, 2, figsize=(10, 10))
fig.suptitle("Example for goal attraction")
ax[0].set_title("No Damping")
ax[1].set_title("With Damping")
plotMultipleTraj(sols[0], ax[0], fig)
plotMultipleTraj(sols[1], ax[1], fig)
(x, y, line, point) = plotTraj(sol1, ax[0], fig)
(x2, y2, line2, point2) = plotTraj(sol2, ax[1], fig)
ani = animation.FuncAnimation(
fig, update, len(x),
fargs=[x, x2, x3, y, y2, y3, line, line2, line3, point, point2, point3],
fargs=[x, x2, y, y2, line, line2, point, point2],
interval=10, blit=True
)
plt.show()
Expand Down

0 comments on commit 76de800

Please sign in to comment.