Skip to content

Commit

Permalink
Merge pull request spcl#1170 from spcl/users/lukas/cutout-fix
Browse files Browse the repository at this point in the history
Added support for implicitly defined arrays in cutouts
  • Loading branch information
lukastruemper authored Dec 6, 2022
2 parents ef405d7 + 095418b commit cee4205
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
14 changes: 9 additions & 5 deletions dace/sdfg/analysis/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def cutout_state(state: SDFGState, *nodes: nd.Node, make_copy: bool = True) -> S
for sym in freesyms:
new_sdfg.add_symbol(sym, defined_syms[sym])

for dnode in subgraph.data_nodes():
if dnode.data in new_sdfg.arrays:
for edge in subgraph.edges():
if edge.data is None:
continue

memlet = edge.data
if memlet.data in new_sdfg.arrays:
continue
new_desc = sdfg.arrays[dnode.data].clone()
new_desc = sdfg.arrays[memlet.data].clone()
# If transient is defined outside, it becomes a global
if dnode.data in other_arrays:
if memlet.data in other_arrays:
new_desc.transient = False
new_sdfg.add_datadesc(dnode.data, new_desc)
new_sdfg.add_datadesc(memlet.data, new_desc)

# Add a single state with the extended subgraph
new_state = new_sdfg.add_state(state.label, is_start_state=True)
Expand Down
34 changes: 34 additions & 0 deletions tests/sdfg/cutout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def test_cutout_onenode():
""" Tests cutout on a single node in a state. """

@dace.program
def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]):
return A @ B + 5
Expand All @@ -26,6 +27,7 @@ def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]):

def test_cutout_multinode():
""" Tests cutout on multiple nodes in a state. """

@dace.program
def simple_matmul(A: dace.float64[20, 20], B: dace.float64[20, 20]):
return A @ B + 5
Expand Down Expand Up @@ -117,8 +119,40 @@ def test_cutout_scope_fail():
cutout.cutout_state(state, t)


def test_cutout_implicit_array():
N = dace.symbol("N")
C = dace.symbol("C")
nnz = dace.symbol("nnz")

@dace.program
def spmm(
A_row: dace.int32[C + 1],
A_col: dace.int32[nnz],
A_val: dace.float32[nnz],
B: dace.float32[C, N],
):
out = dace.define_local((C, N), dtype=B.dtype)

for i in dace.map[0:C]:
for j in dace.map[A_row[i]:A_row[i + 1]]:
for k in dace.map[0:N]:
b_col = B[:, k]
with dace.tasklet:
w << A_val[j]
b << b_col[A_col[j]]
o >> out(0, lambda x, y: x + y)[i, k]
o = w * b

return out

sdfg = spmm.to_sdfg()
c = cutout.cutout_state(sdfg.start_state, *sdfg.start_state.nodes())
c.validate()


if __name__ == '__main__':
test_cutout_onenode()
test_cutout_multinode()
test_cutout_complex_case()
test_cutout_scope_fail()
test_cutout_implicit_array()

0 comments on commit cee4205

Please sign in to comment.