forked from jump-dev/JuMP.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconversion.jl
84 lines (75 loc) · 2.76 KB
/
conversion.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# convert from Julia expression into NodeData form
function expr_to_nodedata(ex::Expr,r::UserOperatorRegistry=UserOperatorRegistry())
nd = NodeData[]
values = Float64[]
expr_to_nodedata(ex,nd,values,-1,r)
return nd,values
end
function expr_to_nodedata(ex::Expr,nd::Vector{NodeData},values::Vector{Float64},parentid,r::UserOperatorRegistry)
myid = length(nd) + 1
if isexpr(ex,:call)
op = ex.args[1]
if length(ex.args) == 2
id = haskey(univariate_operator_to_id,op) ? univariate_operator_to_id[op] : r.univariate_operator_to_id[op] + USER_UNIVAR_OPERATOR_ID_START - 1
push!(nd,NodeData(CALLUNIVAR, id, parentid))
elseif op in comparison_operators
push!(nd,NodeData(COMPARISON, comparison_operator_to_id[op], parentid))
else
id = haskey(operator_to_id,op) ? operator_to_id[op] : r.multivariate_operator_to_id[op] + USER_OPERATOR_ID_START - 1
push!(nd,NodeData(CALL, id, parentid))
end
for k in 2:length(ex.args)
expr_to_nodedata(ex.args[k],nd,values,myid,r)
end
elseif isexpr(ex, :ref)
@assert ex.args[1] == :x
push!(nd,NodeData(VARIABLE, ex.args[2], parentid))
elseif isexpr(ex, :comparison)
op = ex.args[2]
opid = comparison_operator_to_id[op]
for k in 2:2:length(ex.args)-1
@assert ex.args[k] == op
end
push!(nd, NodeData(COMPARISON, opid, parentid))
for k in 1:2:length(ex.args)
expr_to_nodedata(ex.args[k],nd,values,myid,r)
end
elseif isexpr(ex,:&&) || isexpr(ex,:||)
@assert length(ex.args) == 2
op = ex.head
opid = logic_operator_to_id[op]
push!(nd, NodeData(LOGIC, opid, parentid))
expr_to_nodedata(ex.args[1],nd,values,myid,r)
expr_to_nodedata(ex.args[2],nd,values,myid,r)
else
error("Unrecognized expression $ex: $(ex.head)")
end
nothing
end
function expr_to_nodedata(ex::Number,nd::Vector{NodeData},values::Vector{Float64},parentid,r::UserOperatorRegistry)
valueidx = length(values)+1
push!(values,ex)
push!(nd, NodeData(VALUE, valueidx, parentid))
nothing
end
export expr_to_nodedata
# (i,j) nonzero means there's an edge *from* j to *i*
# since we get a column-oriented matrix, this gives us a fast way to look up the
# edges leaving any node (i.e., the children)
function adjmat(nd::Vector{NodeData})
len = length(nd)
I = Vector{Int}(undef, len)
J = Vector{Int}(undef, len)
realnz = 0
for nz in 1:len
par = nd[nz].parent
par > 0 || continue
realnz += 1
I[realnz] = nz
J[realnz] = par
end
resize!(I,realnz)
resize!(J,realnz)
return sparse(I,J,ones(Bool,realnz),len,len)
end
export adjmat