Skip to content

fix: allow specifying type of buffers inside MTKParameters #3585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
58072eb
fix: allow specifying type of buffers inside `MTKParameters`
AayushSabharwal Apr 28, 2025
a69c71e
fix: fix accidental narrowing of nonnumeric buffer
AayushSabharwal Apr 28, 2025
2be5495
fix: error if `container_type` passed to `MTKParameters` is not an `A…
AayushSabharwal Apr 30, 2025
c7b1923
fix: retain `split` kwarg when simplifying initialization system
AayushSabharwal May 5, 2025
f66a91b
refactor: make `EmptySciMLFunction` subtype `SciMLBase.AbstractSciMLF…
AayushSabharwal May 5, 2025
995ab5f
fix: retain `iip` and `u0Type` in `SCCNonlinearProblem` constructor
AayushSabharwal May 5, 2025
0b40f06
feat: add `p_constructor` kwarg to `MTKParameters` constructor
AayushSabharwal May 5, 2025
28a8215
feat: implement `ArrayInterface.ismutable` for `MTKParameters`
AayushSabharwal May 5, 2025
b5718a2
fix: handle immutable MTKParameters in `remake_buffer`
AayushSabharwal May 5, 2025
2b6b2ed
feat: add `p_constructor` kwarg to problem constructors
AayushSabharwal May 5, 2025
c809a92
feat: propagate `p_constructor` to `InitializationProblem`
AayushSabharwal May 5, 2025
ae54df9
fix: propagate `iip` of parent problem to `InitializationProblem`
AayushSabharwal May 5, 2025
0084445
fix: retain type of buffers when promoting `u0`/`p` of initialization…
AayushSabharwal May 5, 2025
38fc7f6
fix: handle immutable buffers in initialization
AayushSabharwal May 5, 2025
b3464be
fix: handle `u0_constructor`, `p_constructor` in `remake_initializati…
AayushSabharwal May 5, 2025
15a72dd
fix: handle immutable MTKParameters in symbolic `late_binding_update_…
AayushSabharwal May 5, 2025
369d986
fix: handle edge case in floating point type promotion
AayushSabharwal May 5, 2025
de214f4
fix: call `u0_constructor` on `resid_prototype`
AayushSabharwal May 5, 2025
1c00d9d
test: test initialization on static array problems
AayushSabharwal May 5, 2025
dc36ba1
build: bump SciMLBase, StochasticDelayDiffEq compat
AayushSabharwal May 8, 2025
3f51efe
fix: handle empty `syms` in `concrete_getu`
AayushSabharwal May 12, 2025
1a0acc7
refactor: format
AayushSabharwal May 12, 2025
4ce8cb6
fix: fix `is_update_oop` passed as type
AayushSabharwal May 12, 2025
f60c0de
fix: fix call to `remake_initialization_data`
AayushSabharwal May 12, 2025
9f6ae44
Update Project.toml
ChrisRackauckas May 12, 2025
604d838
fix: fix `update_initializeprob!`
AayushSabharwal May 12, 2025
1936dc0
fix: add `f` field to `MockIntegrator`
AayushSabharwal May 12, 2025
14d6867
fixup! test: test initialization on static array problems
AayushSabharwal May 15, 2025
cccab96
fix: handle discretes properly in `get_mtkparameters_reconstructor`
AayushSabharwal May 15, 2025
a6c5b40
refactor: move ChainRulesCoreExt into main package
AayushSabharwal May 15, 2025
3044635
fix: use `@ignore_derivatives` inside `update_initializeprob!`
AayushSabharwal May 15, 2025
6aa4400
Update Project.toml
ChrisRackauckas May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down Expand Up @@ -65,7 +66,6 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
[weakdeps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
CasADi = "c49709b8-5c63-11e9-2fb2-69db5844192f"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
FMI = "14a09403-18e3-468f-ad8a-74f8dda2d9ac"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
Expand All @@ -74,7 +74,6 @@ LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
[extensions]
MTKBifurcationKitExt = "BifurcationKit"
MTKCasADiDynamicOptExt = "CasADi"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKFMIExt = "FMI"
MTKInfiniteOptExt = "InfiniteOpt"
Expand Down Expand Up @@ -142,15 +141,15 @@ RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.84"
SciMLBase = "2.90.0"
SciMLStructures = "1.7"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDelayDiffEq = "1.8.1"
StochasticDelayDiffEq = "1.10"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.39"
SymbolicUtils = "3.26.1"
Expand Down
66 changes: 36 additions & 30 deletions ext/MTKCasADiDynamicOptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ struct CasADiDynamicOptProblem{uType, tType, isinplace, P, F, K} <:
end
end

function (M::MXLinearInterpolation)(τ)
function (M::MXLinearInterpolation)(τ)
nt = (τ - M.t[1]) / M.dt
i = 1 + floor(Int, nt)
Δ = nt - i + 1

(i > length(M.t) || i < 1) && error("Cannot extrapolate past the tspan.")
if i < length(M.t)
M.u[:, i] + Δ*(M.u[:, i + 1] - M.u[:, i])
M.u[:, i] + Δ * (M.u[:, i + 1] - M.u[:, i])
else
M.u[:, i]
end
Expand All @@ -74,7 +74,7 @@ The constraints are:
function MTK.CasADiDynamicOptProblem(sys::ODESystem, u0map, tspan, pmap;
dt = nothing,
steps = nothing,
guesses = Dict(), kwargs...)
guesses = Dict(), kwargs...)
MTK.warn_overdetermined(sys, u0map)
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
Expand Down Expand Up @@ -104,21 +104,21 @@ function init_model(sys, tspan, steps, u0map, pmap, u0; is_free_t = false)
subject_to!(opti, tₛ >= lo)
subject_to!(opti, tₛ >= hi)
end
pmap[te_sym] = tₛ
pmap[te_sym] = tₛ
tsteps = LinRange(0, 1, steps)
else
tₛ = MX(1)
tsteps = LinRange(tspan[1], tspan[2], steps)
end

