Skip to content

Commit

Permalink
More progress towards clean BlockMerge
Browse files Browse the repository at this point in the history
  • Loading branch information
meggart committed Jan 17, 2025
1 parent 85b03fe commit c486092
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 88 deletions.
3 changes: 3 additions & 0 deletions devscripts/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Inflate = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
NetCDF = "30363a11-5582-574a-97bb-aa9a979735b9"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
Expand All @@ -25,7 +26,9 @@ PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Proj = "c94c279d-25a6-4763-9509-64d165bea63e"
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tar = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
YAXArrays = "c21b50f5-aa40-41ea-b809-c0f5e47bfa5c"
Zarr = "0a941bbe-ad1d-11e8-39d9-ab76183a1d99"
65 changes: 65 additions & 0 deletions devscripts/distribexperiment.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using Distributed

addprocs(8)

@everywhere function f(x)
sleep(1)
println(myid(), " ", x)
end

data = reshape(1:(6^3), 6, 6, 6)

function takeall!(res, c)
while !isempty(c)
push!(res, take!(c))
end
end

function run_slice(data, workerlist)
s = size(data)
n = last(s)
if length(data) > 10
subworkerlists = [Channel{Int}(Inf) for _ in 1:n]
newtasks = map(1:n) do i
@async begin
inds = map(_ -> :, size(data))
inds = Base.setindex(inds, i, ndims(data))
vdata = view(data, inds...)
run_slice(vdata, subworkerlists[i])
end
end
while true
nexttask = (id=-1, nworkers=typemax(Int))
anyrunning = false
for i in eachindex(newtasks)
t = newtasks[i]
if istaskdone(t)
takeall!(workerlist, subworkerlists[i])
else
anyrunning = true
if length(subworkerlists[i]) < nexttask.nworkers
nexttask = (id=i, nworkers=length(subworkerlists[i]))
end
end
end
if !anyrunning
break
end
worker = take!(workerlist)
put!(subworkerlists[nexttask.id], worker)
end
else

end




end


pool = CachingPool(Int[2, 3])

i = take!(pool)

