-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Description
This is in response of the discussion we had in the 2025-06-12 meeting about Memlets between MapEntry
nodes.
There is the first situation, which is not so strange.
More interesting is this situation, what is the "destination" of the Memlet between the two MapEntry
nodes?
Here are the reproducers.
Case 1:
from typing import Tuple
import dace
from dace.transformation import dataflow as dace_tdf
import numpy as np
import copy
def reference(a, b, c):
c[:] = np.sin(a) ** 2 + np.cos(np.transpose(b)) ** 2
def make_data():
ref = {
"a": np.array(np.random.rand(10, 10), dtype=dace.float64.as_numpy_dtype()),
"b": np.array(np.random.rand(10, 10), dtype=dace.float64.as_numpy_dtype()),
"c": -np.abs(np.array(np.random.rand(10, 10), dtype=dace.float64.as_numpy_dtype())),
}
return ref, copy.deepcopy(ref)
def make_sdfg() -> Tuple[dace.SDFG, dace.SDFGState]:
sdfg = dace.SDFG("crazy_memlet_example")
state = sdfg.add_state(is_start_block=True)
for aname in "abc":
sdfg.add_array(
aname,
shape=(10, 10),
dtype=dace.float64,
transient=False,
)
a, b, c = (state.add_access(aname) for aname in "abc")
sac = []
for i in range(4):
sname = "t" + str(i + 1)
sdfg.add_scalar(
sname,
dtype=dace.float64,
transient=True,
)
sac.append(state.add_access(sname))
t1, t2, t3, t4 = sac
me, mx = state.add_map(
"map",
ndrange={
"__i": "0:10",
"__j": "0:10",
},
)
state.add_edge(
a,
None,
me,
"IN_a",
dace.Memlet("a[0:10, 0:10]")
)
state.add_edge(
me,
"OUT_a",
t1,
None,
dace.Memlet("a[__i, __j] -> [0]")
)
me.add_scope_connectors("a")
state.add_edge(
b,
None,
me,
"IN_b",
dace.Memlet("b[0:10, 0:10]")
)
state.add_edge(
me,
"OUT_b",
t2,
None,
# It is swapped compared to the top one.
dace.Memlet("t2[0] -> [__j, __i]"),
)
me.add_scope_connectors("b")
tlet1 = state.add_tasklet(
"tlet1",
inputs={"__in"},
outputs={"__out"},
code="__out = math.sin(__in)",
)
state.add_edge(
t1,
None,
tlet1,
"__in",
dace.Memlet("t1[0]"),
)
state.add_edge(
tlet1,
"__out",
t3,
None,
dace.Memlet("t3[0]"),
)
tlet1.add_in_connector("__in")
tlet1.add_out_connector("__out")
tlet2 = state.add_tasklet(
"tlet2",
inputs={"__in"},
outputs={"__out"},
code="__out = math.cos(__in)",
)
state.add_edge(
t2,
None,
tlet2,
"__in",
dace.Memlet("t2[0]"),
)
state.add_edge(
tlet2,
"__out",
t4,
None,
dace.Memlet("t4[0]"),
)
tlet2.add_in_connector("__in")
tlet2.add_out_connector("__out")
tlet3 = state.add_tasklet(
"tlet3",
inputs={"__in1", "__in2"},
outputs={"__out"},
code="__out = (__in1 ** 2) + (__in2 ** 2)",
)
state.add_edge(
t3,
None,
tlet3,
"__in1",
dace.Memlet("t3[0]"),
)
state.add_edge(
t4,
None,
tlet3,
"__in2",
dace.Memlet("t4[0]"),
)
tlet3.add_in_connector("__in1")
tlet3.add_in_connector("__in2")
state.add_edge(
tlet3,
"__out",
mx,
"IN_c",
dace.Memlet("c[__i, __j]"),
)
state.add_edge(
mx,
"OUT_c",
c,
None,
dace.Memlet("c[0:10, 0:10]"),
)
tlet3.add_out_connector("__out")
mx.add_scope_connectors("c")
dace_tdf.MapExpansion.apply_to(
sdfg=sdfg,
map_entry=me,
verify=True,
)
sdfg.validate()
return sdfg, state
def main():
sdfg, state = make_sdfg()
ref, res = make_data()
sdfg.view()
reference(**ref)
with dace.config.temporary_config():
dace.Config.set("compiler", "allow_view_arguments", value=True)
dace.Config.set("store_history", value=False)
dace.Config.set("optimizer", "match_exception", value=False)
dace.Config.set("compiler", "use_cache", value=False)
dace.Config.set("cache", value="name")
csdfg = sdfg.compile()
csdfg(**res)
assert all(
np.allclose(ref[k], res[k]) for k in ref.keys()
)
if __name__ == "__main__":
main()
Case 2:
from typing import Tuple
import dace
from dace.transformation import dataflow as dace_tdf
import numpy as np
import copy
def reference(a, b):
b[:] = np.sin(a) ** 2 + np.cos(a) ** 2
def make_data():
ref = {
"a": np.array(np.random.rand(10, 10), dtype=dace.float64.as_numpy_dtype()),
"b": -np.abs(np.array(np.random.rand(10, 10), dtype=dace.float64.as_numpy_dtype())),
}
return ref, copy.deepcopy(ref)
def make_sdfg() -> Tuple[dace.SDFG, dace.SDFGState]:
sdfg = dace.SDFG("crazy_memlet_example_2")
state = sdfg.add_state(is_start_block=True)
for aname in "ab":
sdfg.add_array(
aname,
shape=(10, 10),
dtype=dace.float64,
transient=False,
)
a, b = (state.add_access(aname) for aname in "ab")
sac = []
for i in range(4):
sname = "t" + str(i + 1)
sdfg.add_scalar(
sname,
dtype=dace.float64,
transient=True,
)
sac.append(state.add_access(sname))
t1, t2, t3, t4 = sac
me, mx = state.add_map(
"map",
ndrange={
"__i": "0:10",
"__j": "0:10",
},
)
state.add_edge(
a,
None,
me,
"IN_a",
dace.Memlet("a[0:10, 0:10]")
)
state.add_edge(
me,
"OUT_a",
t1,
None,
dace.Memlet("a[__i, __j] -> [0]")
)
me.add_scope_connectors("a")
state.add_edge(
me,
"OUT_a",
t2,
None,
dace.Memlet("t2[0] -> [__i, __j]"),
)
tlet1 = state.add_tasklet(
"tlet1",
inputs={"__in"},
outputs={"__out"},
code="__out = math.sin(__in)",
)
state.add_edge(
t1,
None,
tlet1,
"__in",
dace.Memlet("t1[0]"),
)
state.add_edge(
tlet1,
"__out",
t3,
None,
dace.Memlet("t3[0]"),
)
tlet1.add_in_connector("__in")
tlet1.add_out_connector("__out")
tlet2 = state.add_tasklet(
"tlet2",
inputs={"__in"},
outputs={"__out"},
code="__out = math.cos(__in)",
)
state.add_edge(
t2,
None,
tlet2,
"__in",
dace.Memlet("t2[0]"),
)
state.add_edge(
tlet2,
"__out",
t4,
None,
dace.Memlet("t4[0]"),
)
tlet2.add_in_connector("__in")
tlet2.add_out_connector("__out")
tlet3 = state.add_tasklet(
"tlet3",
inputs={"__in1", "__in2"},
outputs={"__out"},
code="__out = (__in1 ** 2) + (__in2 ** 2)",
)
state.add_edge(
t3,
None,
tlet3,
"__in1",
dace.Memlet("t3[0]"),
)
state.add_edge(
t4,
None,
tlet3,
"__in2",
dace.Memlet("t4[0]"),
)
tlet3.add_in_connector("__in1")
tlet3.add_in_connector("__in2")
state.add_edge(
tlet3,
"__out",
mx,
"IN_b",
dace.Memlet("b[__i, __j]"),
)
state.add_edge(
mx,
"OUT_b",
b,
None,
dace.Memlet("b[0:10, 0:10]"),
)
tlet3.add_out_connector("__out")
mx.add_scope_connectors("b")
dace_tdf.MapExpansion.apply_to(
sdfg=sdfg,
map_entry=me,
verify=True,
)
sdfg.validate()
return sdfg, state
def main():
sdfg, state = make_sdfg()
ref, res = make_data()
sdfg.view()
reference(**ref)
with dace.config.temporary_config():
dace.Config.set("compiler", "allow_view_arguments", value=True)
dace.Config.set("store_history", value=False)
dace.Config.set("optimizer", "match_exception", value=False)
dace.Config.set("compiler", "use_cache", value=False)
dace.Config.set("cache", value="name")
csdfg = sdfg.compile()
csdfg(**res)
assert all(
np.allclose(ref[k], res[k]) for k in ref.keys()
)
if __name__ == "__main__":
main()
Metadata
Metadata
Assignees
Labels
No labels