Skip to content

Commit

Permalink
compiler: Inline halo_to_halo
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 5, 2024
1 parent 69e7e59 commit fee66c8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 73 deletions.
5 changes: 4 additions & 1 deletion devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ def union(self, halo_schemes):
Create a new HaloScheme from the union of a set of HaloSchemes.
"""
halo_schemes = [hs for hs in halo_schemes if hs is not None]
if not halo_schemes:

if len(halo_schemes) == 1:
return halo_schemes[0]
elif not halo_schemes:
return None

fmapper = {}
Expand Down
92 changes: 42 additions & 50 deletions devito/passes/iet/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _hoist_halospots(iet):
Example:
haloupd v[t0]
for time for time
W v[t1]- R v[t0] W v[t1]- R v[t0]
W v[t1]- R v[t0] W v[t1]- R v[t0]
haloupd v[t1] haloupd v[t1]
R v[t1] R v[t1]
haloupd v[t0] R v[t0]
Expand All @@ -74,7 +74,7 @@ def _hoist_halospots(iet):
# Precompute scopes to save time
scopes = {i: Scope([e.expr for e in v]) for i, v in MapNodes().visit(iet).items()}

cond_mapper = _create_cond_mapper(iet)
cond_mapper = _make_cond_mapper(iet)

# Analysis
hsmapper = {}
Expand All @@ -95,13 +95,46 @@ def _hoist_halospots(iet):
continue

# If there are overlapping time accesses, skip
if any(i in hs0.halo_scheme.loc_values
for i in hs1.halo_scheme.loc_values):
if hs0.halo_scheme.loc_values.intersection(hs1.halo_scheme.loc_values):
continue

# Compare hs0 to subsequent halo_spots, looking for optimization
# possibilities
_process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper)
# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():
# If no time accesses, skip
if not hs1.halo_scheme.fmapper[f].loc_indices:
continue

# If the function is not in both HaloSpots, skip
if f not in hs0.functions:
continue

for it in iters:
# If also merge-able we can start hoisting the latter
for dep in scopes[it].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Entering here means we can lift, and we need to update
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in hse.loc_indices:
if hse.loc_indices[d].is_Symbol:
assert d in hse.loc_indices[d]._defines
root_min = hse.loc_indices[d].symbolic_min
new_min = root_min.subs(hse.loc_indices[d].root,
hse.loc_indices[d].root.symbolic_min)
raw_loc_indices[d] = new_min
else:
assert d.symbolic_min in hse.loc_indices[d].free_symbols
raw_loc_indices[d] = hse.loc_indices[d]

hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[it].append(hs1.halo_scheme.project(f))

mapper = {i: HaloSpot(i._rebuild(), HaloScheme.union(hss))
for i, hss in imapper.items()}
Expand All @@ -113,55 +146,14 @@ def _hoist_halospots(iet):
return iet


def _process_halo_to_halo(hs0, hs1, iters, scopes, hsmapper, imapper):

# Loop over the functions in the HaloSpots
for f, v in hs1.fmapper.items():
# If no time accesses, skip
if not hs1.halo_scheme.fmapper[f].loc_indices:
continue

# If the function is not in both HaloSpots, skip
if (*hs0.functions, *hs1.functions).count(f) < 2:
continue

for iter in iters:
# If also merge-able we can start hoisting the latter
for dep in scopes[iter].d_flow.project(f):
if not any(r(dep, hs1, v.loc_indices) for r in merge_rules()):
break
else:
hse = hs1.halo_scheme.fmapper[f]
raw_loc_indices = {}
# Entering here means we can lift, and we need to update
# the loc_indices with known values
# TODO: Can I get this in a more elegant way?
for d in hse.loc_indices:
if hse.loc_indices[d].is_Symbol:
assert d in hse.loc_indices[d]._defines
root_min = hse.loc_indices[d].symbolic_min
new_min = root_min.subs(hse.loc_indices[d].root,
hse.loc_indices[d].root.symbolic_min)
raw_loc_indices[d] = new_min
else:
assert d.symbolic_min in hse.loc_indices[d].free_symbols
raw_loc_indices[d] = hse.loc_indices[d]

hse = hse.rebuild(loc_indices=frozendict(raw_loc_indices))
hs1.halo_scheme.fmapper[f] = hse

hsmapper[hs1] = hsmapper.get(hs1, hs1.halo_scheme).drop(f)
imapper[iter].append(hs1.halo_scheme.project(f))


def _merge_halospots(iet):
"""
Merge HaloSpots on the same Iteration tree level where all data dependencies
would be honored.
"""

# Analysis
cond_mapper = _create_cond_mapper(iet)
cond_mapper = _make_cond_mapper(iet)

mapper = {}
for iter, halo_spots in MapNodes(Iteration, HaloSpot, 'immediate').visit(iet).items():
Expand Down Expand Up @@ -363,7 +355,7 @@ def mpiize(graph, **kwargs):

# Utility functions to avoid code duplication

def _create_cond_mapper(iet):
def _make_cond_mapper(iet):
cond_mapper = MapHaloSpots().visit(iet)
return {hs: {i for i in v if i.is_Conditional and
not isinstance(i.condition, GuardFactorEq)}
Expand Down
35 changes: 13 additions & 22 deletions tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from devito.types.dimension import SpaceDimension

from examples.seismic.acoustic import acoustic_setup
from examples.seismic import Receiver, TimeAxis, demo_model
from examples.seismic import demo_model
from tests.test_dse import TestTTI


Expand Down Expand Up @@ -1013,23 +1013,20 @@ def test_avoid_haloupdate_if_distr_but_sequential(self, mode):

@pytest.mark.parallel(mode=1)
def test_issue_2448(self, mode):
extent = (10., )
shape = (2, )
extent = (10.,)
shape = (2,)
so = 2
to = 1

x = SpaceDimension(name='x', spacing=Constant(name='h_x',
value=extent[0]/(shape[0]-1)))
grid = Grid(extent=extent, shape=shape, dimensions=(x, ))
grid = Grid(extent=extent, shape=shape, dimensions=(x,))

# Time related
t0, tn = 0., 30.
dt = (10. / np.sqrt(2.)) / 6.
time_range = TimeAxis(start=t0, stop=tn, step=dt)
tn = 30

# Velocity and pressure fields
v = TimeFunction(name='v', grid=grid, space_order=so, time_order=to)
tau = TimeFunction(name='tau', grid=grid, space_order=so, time_order=to)
v = TimeFunction(name='v', grid=grid, space_order=so)
tau = TimeFunction(name='tau', grid=grid, space_order=so)

# First order elastic-like dependencies equations
pde_v = v.dt - (tau.dx)
Expand All @@ -1040,7 +1037,7 @@ def test_issue_2448(self, mode):

# Test two variants of receiver interpolation
nrec = 1
rec = Receiver(name="rec", grid=grid, npoint=nrec, time_range=time_range)
rec = SparseTimeFunction(name="rec", grid=grid, npoint=nrec, nt=tn)
rec.coordinates.data[:, 0] = np.linspace(0., extent[0], num=nrec)

# The receiver 0
Expand Down Expand Up @@ -2811,18 +2808,12 @@ def test_elastic_structure(self, mode):
shape=(301, 301), spacing=(10., 10.),
space_order=so)

t0, tn = 0., 2000.
dt = model.critical_dt
time_range = TimeAxis(start=t0, stop=tn, step=dt)

x, z = model.grid.dimensions

v = VectorTimeFunction(name='v', grid=model.grid, space_order=so, time_order=1)
tau = TensorTimeFunction(name='t', grid=model.grid, space_order=so, time_order=1)
v = VectorTimeFunction(name='v', grid=model.grid, space_order=so)
tau = TensorTimeFunction(name='t', grid=model.grid, space_order=so)

# The receiver
nrec = 301
rec = Receiver(name="rec", grid=model.grid, npoint=nrec, time_range=time_range)
rec = SparseTimeFunction(name="rec", grid=model.grid, npoint=nrec, nt=10)
rec.coordinates.data[:, 0] = np.linspace(0., model.domain_size[0], num=nrec)
rec.coordinates.data[:, -1] = 5.

Expand Down Expand Up @@ -2864,9 +2855,9 @@ def test_elastic_structure(self, mode):
assert calls[4].arguments[1] is v[1]


class TestTTI_w_MPI:
class TestTTIwMPI:

@pytest.mark.parallel(mode=[(1)])
@pytest.mark.parallel(mode=1)
def test_halo_structure(self, mode):

mytest = TestTTI()
Expand Down

0 comments on commit fee66c8

Please sign in to comment.