put!(pool, i)
139 changes: 63 additions & 76 deletions devscripts/maintest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,69 @@ newconn, newnodes = DAE.merged_connection(DAE.BlockMerge, g, conn1, conn2, 2, ne
@test newconn.inputids == [3, 4]
@test newconn.outputids == [2, 1]

newconn.inwindows[1].windows.members[1]
newconn.inwindows[2].windows.members[1]
newconn.outwindows[1].windows.members[1]
newconn.outwindows[2].windows.members[1]


win1 = newconn.inwindows[1].windows.members[1]
@test win1 isa DAE.Window
@test eltype(win1) <: DAE.WindowGroup
@test length(win1) == 2
@test win1[1].g == 1:2
@test win1[1].parent == [1:5, 6:10, 11:15, 16:20]
@test win1[2].g == 3:4
@test win1[2].parent == [1:5, 6:10, 11:15, 16:20]
@test DAE.avg_step(win1) == 10
@test DAE.max_size(win1) == 10


wout1 = newconn.outwindows[1].windows.members[1]
@test wout1 isa DAE.Window
@test eltype(wout1) <: DAE.WindowGroup
@test length(wout1) == 2
@test wout1[1].g == 1:2
@test wout1[1].parent == 1:4
@test wout1[2].g == 3:4
@test wout1[2].parent == 1:4
@test DAE.avg_step(wout1) == 2
@test DAE.max_size(wout1) == 2

win2 = newconn.inwindows[2].windows.members[1]
@test win2 isa DAE.Window
@test eltype(win2) <: DAE.WindowGroup
@test length(win2) == 2
@test win2[1].g == 1:1
@test win2[1].parent == [1:2, 3:4]
@test win2[2].g == 2:2
@test win2[2].parent == [1:2, 3:4]
@test DAE.avg_step(win2) == 2
@test DAE.max_size(win2) == 2

wout2 = newconn.outwindows[2].windows.members[1]
@test wout2 isa DAE.Window
@test eltype(wout2) <: DAE.WindowGroup
@test length(wout2) == 2
@test wout2[1].g == 1:1
@test wout2[1].parent == 1:2
@test wout2[2].g == 2:2
@test wout2[2].parent == 1:2
@test DAE.avg_step(wout2) == 1
@test DAE.max_size(wout2) == 1

@test length(newnodes) == 1
@test newnodes[1] == DAE.EmptyInput{Float64,1}((4,))

append!(g.nodes, newnodes)

deleteat!(g.connections, [1, 2])
push!(g.connections, newconn)

newop = DAE.gmwop_from_reducedgraph(g)

inar = newop.inars[1]
cspec = DAE.get_chunkspec(inar, (2,))
@test cspec.app_cs == (2,)
@test cspec.windowfac == (10,)

lr = DAE.custom_loopranges(newop, (1,))

runner = DAE.LocalRunner(newop, lr)



Expand Down Expand Up @@ -116,76 +173,6 @@ p = graphplot(g, elabels=DAE.edgenames(g), ilabels=DAE.nodenames(g))
#DAE.fuse_step_direct!(g)


nodemergestrategies = DAE.collect_strategies(g)
i_eliminate = findfirst(nodemergestrategies) do strat
!isempty(strat) && !all(isnothing, strat)
end
### DAE.eliminate_node(g, i_eliminate, nodemergestrategies[i_eliminate], BlockMerge)
nodegraph = g
inconids = DAE.inconnections(nodegraph, i_eliminate)
outconids = DAE.outconnections(nodegraph, i_eliminate)
inconns = nodegraph.connections[inconids]
outconns = nodegraph.connections[outconids]

inconn = only(inconns)
outconn = only(outconns)

dimmap = DAE.create_loopdimmap(inconn, outconn, i_eliminate)

newop = DAE.merge_operations(DAE.BlockMerge, inconn, outconn, i_eliminate, dimmap)

newconn = DAE.merged_connection(DAE.BlockMerge, nodegraph, inconn, outconn, i_eliminate, newop, nodemergestrategies, dimmap)

newconn.inputids
newconn.outputids
newconn.inwindows[2].windows.members[2]


nodemergestrategies = DAE.collect_strategies(g)
i_eliminate = findfirst(nodemergestrategies) do strat
!isempty(strat) && !all(isnothing, strat)
end

nodegraph = g;
inconids = DAE.inconnections(nodegraph, i_eliminate)
outconids = DAE.outconnections(nodegraph, i_eliminate)
inconns = nodegraph.connections[inconids]
outconns = nodegraph.connections[outconids]

inconn = only(inconns)
outconn = only(outconns)

dimmap = DAE.create_loopdimmap(inconn, outconn, i_eliminate)

chain1 = DAE.BlockFunctionChain(inconn)
chain2 = DAE.BlockFunctionChain(outconn)

to_eliminate = i_eliminate

chain1.args
chain2.args
ifrom = findfirst(==(to_eliminate), inconn.outputids)
ito = findall(==(to_eliminate), outconn.inputids)
transfer = ifrom => ito

newfunc = DAE.build_chain(chain1, chain2, dimmap, transfer)


newop = DAE.merge_operations(DAE.BlockMerge, inconn, outconn, i_eliminate, dimmap)

newconn = DAE.merged_connection(DAE.BlockMerge, nodegraph, inconn, outconn, i_eliminate, newop, nodemergestrategies, dimmap)

deleteat!(nodegraph.connections, [inconids; outconids])
push!(nodegraph.connections, newconn)

nodegraph.connections

conn = only(nodegraph.connections)
op = conn.f
inputs = InputArray.(g.nodes[conn.inputids], conn.inwindows)
outspecs = map(g.nodes[conn.outputids], conn.outwindows) do outnode, outwindow
(; lw=outwindow, chunks=outnode.chunks, ismem=outnode.ismem)
end


function gmwop_from_conn(conn)
Expand Down
7 changes: 4 additions & 3 deletions src/buffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ getloopinds(b::ArrayBuffer) = getloopinds(b.lw)
function getbufsize(ia, lr)
map(windowbuffersize, mysub(ia, lr.members), ia.lw.windows.members)
end
windowbuffersize(looprange, window) = maximum(c -> internal_size(inner_index(window, c)), looprange)
windowbuffersize(looprange, window) = maximum(looprange) do c
last(inner_range(last(c))) - first(inner_range(first(c))) + 1
end

"Creates buffers for input arrays"
function generate_inbuffers(inars, loopranges)
Expand Down Expand Up @@ -135,8 +137,7 @@ end
function get_bufferindices(r, outspecs)
mywindowrange = mysub(outspecs, r)
BufferIndex(map(outspecs.lw.windows.members, mywindowrange) do w, r
i = inner_index(w, r)
first(first(i)):last(last(i))
first(inner_range(first(w))):last(inner_range(last(w)))
end)
end
get_bufferindices(r::BufferIndex, _) = r
Expand Down
29 changes: 21 additions & 8 deletions src/executionplan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function access_per_chunk(p::ExecutionPlan)
(;input_times,output_times)
end
function actual_access_per_chunk(p::ExecutionPlan)

input_times = map(p.input_chunkspecs) do chunkspec
mylr = mysub(chunkspec.lw,p.lr.members)
actual_chunk_access.(chunkspec.cs,mylr,chunkspec.lw.windows.members)
Expand Down Expand Up @@ -170,7 +170,7 @@ function actual_chunk_access(cs,looprange,window)
n_access = 0
for lr in looprange
w1,w2 = mapreduce(((a,b),(c,d))->(min(a,c),max(b,d)),lr) do ii
extrema(window[ii])
extrema(inner_range(window[ii]))
end
i = DiskArrays.findchunk(cs,w1:w2)
n_access = n_access+length(i)
Expand All @@ -195,7 +195,7 @@ end
function get_chunkspec(outspec,ot)
cs = outspec.chunks
avgs = avg_step.(outspec.lw.windows.members)
si = map(m->last(last(m))-first(first(m))+1,outspec.lw.windows.members)
si = map(m->last(inner_range(last(m)))-first(inner_range(first(m)))+1,outspec.lw.windows.members)
if cs isa GridChunks
cs = cs.chunks
elseif cs === nothing
Expand Down Expand Up @@ -262,9 +262,9 @@ all_constraints(window,chunkspec) = (compute_bufsize(window,chunkspec...),window
all_constraints!(res,window,chunkspec) = res.=all_constraints(window,chunkspec)

avg_step(lw) = avg_step(lw,get_ordering(lw),get_overlap(lw))
avg_step(lw,::Union{Increasing,Decreasing},::Any) = length(lw) > 1 ? mean(diff(first.(lw))) : 1.0
avg_step(lw,::Union{Increasing,Decreasing},::Any) = length(lw) > 1 ? mean(diff(first.(inner_range.(lw)))) : 1.0
avg_step(lw,::Any,::Any) = error("Not implemented")
max_size(lw) = maximum(length,lw)
max_size(lw) = maximum(length,inner_range.(lw))

estimate_singleread(ia::InputArray)= ismem(ia) ? 1e-16 : 1.0
estimate_singleread(ia) = ia.ismem ? 1e-16 : 3.0
Expand All @@ -285,6 +285,17 @@ function optimize_loopranges(op::GMDWop,max_cache;tol_low=0.2,tol_high = 0.05,ma
ExecutionPlan(input_chunkspecs, output_chunkspecs,(sol.u...,),totsize,sol.objective,lr)
end

function custom_loopranges(op, steps::Tuple)
totsize = op.windowsize
input_chunkspecs = get_chunkspec.(op.inars,(totsize,))
output_chunkspecs = get_chunkspec.(op.outspecs,op.f.outtype)
length(steps) == length(totsize) || error("Steps for loop ranges does not fit number of loop vars ($(length(totsize)))")
chunkspecs = (input_chunkspecs..., output_chunkspecs...)
tobj = compute_time((steps...,),chunkspecs)
lrc = DiskArrays.RegularChunks.(steps,0,totsize)
ExecutionPlan(input_chunkspecs, output_chunkspecs,Float64.(steps),totsize,tobj,ProductArray(lrc))
end

using OrderedCollections, Primes

function kgv(i...)
Expand Down Expand Up @@ -382,7 +393,7 @@ function adjust_loopranges(optotal,approx_opti;tol_low=0.2,tol_high = 0.05,max_o
adj_cands = first.(r)
adj_chunks = last.(r)


@debug "Adjust candidates: ", adj_cands
lr = if force_regular
DiskArrays.RegularChunks.(round.(Int,adj_cands),0,optotal.windowsize)
Expand All @@ -403,6 +414,7 @@ loop ranges for a reduction group. This will try to correct loopranges to avoid
the problems mentioned above.
"""
function fix_output_overlap(outspecs,lrbreaks)
@show lrbreaks
for outspec in outspecs
mylr = mysub(outspec.lw,lrbreaks)
newbreaks = map(mylr,outspec.lw.windows.members) do breaks,window
Expand All @@ -428,6 +440,7 @@ function fix_output_overlap(outspecs,lrbreaks)
breaks
end
end
@show newbreaks
for (lr,b) in zip(mylr,newbreaks)
lr.=b
end
Expand All @@ -441,12 +454,12 @@ function output_chunks(outspec,lr)
map(mylr,outspec.lw.windows.members) do llr,wi
ww = map(llr) do l
wnow = wi[l]
first(first(wnow)):last(last(wnow))
first(inner_range(first(wnow))):last(inner_range(last(wnow)))
end
DiskArrays.chunktype_from_chunksizes(length.(sort(unique(ww),lt=rangelt)))
end
end

function output_chunks(p::ExecutionPlan)
output_chunks.(p.output_chunkspecs,(p.lr,))
output_chunks.(p.output_chunkspecs,(p.lr,))
end
12 changes: 11 additions & 1 deletion src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,4 +328,14 @@ function result_to_graph(res)
g = MwopGraph()
to_graph!(g, res)
g
end
end

function gmwop_from_reducedgraph(g)
conn = only(g.connections)
op = conn.f
inputs = InputArray.(g.nodes[conn.inputids], conn.inwindows)
outspecs = map(g.nodes[conn.outputids], conn.outwindows) do outnode, outwindow
(; lw=outwindow, chunks=outnode.chunks, ismem=outnode.ismem)
end
GMDWop((inputs...,), (outspecs...,), op)
end
4 changes: 4 additions & 0 deletions src/windows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ struct WindowGroup{P}
g::UnitRange{Int}
end
inner_range(w::WindowGroup) = first(w.parent[first(w.g)]):last(w.parent[last(w.g)])
inner_range(i::Number) = i:i
inner_range(i::AbstractRange) = i
inner_values(w::WindowGroup) = w.parent[w.g]
inner_values(i) = i
compute_ordering(r::AbstractVector{<:WindowGroup}) = compute_ordering(first(r).parent)
compute_overlap(r::AbstractVector{<:WindowGroup}, ordering) = compute_overlap(inner_range.(r), ordering)
compute_sparsity(r::AbstractVector{<:WindowGroup}) = compute_sparsity(first(r).parent)


function compute_ordering(r)
exts = extrema.(r)
allsorted(x;rev=false) = issorted(x,by=first;rev) && issorted(x,by=last;rev)
Expand Down

0 comments on commit c486092

Please sign in to comment.