-
-
Notifications
You must be signed in to change notification settings - Fork 220
fix: allow specifying type of buffers inside MTKParameters
#3585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
5abdb20
to
b1de305
Compare
src/systems/parameter_buffer.jl
Outdated
t0 = nothing, substitution_limit = 1000, floatT = nothing, | ||
container_type = Vector) | ||
if !(container_type <: AbstractArray) | ||
container_type = Array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved
b1de305
to
458f944
Compare
Fixes #330. Currently the MWE works: ```julia using DiffEqGPU using OrdinaryDiffEqTsit5, ModelingToolkit, StaticArrays using ModelingToolkit: t_nounits as t, D_nounits as D @parameters σ ρ β @variables x(t) y(t) z(t) eqs = [D(D(x)) ~ σ * (y - x), D(y) ~ x * (ρ - z) - y, D(z) ~ x * y - β * z] @mtkbuild sys = ODESystem(eqs, t) split=false u0 = SA[D(x) => 2f0, x => 1f0, y => 0f0, z => 0f0] p = SA[σ => 28f0, ρ => 10f0, β => 8f0 / 3f0] tspan = (0f0, 100f0) prob = ODEProblem{false}(sys, u0, tspan, p, split=true) prob = remake(prob, p = p = SVector{10, Float32}(prob.p...)) sol = solve(prob, Tsit5()) using SymbolicIndexingInterface p_setter = setp_oop(sys, [σ, ρ, β]) using DiffEqGPU, CUDA function prob_func2(prob, i, repeat) remake(prob, p = p_setter(prob,@svector(rand(Float32,3)))) end monteprob = EnsembleProblem(prob, prob_func = prob_func2, safetycopy = false) sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()), trajectories = 10_000) ``` But you need to `#prob = get_updated_symbolic_problem(_get_root_indp(prob), prob; kwargs...)` in DiffEqBase. What's in here drops the `split=false` part. We need to fix `get_updated_symbolic_problem` to not promote to `Float64` and fix static array outputs in `split=true`, i.e. SciML/ModelingToolkit.jl#3585, in order to finish this tutorial.
458f944
to
2361883
Compare
Needs SciML/DiffEqBase.jl#1151 for CI to run |
…bstractArray` subtype
Not ProjectTo, but probable that the Ref in the stacktrace actually contains a tangent as well, which should get accumulated. I would check where the original argument gets mutated. |
So this happens because the initialization is trivial and it skips calling |
Here's an MWE which reproduces the error and does not do any mutation: @variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
eqs = [D(x) ~ p * x
D(y) ~ sum(p) + q * y]
u0 = [x => zeros(3),
y => 1.0]
ps = [p => zeros(3, 3),
q => 1.0]
tspan = (0.0, 10.0)
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, u0, tspan, ps)
gs = gradient(prob) do prob
new_iprob = prob.f.initialization_data.update_initializeprob!(prob.f.initialization_data.initializeprob, prob)
sum(prob.f.initialization_data.initializeprobmap(new_iprob))
end Despite the name, |
The MWE doesn't quite work for me.. perhaps I need some branches etc? prob = ODEProblem(sys, u0, tspan, ps)
upf = prob.f.initialization_data.update_initializeprob!
iprob = prob.f.initialization_data.initializeprob
gs = Zygote.gradient(prob) do prob
new_iprob = upf(iprob, prob)
@show new_iprob
out = prob.f.initialization_data.initializeprobmap(new_iprob)
@show out
sum(out)
end returns new_iprob = [nothing, nothing, nothing, nothing, nothing, nothing, nothing]
ERROR: BoundsError: attempt to access 7-element Vector{Nothing} at index [3:11]
Stacktrace:
[1] throw_boundserror(A::Vector{Nothing}, I::Tuple{UnitRange{Int64}})
@ Base ./abstractarray.jl:737
[2] checkbounds
@ ./abstractarray.jl:702 [inlined]
[3] view
@ ./subarray.jl:184 [inlined]
[4] rrule
@ ~/.julia/packages/ChainRules/Q16hj/src/rulesets/Base/indexing.jl:202 [inlined]
[5] rrule
@ ~/.julia/packages/ChainRulesCore/U6wNx/src/rules.jl:138 [inlined]
[6] chain_rrule
@ ~/.julia/packages/Zygote/1GK3J/src/compiler/chainrules.jl:224 [inlined]
[7] macro expansion
@ ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0 [inlined]
[8] _pullback
@ ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:91 [inlined]
[9] generated_callfunc
@ ~/.julia/packages/SymbolicUtils/0GKgW/src/code.jl:411 [inlined]
[10] _pullback(::Zygote.Context{…}, ::typeof(RuntimeGeneratedFunctions.generated_callfunc), ::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{…}, ::Vector{…}, ::Tuple{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[11] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[12] adjoint
@ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
[13] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[14] RuntimeGeneratedFunction
@ ~/.julia/packages/RuntimeGeneratedFunctions/RrXEW/src/RuntimeGeneratedFunctions.jl:148 [inlined]
[15] _generated_call
@ ~/.julia/packages/ModelingToolkit/Ljerk/src/systems/codegen_utils.jl:0 [inlined]
[16] _pullback(::Zygote.Context{…}, ::typeof(ModelingToolkit._generated_call), ::ModelingToolkit.GeneratedFunctionWrapper{…}, ::Vector{…}, ::Vector{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[17] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[18] adjoint
@ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
[19] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[20] GeneratedFunctionWrapper
@ ~/.julia/packages/ModelingToolkit/Ljerk/src/systems/codegen_utils.jl:259 [inlined]
[21] _pullback(::Zygote.Context{…}, ::ModelingToolkit.GeneratedFunctionWrapper{…}, ::Vector{…}, ::Vector{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[22] TimeIndependentObservedFunction
@ ~/.julia/packages/SymbolicIndexingInterface/3UAF0/src/state_indexing.jl:142 [inlined]
[23] _pullback(::Zygote.Context{…}, ::SymbolicIndexingInterface.TimeIndependentObservedFunction{…}, ::NotTimeseries, ::Vector{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[24] AbstractStateGetIndexer
@ ~/.julia/packages/SymbolicIndexingInterface/3UAF0/src/value_provider_interface.jl:166 [inlined]
[25] _pullback(ctx::Zygote.Context{…}, f::SymbolicIndexingInterface.TimeIndependentObservedFunction{…}, args::Vector{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[26] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[27] adjoint
@ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
[28] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[29] call_composed
@ ./operators.jl:1045 [inlined]
[30] call_composed
@ ./operators.jl:1044 [inlined]
[31] #_#103
@ ./operators.jl:1041 [inlined]
[32] _pullback(::Zygote.Context{…}, ::Base.var"##_#103", ::@Kwargs{}, ::ComposedFunction{…}, ::Vector{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[33] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[34] adjoint
@ ~/.julia/packages/Zygote/1GK3J/src/lib/lib.jl:202 [inlined]
[35] _pullback
@ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
[36] ComposedFunction
@ ./operators.jl:1041 [inlined]
[37] _pullback(ctx::Zygote.Context{…}, f::ComposedFunction{…}, args::Vector{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[38] #57
@ ~/Downloads/arpa/jsmo/t2/JuliaSimExampleComponents/dc_motor.jl:160 [inlined]
[39] _pullback(ctx::Zygote.Context{…}, f::var"#57#58", args::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface2.jl:0
[40] pullback(f::Function, cx::Zygote.Context{…}, args::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface.jl:90
[41] pullback
@ ~/.julia/packages/Zygote/1GK3J/src/compiler/interface.jl:88 [inlined]
[42] gradient(f::Function, args::ODEProblem{…})
@ Zygote ~/.julia/packages/Zygote/1GK3J/src/compiler/interface.jl:147
[43] top-level scope
@ ~/Downloads/arpa/jsmo/t2/JuliaSimExampleComponents/dc_motor.jl:157
Some type information was truncated. Use `show(err)` to see complete types. Is it expected to have the |
Just this PR should be fine. This is my environment:
|
|
prob = ODEProblem(sys, u0, tspan, ps)
upf = prob.f.initialization_data.update_initializeprob!
iprob = prob.f.initialization_data.initializeprob
pgetter = prob.f.initialization_data.metadata.oop_reconstruct_u0_p.pgetter
Zygote.gradient(prob, iprob) do prob, iprob
mtkp = pgetter(prob, iprob)
ip = remake(iprob; p = mtkp)
out = prob.f.initialization_data.initializeprobmap(ip)
sum(out)
end Is that the same as expanding It returns for me: ((f = nothing, u0 = nothing, tspan = (0.0, nothing), p = (tunable = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], initials = [4.0, 4.0, 4.0, 0.0, 0.0, 0.0, 4.0, 0.0], discrete = nothing, constant = nothing, nonnumeric = nothing, caches = nothing), kwargs = nothing, problem_type = nothing), (f = nothing, u0 = nothing, p = nothing, problem_type = nothing, kwargs = (data = NamedTuple(), itr = nothing))) |
Interesting. So what makes this fail is if you include |
Actually it is including the Zygote.gradient(prob, iprob) do prob, iprob
pgetter2 = prob.f.initialization_data.metadata.oop_reconstruct_u0_p.pgetter
mtkp = pgetter2(prob, iprob)
ip = remake(iprob; p = mtkp)
out = prob.f.initialization_data.initializeprobmap(ip)
sum(out)
end gives the same error, however, using |
Wrapping it inside a |
For the MWE, sure, but that isn't a solution for MTK. Why does including |
Of course. I am not sure why quite yet. I don't suppose there's any mutation going around since we had to allocate a ref. |
This still doesn't work @variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
eqs = [D(x) ~ p * x
D(y) ~ sum(p) + q * y]
u0 = [x => zeros(3),
y => 1.0]
ps = [p => zeros(3, 3),
q => 1.0]
tspan = (0.0, 10.0)
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, u0, tspan, ps)
sol = solve(prob, Tsit5())
mtkparams = parameter_values(prob)
new_p = rand(14)
gs = gradient(new_p) do new_p
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
new_prob = remake(prob, p = new_params)
new_sol = solve(new_prob, Tsit5())
sum(new_sol)
end Throws ERROR: MethodError: no method matching +(::@NamedTuple{…}, ::Base.RefValue{…})
The function `+` exists, but no method is defined for this combination of argument types.
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...)
@ Base operators.jl:596
+(::MutableArithmetics.Zero, ::Any)
@ MutableArithmetics ~/.julia/packages/MutableArithmetics/tNSBd/src/rewrite.jl:64
+(::Any, ::MutableArithmetics.Zero)
@ MutableArithmetics ~/.julia/packages/MutableArithmetics/tNSBd/src/rewrite.jl:65
...
Stacktrace:
[1] accum(x::@NamedTuple{f::Nothing, u0::Nothing, p::@NamedTuple{…}, problem_type::Nothing, kwargs::Nothing}, y::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:9
[2] #get_initial_values#1623
@ ~/Julia/SciML/SciMLBase.jl/src/initialization.jl:298 [inlined]
[3] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Thunk{…}, @NamedTuple{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[4] get_initial_values
@ ~/Julia/SciML/SciMLBase.jl/src/initialization.jl:242 [inlined]
[5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Thunk{…}, @NamedTuple{…}, Nothing})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[6] maybe_eager_initialize_problem
@ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:1211 [inlined]
[7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Thunk{…}, @NamedTuple{…}})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[8] #remake#751
@ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:266 [inlined]
[9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[10] remake
@ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:214 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{…}, typeof(remake), ODEProblem{…}}, Any})(Δ::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[12] #24
@ ./REPL[29]:3 [inlined]
[13] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[14] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
[15] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
[16] top-level scope
@ REPL[29]:1 |
Interestingly, if I change Stacktrace:
[1] accum(x::@NamedTuple{f::Nothing, u0::Nothing, tspan::Tuple{…}, p::@NamedTuple{…}, kwargs::Nothing, problem_type::Nothing}, y::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:9
[2] accum(x::@NamedTuple{f::Nothing, u0::Nothing, tspan::Tuple{…}, p::@NamedTuple{…}, kwargs::Nothing, problem_type::Nothing}, y::Nothing, zs::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:14
[3] accum(::Nothing, ::@NamedTuple{f::Nothing, u0::Nothing, tspan::Tuple{…}, p::@NamedTuple{…}, kwargs::Nothing, problem_type::Nothing}, ::Nothing, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/lib.jl:14
[4] maybe_eager_initialize_problem
@ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:1211 [inlined]
[5] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Tuple{Thunk{…}, @NamedTuple{…}})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[6] #remake#751
@ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:266 [inlined]
[7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[8] remake
@ ~/.julia/packages/SciMLBase/JKXkh/src/remake.jl:214 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(Core.kwcall), @NamedTuple{…}, typeof(remake), ODEProblem{…}}, Any})(Δ::Base.RefValue{Any})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[10] #26
@ ./REPL[31]:3 [inlined]
[11] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[12] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
[14] top-level scope
@ REPL[31]:1 With the same error message |
I pulled in the latest of your branch and this is what i see Zygote.gradient(prob, iprob) do prob, iprob
pgetter2 = Zygote.ChainRules.ignore_derivatives() do
prob.f.initialization_data.metadata.oop_reconstruct_u0_p.pgetter
end
mtkp = pgetter2(prob, iprob)
ip = remake(iprob; p = mtkp)
out = prob.f.initialization_data.initializeprobmap(ip)
sum(out)
end ((f = nothing, u0 = nothing, tspan = (0.0, nothing), p = (tunable = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], initials = [4.0, 4.0, 4.0, 0.0, 0.0, 0.0, 4.0, 0.0], discrete = nothing, constant = nothing, nonnumeric = nothing, caches = nothing), kwargs = nothing, problem_type = nothing), (f = nothing, u0 = nothing, p = nothing, problem_type = nothing, kwargs = (data = NamedTuple(), itr = nothing))) |
Try the MWE in the above comment |
I do see the same error with with #3585 (comment) |
I'm guessing we need more |
I am looking at why this case is causing this failure when the tests in SciMLSensitivity don't |
So I found what caused the ref here, and it is more of a code reorganization issue. Couple places had overwritten the variable If instead they were named uniquely, then this wouldn't happen. Refactoring the last little bit of that dispatch as diff --git a/src/initialization.jl b/src/initialization.jl
index c580f944..65a738fd 100644
--- a/src/initialization.jl
+++ b/src/initialization.jl
@@ -295,13 +295,16 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
end
if initdata.initializeprobmap !== nothing
- u0 = initdata.initializeprobmap(nlsol)
+ u02 = initdata.initializeprobmap(nlsol)
end
if initdata.initializeprobpmap !== nothing
- p = initdata.initializeprobpmap(valp, nlsol)
+ p2 = initdata.initializeprobpmap(valp, nlsol)
end
- return u0, p, success
+ u03 = isnothing(initdata.initializeprobmap) ? u02 : u0
+ p3 = isnothing(initdata.initializeprobmap) ? p2 : p
+
+ return u03, p3, success and running #3585 (comment) now gives gs = Zygote.gradient(new_p) do new_p
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
new_prob = remake(prob, p = new_params)
new_sol = solve(new_prob, Tsit5())
sum(new_sol)
end
([682491.209731987, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 15007.02429573708, 0.0, 0.0, 0.0, 0.0],) |
Wow, I didn't know that that was a thing. Thanks for finding this. Is there some sort of general rule against not overwriting variable names for Zygote compatibility? It also looks like we don't need the |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Add any other context about the problem here.