Skip to content

fix: allow specifying type of buffers inside MTKParameters #3585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: master
Choose a base branch
from

Conversation

AayushSabharwal
Copy link
Member

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Add any other context about the problem here.

t0 = nothing, substitution_limit = 1000, floatT = nothing,
container_type = Vector)
if !(container_type <: AbstractArray)
container_type = Array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved

ChrisRackauckas added a commit to SciML/DiffEqGPU.jl that referenced this pull request May 4, 2025
Fixes #330.

Currently the MWE works:

```julia
using DiffEqGPU
using OrdinaryDiffEqTsit5, ModelingToolkit, StaticArrays
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters σ ρ β
@variables x(t) y(t) z(t)

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x * (ρ - z) - y,
    D(z) ~ x * y - β * z]

@mtkbuild sys = ODESystem(eqs, t) split=false

u0 = SA[D(x) => 2f0,
    x => 1f0,
    y => 0f0,
    z => 0f0]

p = SA[σ => 28f0,
    ρ => 10f0,
    β => 8f0 / 3f0]

tspan = (0f0, 100f0)
prob = ODEProblem{false}(sys, u0, tspan, p, split=true)
prob = remake(prob, p = p = SVector{10, Float32}(prob.p...))
sol = solve(prob, Tsit5())

using SymbolicIndexingInterface
p_setter = setp_oop(sys, [σ, ρ, β])

using DiffEqGPU, CUDA
function prob_func2(prob, i, repeat)
    remake(prob, p = p_setter(prob,@svector(rand(Float32,3))))
end
monteprob = EnsembleProblem(prob, prob_func = prob_func2, safetycopy = false)
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()),
    trajectories = 10_000)
```

But you need to `#prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)` in DiffEqBase.

What's in here drops the `split=false` part.

We need to fix `get_updated_symbolic_problem` to not promote to `Float64` and fix static array outputs in `split=true`, i.e. SciML/ModelingToolkit.jl#3585, in order to finish this tutorial.
@AayushSabharwal
Copy link
Member Author

Needs SciML/DiffEqBase.jl#1151 for CI to run

@DhairyaLGandhi
Copy link
Member

Not ProjectTo, but probable that the Ref in the stacktrace actually contains a tangent as well, which should get accumulated. I would check where the original argument gets mutated.

@AayushSabharwal
Copy link
Member Author

So this happens because the initialization is trivial and it skips calling solve altogether, instead just calling initializeprobmap(initializeprob).

@AayushSabharwal
Copy link
Member Author

AayushSabharwal commented May 13, 2025

Here's an MWE which reproduces the error and does not do any mutation:

@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
eqs = [D(x) ~ p * x
       D(y) ~ sum(p) + q * y]
u0 = [x => zeros(3),
    y => 1.0]
ps = [p => zeros(3, 3),
    q => 1.0]
tspan = (0.0, 10.0)
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, u0, tspan, ps)
gs = gradient(prob) do prob
           new_iprob = prob.f.initialization_data.update_initializeprob!(prob.f.initialization_data.initializeprob, prob)
           sum(prob.f.initialization_data.initializeprobmap(new_iprob))
       end

Despite the name, update_initializeprob! does not mutate its arguments or any internal buffer.

@DhairyaLGandhi
Copy link
Member

The MWE doesn't quite work for me.. perhaps I need some branches etc?

prob = ODEProblem(sys, u0, tspan, ps)
upf = prob.f.initialization_data.update_initializeprob!
iprob = prob.f.initialization_data.initializeprob

gs = Zygote.gradient(prob) do prob
    new_iprob = upf(iprob, prob)
    @show new_iprob
    out = prob.f.initialization_data.initializeprobmap(new_iprob)
    @show out
    sum(out)
end

returns