U = CasADi.variable!(opti, length(states), steps)
V = CasADi.variable!(opti, length(ctrls), steps)
set_initial!(opti, U, DM(repeat(u0, 1, steps)))
c0 = MTK.value.([pmap[c] for c in ctrls])
!isempty(c0) && set_initial!(opti, V, DM(repeat(c0, 1, steps)))

U_interp = MXLinearInterpolation(U, tsteps, tsteps[2]-tsteps[1])
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2]-tsteps[1])
U_interp = MXLinearInterpolation(U, tsteps, tsteps[2] - tsteps[1])
V_interp = MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
for (i, ct) in enumerate(ctrls)
pmap[ct] = V[i, :]
end
Expand Down Expand Up @@ -185,8 +185,8 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
x = MTK.operation(st)
t = only(MTK.arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
Expand All @@ -196,11 +196,11 @@ function add_user_constraints!(model::CasADiModel, sys, tspan, pmap; is_free_t)
end

if cons isa Equation
subject_to!(opti, cons.lhs - cons.rhs==0)
subject_to!(opti, cons.lhs - cons.rhs == 0)
elseif cons.relational_op === Symbolics.geq
subject_to!(opti, cons.lhs - cons.rhs0)
subject_to!(opti, cons.lhs - cons.rhs0)
else
subject_to!(opti, cons.lhs - cons.rhs0)
subject_to!(opti, cons.lhs - cons.rhs0)
end
end
end
Expand All @@ -227,8 +227,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
x = operation(st)
t = only(arguments(st))
MTK.symbolic_type(t) === MTK.NotSymbolic() || continue
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
if haskey(stidxmap, x(iv))
idx = stidxmap[x(iv)]
cv = U
else
idx = ctidxmap[x(iv)]
Expand All @@ -244,7 +244,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
op = MTK.operation(int)
arg = only(arguments(MTK.value(int)))
lo, hi = (op.domain.domain.left, op.domain.domain.right)
!isequal((lo, hi), tspan) && error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
!isequal((lo, hi), tspan) &&
error("Non-whole interval bounds for integrals are not currently supported for CasADiDynamicOptProblem.")
# Approximate integral as sum.
intmap[int] = dt * tₛ * sum(arg)
end
Expand All @@ -253,7 +254,8 @@ function add_cost_function!(model::CasADiModel, sys, tspan, pmap; is_free_t)
minimize!(opti, MX(MTK.value(consolidate(jcosts))))
end

function substitute_casadi_vars(model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
function substitute_casadi_vars(
model::CasADiModel, sys, pmap, exprs; auxmap::Dict = Dict(), is_free_t)
@unpack opti, U, V, tₛ = model
iv = MTK.get_iv(sys)
sts = unknowns(sys)
Expand Down Expand Up @@ -281,44 +283,44 @@ end

function add_solve_constraints(prob, tableau)
@unpack A, α, c = tableau
@unpack model, f, p = prob
@unpack model, f, p = prob
@unpack opti, U, V, tₛ = model
solver_opti = copy(opti)

tsteps = U.t
tsteps = U.t
dt = tsteps[2] - tsteps[1]

nᵤ = size(U.u, 1)
nᵥ = size(V.u, 1)

if MTK.is_explicit(tableau)
K = MX[]
for k in 1:length(tsteps)-1
for k in 1:(length(tsteps) - 1)
τ = tsteps[k]
for (i, h) in enumerate(c)
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
Uₙ = U.u[:, k] + ΔU*dt
Uₙ = U.u[:, k] + ΔU * dt
Vₙ = V.u[:, k]
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
push!(K, Kₙ)
end
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k+1])
subject_to!(solver_opti, U.u[:, k] + ΔU == U.u[:, k + 1])
empty!(K)
end
else
for k in 1:length(tsteps)-1
for k in 1:(length(tsteps) - 1)
τ = tsteps[k]
Kᵢ = variable!(solver_opti, nᵤ, length(α))
ΔUs = A * Kᵢ' # the stepsize at each stage of the implicit method
for (i, h) in enumerate(c)
ΔU = ΔUs[i,:]'
Uₙ = U.u[:,k] + ΔU*dt
Vₙ = V.u[:,k]
subject_to!(solver_opti, Kᵢ[:,i] == tₛ * f(Uₙ, Vₙ, p, τ + h*dt))
ΔU = ΔUs[i, :]'
Uₙ = U.u[:, k] + ΔU * dt
Vₙ = V.u[:, k]
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
end
ΔU_tot = dt*(Kᵢ*α)
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:,k+1])
ΔU_tot = dt * (Kᵢ * α)
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
end
end
solver_opti
Expand All @@ -331,7 +333,10 @@ end

NOTE: the solver should be passed in as a string to CasADi. "ipopt"
"""
function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt", tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(), solver_options::Dict = Dict(), silent = false)
function DiffEqBase.solve(
prob::CasADiDynamicOptProblem, solver::Union{String, Symbol} = "ipopt",
tableau_getter = MTK.constructDefault; plugin_options::Dict = Dict(),
solver_options::Dict = Dict(), silent = false)
@unpack model, u0, p, tspan, f = prob
tableau = tableau_getter()
@unpack opti, U, V, tₛ = model
Expand Down Expand Up @@ -366,7 +371,8 @@ function DiffEqBase.solve(prob::CasADiDynamicOptProblem, solver::Union{String, S
end

if failed
ode_sol = SciMLBase.solution_new_retcode(ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
ode_sol = SciMLBase.solution_new_retcode(
ode_sol, SciMLBase.ReturnCode.ConvergenceFailure)
!isnothing(input_sol) && (input_sol = SciMLBase.solution_new_retcode(
input_sol, SciMLBase.ReturnCode.ConvergenceFailure))
end
Expand Down
7 changes: 6 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, bloc
using OffsetArrays: Origin
import CommonSolve
import EnumX
import ChainRulesCore
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk

using RuntimeGeneratedFunctions
using RuntimeGeneratedFunctions: drop_expr
Expand Down Expand Up @@ -204,6 +206,8 @@ include("structural_transformation/StructuralTransformations.jl")
@reexport using .StructuralTransformations
include("inputoutput.jl")

include("adjoints.jl")

for S in subtypes(ModelingToolkit.AbstractSystem)
S = nameof(S)
@eval convert_system(::Type{<:$S}, sys::$S) = sys
Expand Down Expand Up @@ -349,7 +353,8 @@ export AnalysisPoint, get_sensitivity_function, get_comp_sensitivity_function,
function FMIComponent end

include("systems/optimal_control_interface.jl")
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem, CasADiDynamicOptProblem
export AbstractDynamicOptProblem, JuMPDynamicOptProblem, InfiniteOptDynamicOptProblem,
CasADiDynamicOptProblem
export DynamicOptSolution

end # module
32 changes: 12 additions & 20 deletions ext/MTKChainRulesCoreExt.jl → src/adjoints.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
module MTKChainRulesCoreExt

import ModelingToolkit as MTK
import ChainRulesCore
import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk

function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
function ChainRulesCore.rrule(::Type{MTKParameters}, tunables, args...)
function mtp_pullback(dt)
dt = unthunk(dt)
dtunables = dt isa AbstractArray ? dt : dt.tunable
(NoTangent(), dtunables[1:length(tunables)],
ntuple(_ -> NoTangent(), length(args))...)
end
MTK.MTKParameters(tunables, args...), mtp_pullback
MTKParameters(tunables, args...), mtp_pullback
end

function subset_idxs(idxs, portion, template)
Expand Down Expand Up @@ -70,23 +64,23 @@ function selected_tangents(
end

function ChainRulesCore.rrule(
::typeof(MTK.remake_buffer), indp, oldbuf::MTK.MTKParameters, idxs, vals)
::typeof(remake_buffer), indp, oldbuf::MTKParameters, idxs, vals)
if idxs isa AbstractSet
idxs = collect(idxs)
end
idxs = map(idxs) do i
i isa MTK.ParameterIndex ? i : MTK.parameter_index(indp, i)
i isa ParameterIndex ? i : parameter_index(indp, i)
end
newbuf = MTK.remake_buffer(indp, oldbuf, idxs, vals)
newbuf = remake_buffer(indp, oldbuf, idxs, vals)
tunable_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Tunable);
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Tunable);
init = Union{Int, AbstractVector{Int}}[])
initials_idxs = reduce(
vcat, (idx.idx for idx in idxs if idx.portion isa MTK.SciMLStructures.Initials);
vcat, (idx.idx for idx in idxs if idx.portion isa SciMLStructures.Initials);
init = Union{Int, AbstractVector{Int}}[])
disc_idxs = subset_idxs(idxs, MTK.SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, MTK.SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, MTK.NONNUMERIC_PORTION, oldbuf.nonnumeric)
disc_idxs = subset_idxs(idxs, SciMLStructures.Discrete(), oldbuf.discrete)
const_idxs = subset_idxs(idxs, SciMLStructures.Constants(), oldbuf.constant)
nn_idxs = subset_idxs(idxs, NONNUMERIC_PORTION, oldbuf.nonnumeric)

pullback = let idxs = idxs
function remake_buffer_pullback(buf′)
Expand All @@ -102,13 +96,11 @@ function ChainRulesCore.rrule(
oldbuf′ = Tangent{typeof(oldbuf)}(;
tunable, initials, discrete, constant, nonnumeric)
idxs′ = NoTangent()
vals′ = map(i -> MTK._ducktyped_parameter_values(buf′, i), idxs)
vals′ = map(i -> _ducktyped_parameter_values(buf′, i), idxs)
return f′, indp′, oldbuf′, idxs′, vals′
end
end
newbuf, pullback
end

ChainRulesCore.@non_differentiable Base.getproperty(sys::MTK.AbstractSystem, x::Symbol)

end
ChainRulesCore.@non_differentiable Base.getproperty(sys::AbstractSystem, x::Symbol)
13 changes: 9 additions & 4 deletions src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ function (linfun::LinearizationFunction)(u, p, t)
linfun.num_states == length(u) ||
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
integ_cache = (linfun.caches,)
integ = MockIntegrator{true}(u, p, t, integ_cache, nothing)
integ = MockIntegrator{true}(u, p, t, fun, integ_cache, nothing)
u, p, success = SciMLBase.get_initial_values(
linfun.prob, integ, fun, linfun.initializealg, Val(true);
linfun.initialize_kwargs...)
Expand Down Expand Up @@ -325,7 +325,7 @@ Mock `DEIntegrator` to allow using `CheckInit` without having to create a new in

$(TYPEDFIELDS)
"""
struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
struct MockIntegrator{iip, U, P, T, F, C, O} <: SciMLBase.DEIntegrator{Nothing, iip, U, T}
"""
The state vector.
"""
Expand All @@ -339,6 +339,10 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip
"""
t::T
"""
The wrapped `SciMLFunction`.
"""
f::F
"""
The integrator cache.
"""
cache::C
Expand All @@ -348,8 +352,9 @@ struct MockIntegrator{iip, U, P, T, C, O} <: SciMLBase.DEIntegrator{Nothing, iip
opts::O
end

function MockIntegrator{iip}(u::U, p::P, t::T, cache::C, opts::O) where {iip, U, P, T, C, O}
return MockIntegrator{iip, U, P, T, C, O}(u, p, t, cache, opts)
function MockIntegrator{iip}(
u::U, p::P, t::T, f::F, cache::C, opts::O) where {iip, U, P, T, F, C, O}
return MockIntegrator{iip, U, P, T, F, C, O}(u, p, t, f, cache, opts)
end

SymbolicIndexingInterface.state_values(integ::MockIntegrator) = integ.u
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
end

if simplify_system
isys = structural_simplify(isys; fully_determined)
isys = structural_simplify(isys; fully_determined, split = is_split(sys))
end

ts = get_tearing_state(isys)
Expand Down Expand Up @@ -1554,6 +1554,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
else
NonlinearLeastSquaresProblem
end
TProb(isys, u0map, parammap; kwargs...,
TProb{iip}(isys, u0map, parammap; kwargs...,
build_initializeprob = false, is_initializeprob = true)
end
Loading
Loading