forked from SciML/ModelingToolkit.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparameters.jl
114 lines (101 loc) · 2.98 KB
/
parameters.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import SymbolicUtils: symtype, term, hasmetadata, issym
@enum VariableType VARIABLE PARAMETER BROWNIAN
struct MTKVariableTypeCtx end
getvariabletype(x, def = VARIABLE) = getmetadata(unwrap(x), MTKVariableTypeCtx, def)
function isparameter(x)
x = unwrap(x)
if x isa Symbolic && (varT = getvariabletype(x, nothing)) !== nothing
return varT === PARAMETER
#TODO: Delete this branch
elseif x isa Symbolic && Symbolics.getparent(x, false) !== false
p = Symbolics.getparent(x)
isparameter(p) ||
(hasmetadata(p, Symbolics.VariableSource) &&
getmetadata(p, Symbolics.VariableSource)[1] == :parameters)
elseif istree(x) && operation(x) isa Symbolic
varT === PARAMETER || isparameter(operation(x))
elseif istree(x) && operation(x) == (getindex)
isparameter(arguments(x)[1])
elseif x isa Symbolic
varT === PARAMETER
else
false
end
end
"""
toparam(s)
Maps the variable to a parameter.
"""
function toparam(s)
if s isa Symbolics.Arr
Symbolics.wrap(toparam(Symbolics.unwrap(s)))
elseif s isa AbstractArray
map(toparam, s)
else
setmetadata(s, MTKVariableTypeCtx, PARAMETER)
end
end
toparam(s::Num) = wrap(toparam(value(s)))
"""
tovar(s)
Maps the variable to an unknown.
"""
tovar(s::Symbolic) = setmetadata(s, MTKVariableTypeCtx, VARIABLE)
tovar(s::Num) = Num(tovar(value(s)))
"""
$(SIGNATURES)
Define one or more known parameters.
"""
macro parameters(xs...)
Symbolics._parse_vars(:parameters,
Real,
xs,
toparam) |> esc
end
function find_types(array)
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
# t = typeof(x)
get!(set, typeof(x)) do
# if t == Float64
# 1
# else
counter[] += 1
# end
end
end
end
return by.(array)
end
function split_parameters_by_type(ps)
if ps === SciMLBase.NullParameters()
return Float64[], [] #use Float64 to avoid Any type warning
else
by = let set = Dict{Any, Int}(), counter = Ref(0)
x -> begin
get!(set, typeof(x)) do
counter[] += 1
end
end
end
idxs = by.(ps)
split_idxs = [Int[]]
for (i, idx) in enumerate(idxs)
if idx > length(split_idxs)
push!(split_idxs, Int[])
end
push!(split_idxs[idx], i)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))
if ps isa StaticArray
parrs = map(x -> SArray{Tuple{size(x)...}}(x), split_ps)
split_ps = SArray{Tuple{size(parrs)...}}(parrs)
end
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
return (split_ps...,), split_idxs
end
end
end