new_iprob = [nothing, nothing, nothing, nothing, nothing, nothing, nothing]
ERROR: BoundsError: attempt to access 7-element Vector{Nothing} at index [3:11]
Stacktrace:
  [1] throw_boundserror(A::Vector{Nothing}, I::Tuple{UnitRange{Int64}})
    @ Base ./abstractarray.jl:737
  [2] checkbounds
    @ ./abstractarray.jl:702 [inlined]
  [3] view
    @ ./subarray.jl:184 [inlined]
  [4] rrule
    @ ~/.julia/packages/ChainRules/Q16hj/src/rulesets/Base/indexing.jl:202 [inlined]
  [5] rrule
    @ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:138 [inlined]
  [6] chain_rrule
    @ ~/.julia/packages/Zygote/1GK3J/src/compiler/chainrules.jl:224 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0 [inlined]
  [8] _pullback
    @ ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:91 [inlined]
  [9] generated_callfunc
    @ ~/.julia/packages/SymbolicUtils/0GKgW/src/code.jl:411 [inlined]
 [10] _pullback(::Zygote.Context{…}, ::typeof(RuntimeGeneratedFunctions.generated_callfunc), ::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…}, ::Vector{…}, ::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [11] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [12] adjoint
    @ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
 [13] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [14] RuntimeGeneratedFunction
    @ ~/.julia/packages/RuntimeGeneratedFunctions/RrXEW/src/RuntimeGeneratedFunctions.jl:148 [inlined]
 [15] _generated_call
    @ ~/.julia/packages/ModelingToolkit/Ljerk/src/systems/codegen_utils.jl:0 [inlined]
 [16] _pullback(::Zygote.Context{…}, ::typeof(ModelingToolkit._generated_call), ::ModelingToolkit.GeneratedFunctionWrapper{…}, ::Vector{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [17] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [18] adjoint
    @ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
 [19] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [20] GeneratedFunctionWrapper
    @ ~/.julia/packages/ModelingToolkit/Ljerk/src/systems/codegen_utils.jl:259 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::ModelingToolkit.GeneratedFunctionWrapper{…}, ::Vector{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [22] TimeIndependentObservedFunction
    @ ~/.julia/packages/SymbolicIndexingInterface/3UAF0/src/state_indexing.jl:142 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::SymbolicIndexingInterface.TimeIndependentObservedFunction{…}, ::NotTimeseries, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [24] AbstractStateGetIndexer
    @ ~/.julia/packages/SymbolicIndexingInterface/3UAF0/src/value_provider_interface.jl:166 [inlined]
 [25] _pullback(ctx::Zygote.Context{…}, f::SymbolicIndexingInterface.TimeIndependentObservedFunction{…}, args::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [26] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [27] adjoint
    @ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [29] call_composed
    @ ./operators.jl:1045 [inlined]
 [30] call_composed
    @ ./operators.jl:1044 [inlined]
 [31] #_#103
    @ ./operators.jl:1041 [inlined]
 [32] _pullback(::Zygote.Context{…}, ::Base.var"##_#103", ::@Kwargs{}, ::ComposedFunction{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [33] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [34] adjoint
    @ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
 [35] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [36] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [37] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [38] #57
    @ ~/Downloads/arpa/jsmo/t2/JuliaSimExampleComponents/dc_motor.jl:160 [inlined]
 [39] _pullback(ctx::Zygote.Context{…}, f::var"#57#58", args::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
 [40] pullback(f::Function, cx::Zygote.Context{…}, args::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface.jl:90
 [41] pullback
    @ ~/.julia/packages/Zygote/1GK3J/src/compiler/interface.jl:88 [inlined]
 [42] gradient(f::Function, args::ODEProblem{…})
    @ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface.jl:147
 [43] top-level scope
    @ ~/Downloads/arpa/jsmo/t2/JuliaSimExampleComponents/dc_motor.jl:157
Some type information was truncated. Use `show(err)` to see complete types.

Is it expected to have the new_iprob = [nothing, nothing, nothing, nothing, nothing, nothing, nothing]?

@AayushSabharwal
Copy link
Member Author

Just this PR should be fine. This is my environment:

  [d360d2e6] ChainRulesCore v1.25.1
  [cdddcdb0] ChainRulesTestUtils v1.13.0
  [aaaaaaaa] ControlSystemsBase v1.14.8
⌅ [f6369f11] ForwardDiff v0.10.38
  [961ee093] ModelingToolkit v9.78.0 `..`
  [16a59e39] ModelingToolkitStandardLibrary v2.21.0
  [8913a72c] NonlinearSolve v4.8.0
  [1dea7af3] OrdinaryDiffEq v6.96.0
  [0bca4576] SciMLBase v2.89.1
  [1ed8b502] SciMLSensitivity v7.79.0
  [53ae85a6] SciMLStructures v1.7.0
  [2efcf032] SymbolicIndexingInterface v0.3.40
  [e88e6eb3] Zygote v0.7.7

@AayushSabharwal
Copy link
Member Author

new_iprob = [...] I think means you're not on this branch

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented May 13, 2025

prob = ODEProblem(sys, u0, tspan, ps)
upf = prob.f.initialization_data.update_initializeprob!
iprob = prob.f.initialization_data.initializeprob
pgetter = prob.f.initialization_data.metadata.oop_reconstruct_u0_p.pgetter

Zygote.gradient(prob, iprob) do prob, iprob
    mtkp = pgetter(prob, iprob)
    ip = remake(iprob; p = mtkp)
    out = prob.f.initialization_data.initializeprobmap(ip)
    sum(out)
end

Is that the same as expanding upf(iprob, prob) with the pgetter and the remake? Followed from https://github.com/AayushSabharwal/ModelingToolkit.jl/blob/604d8383bf44c1a9c570728e8e58a5864d3c09c7/src/systems/problem_utils.jl#L827

It returns for me:

((f = nothing, u0 = nothing, tspan = (0.0, nothing), p = (tunable = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], initials = [4.0, 4.0, 4.0, 0.0, 0.0, 0.0, 4.0, 0.0], discrete = nothing, constant = nothing, nonnumeric = nothing, caches = nothing), kwargs = nothing, problem_type = nothing), (f = nothing, u0 = nothing, p = nothing, problem_type = nothing, kwargs = (data = NamedTuple(), itr = nothing)))

@AayushSabharwal
Copy link
Member Author

Interesting. So what makes this fail is if you include iprob = prob.f.initialization_data.initializeprob inside the gradient call.

@DhairyaLGandhi
Copy link
Member

Actually it is including the pgetter that causes it to fail

Zygote.gradient(prob, iprob) do prob, iprob
    pgetter2 = prob.f.initialization_data.metadata.oop_reconstruct_u0_p.pgetter
    mtkp = pgetter2(prob, iprob)
    ip = remake(iprob; p = mtkp)
    out = prob.f.initialization_data.initializeprobmap(ip)
    sum(out)
end

gives the same error, however, using pgetter returns the gradients.

@DhairyaLGandhi
Copy link
Member

Wrapping it inside a Zygote.ignore() do ... end block would also work

@AayushSabharwal
Copy link
Member Author

For the MWE, sure, but that isn't a solution for MTK. Why does including pgetter make it fail like this?

@DhairyaLGandhi
Copy link
Member

Of course. I am not sure why quite yet. I don't suppose there's any mutation going around since we had to allocate a ref.

@AayushSabharwal
Copy link
Member Author

AayushSabharwal commented May 15, 2025

This still doesn't work

@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
eqs = [D(x) ~ p * x
       D(y) ~ sum(p) + q * y]
u0 = [x => zeros(3),
    y => 1.0]
ps = [p => zeros(3, 3),
    q => 1.0]
tspan = (0.0, 10.0)
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, u0, tspan, ps)
sol = solve(prob, Tsit5())

mtkparams = parameter_values(prob)
new_p = rand(14)
gs = gradient(new_p) do new_p
    new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
    new_prob = remake(prob, p = new_params)
    new_sol = solve(new_prob, Tsit5())
    sum(new_sol)
end

Throws

ERROR: MethodError: no method matching +(::@NamedTuple{}, ::Base.RefValue{…})
The function `+` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...)
   @ Base operators.jl:596
  +(::MutableArithmetics.Zero, ::Any)
   @ MutableArithmetics ~/.julia/packages/MutableArithmetics/tNSBd/src/rewrite.jl:64
  +(::Any, ::MutableArithmetics.Zero)
   @ MutableArithmetics ~/.julia/packages/MutableArithmetics/tNSBd/src/rewrite.jl:65
  ...

Stacktrace:
  [1] accum(x::@NamedTuple{f::Nothing, u0::Nothing, p::@NamedTuple{}, problem_type::Nothing, kwargs::Nothing}, y::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:9
  [2] #get_initial_values#1623
    @ ~/Julia/SciML/SciMLBase.jl/src/initialization.jl:298 [inlined]
  [3] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Thunk{…}, @NamedTuple{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [4] get_initial_values
    @ ~/Julia/SciML/SciMLBase.jl/src/initialization.jl:242 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Thunk{…}, @NamedTuple{…}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [6] maybe_eager_initialize_problem
    @ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:1211 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Thunk{…}, @NamedTuple{…}})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [8] #remake#751
    @ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:266 [inlined]
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [10] remake
    @ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:214 [inlined]
 [11] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{…}, typeof(remake), ODEProblem{…}}, Any})(Δ::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [12] #24
    @ ./REPL[29]:3 [inlined]
 [13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
 [15] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
 [16] top-level scope
    @ REPL[29]:1

@AayushSabharwal
Copy link
Member Author

Interestingly, if I change SciMLBase/src/initialization.jl:263 to nlsol = solve(initprob, Main.NewtonRaphson()) the stacktrace changes to

Stacktrace:
  [1] accum(x::@NamedTuple{f::Nothing, u0::Nothing, tspan::Tuple{…}, p::@NamedTuple{}, kwargs::Nothing, problem_type::Nothing}, y::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:9
  [2] accum(x::@NamedTuple{f::Nothing, u0::Nothing, tspan::Tuple{…}, p::@NamedTuple{}, kwargs::Nothing, problem_type::Nothing}, y::Nothing, zs::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:14
  [3] accum(::Nothing, ::@NamedTuple{f::Nothing, u0::Nothing, tspan::Tuple{…}, p::@NamedTuple{}, kwargs::Nothing, problem_type::Nothing}, ::Nothing, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:14
  [4] maybe_eager_initialize_problem
    @ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:1211 [inlined]
  [5] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Thunk{…}, @NamedTuple{…}})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [6] #remake#751
    @ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:266 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
  [8] remake
    @ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:214 [inlined]
  [9] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{…}, typeof(remake), ODEProblem{…}}, Any})(Δ::Base.RefValue{Any})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [10] #26
    @ ./REPL[31]:3 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
 [14] top-level scope
    @ REPL[31]:1

With the same error message

@DhairyaLGandhi
Copy link
Member

I pulled in the latest of your branch and this is what i see

Zygote.gradient(prob, iprob) do prob, iprob
    pgetter2 = Zygote.ChainRules.ignore_derivatives() do
        prob.f.initialization_data.metadata.oop_reconstruct_u0_p.pgetter
    end
    mtkp = pgetter2(prob, iprob)
    ip = remake(iprob; p = mtkp)
    out = prob.f.initialization_data.initializeprobmap(ip)
    sum(out)
end
((f = nothing, u0 = nothing, tspan = (0.0, nothing), p = (tunable = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], initials = [4.0, 4.0, 4.0, 0.0, 0.0, 0.0, 4.0, 0.0], discrete = nothing, constant = nothing, nonnumeric = nothing, caches = nothing), kwargs = nothing, problem_type = nothing), (f = nothing, u0 = nothing, p = nothing, problem_type = nothing, kwargs = (data = NamedTuple(), itr = nothing)))

@AayushSabharwal
Copy link
Member Author

Try the MWE in the above comment

@DhairyaLGandhi
Copy link
Member

I do see the same error with with #3585 (comment)

@AayushSabharwal
Copy link
Member Author

I'm guessing we need more @ignore_derivatives but am not sure where

@DhairyaLGandhi
Copy link
Member

I am looking at why this case is causing this failure when the tests in SciMLSensitivity don't

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented May 15, 2025

So I found what caused the ref here, and it is more of a code reorganization issue. Couple places had overwritten the variable u0 and p. https://github.com/SciML/SciMLBase.jl/blob/ecdc172103d25552e0d7e8d7211cce542de15315/src/initialization.jl#L244-L245 and again https://github.com/SciML/SciMLBase.jl/blob/ecdc172103d25552e0d7e8d7211cce542de15315/src/initialization.jl#L297-L302

If instead they were named uniquely, then this wouldn't happen.

Refactoring the last little bit of that dispatch as

diff --git a/src/initialization.jl b/src/initialization.jl
index c580f944..65a738fd 100644
--- a/src/initialization.jl
+++ b/src/initialization.jl
@@ -295,13 +295,16 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
     end
 
     if initdata.initializeprobmap !== nothing
-        u0 = initdata.initializeprobmap(nlsol)
+        u02 = initdata.initializeprobmap(nlsol)
     end
     if initdata.initializeprobpmap !== nothing
-        p = initdata.initializeprobpmap(valp, nlsol)
+        p2 = initdata.initializeprobpmap(valp, nlsol)
     end
 
-    return u0, p, success
+    u03 = isnothing(initdata.initializeprobmap) ? u02 : u0
+    p3 = isnothing(initdata.initializeprobmap) ? p2 : p
+
+    return u03, p3, success

and running #3585 (comment) now gives

gs = Zygote.gradient(new_p) do new_p
    new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
    new_prob = remake(prob, p = new_params)
    new_sol = solve(new_prob, Tsit5())
    sum(new_sol)
end
([682491.209731987, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 0.0, 0.0, 0.0, 0.0],)

@AayushSabharwal
Copy link
Member Author

Wow, I didn't know that that was a thing. Thanks for finding this. Is there some sort of general rule against not overwriting variable names for Zygote compatibility?

It also looks like we don't need the @ignore_derivatives for this to work, so I'll remove that too.

@AayushSabharwal
Copy link
Member Author

SciML/SciMLBase.jl#1023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants