Skip to content

Commit

Permalink
Merge pull request #2286 from devitocodes/tensor-fit-2
Browse files Browse the repository at this point in the history
misc: Fix gpu-fit for multiple tensors
  • Loading branch information
mloubout authored Dec 19, 2023
2 parents 063d07d + 2294029 commit 3126fb0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 2 additions & 1 deletion devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _normalize_kwargs(cls, **kwargs):
def _normalize_gpu_fit(cls, oo, **kwargs):
try:
gfit = as_tuple(oo.pop('gpu-fit'))
gfit = set().union([f.values() if f.is_AbstractTensor else f for f in gfit])
gfit = set().union(*[f.values() if f.is_AbstractTensor else [f]
for f in gfit])
return tuple(gfit)
except KeyError:
if any(i in kwargs['mode'] for i in ['tasking', 'streaming']):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,13 +1428,22 @@ def test_gpu_fit_w_tensor_functions(self):

u = TensorTimeFunction(name='u', grid=grid)
usave = TensorTimeFunction(name="usave", grid=grid, save=10)
usave2 = TensorTimeFunction(name="usave2", grid=grid, save=10)

eqns = [Eq(u.forward, u + 1),
Eq(usave, u.forward)]

op = Operator(eqns, opt=('noop', {'gpu-fit': usave}))
assert set(op._options['gpu-fit']) - set(usave.values()) == set()

eqns = [Eq(u.forward, u + 1),
Eq(usave, u.forward),
Eq(usave2, u.forward)]

op = Operator(eqns, opt=('noop', {'gpu-fit': [usave, usave2]}))
vals = set().union(usave.values(), usave2.values())
assert set(op._options['gpu-fit']) - vals == set()


class TestMisc(object):

Expand Down

0 comments on commit 3126fb0

Please sign in to comment.