Skip to content

Commit

Permalink
Merge branch 'main' into dokken/form-split-no-arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd authored Oct 2, 2024
2 parents 02cedcf + 3b85665 commit 7a4ae70
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 89 deletions.
34 changes: 33 additions & 1 deletion test/test_indices.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

import ufl.algorithms
import ufl.classes
from ufl import (
Argument,
Coefficient,
Expand All @@ -15,6 +17,7 @@
exp,
i,
indices,
interval,
j,
k,
l,
Expand Down Expand Up @@ -305,4 +308,33 @@ def test_spatial_derivative(self):


def test_renumbering(self):
pass
"""Test that kernels with common integral data, but different index numbering,
are correctly renumbered."""
cell = interval
mesh = Mesh(FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1))
V = FunctionSpace(mesh, FiniteElement("Lagrange", cell, 1, (2,), identity_pullback, H1))
v = TestFunction(V)
u = TrialFunction(V)
i = indices(1)
a0 = u[i].dx(0) * v[i].dx(0) * ufl.dx((1))
a1 = (
u[i].dx(0)
* v[i].dx(0)
* ufl.dx(
(
2,
3,
)
)
)
form_data = ufl.algorithms.compute_form_data(
a0 + a1,
do_apply_function_pullbacks=True,
do_apply_integral_scaling=True,
do_apply_geometry_lowering=True,
preserve_geometry_types=(ufl.classes.Jacobian,),
do_apply_restrictions=True,
do_append_everywhere_integrals=False,
)

assert len(form_data.integral_data) == 1
60 changes: 36 additions & 24 deletions ufl/algorithms/domain_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
attach_coordinate_derivatives,
strip_coordinate_derivatives,
)
from ufl.algorithms.renumbering import renumber_indices
from ufl.form import Form
from ufl.integral import Integral
from ufl.protocols import id_or_none
Expand Down Expand Up @@ -262,37 +263,16 @@ def build_integral_data(integrals):
itgs = defaultdict(list)

# --- Merge integral data that has the same integrals,
unique_integrals = defaultdict(tuple)
metadata_table = defaultdict(dict)
for integral in integrals:
integrand = integral.integrand()
integral_type = integral.integral_type()
ufl_domain = integral.ufl_domain()
metadata = integral.metadata()
meta_hash = hash(canonicalize_metadata(metadata))
subdomain_id = integral.subdomain_id()
subdomain_data = id_or_none(integral.subdomain_data())
if subdomain_id == "everywhere":
subdomain_ids = integral.subdomain_id()
if "everywhere" in subdomain_ids:
raise ValueError(
"'everywhere' not a valid subdomain id. "
"Did you forget to call group_form_integrals?"
)
unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += (
subdomain_id,
)
metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata

for integral_data, subdomain_ids in unique_integrals.items():
(integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data

integral = Integral(
integrand,
integral_type,
ufl_domain,
subdomain_ids,
metadata_table[integral_data],
subdomain_data,
)
# Group for integral data (One integral data object for all
# integrals with same domain, itype, (but possibly different metadata).
itgs[(ufl_domain, integral_type, subdomain_ids)].append(integral)
Expand Down Expand Up @@ -380,7 +360,39 @@ def calc_hash(cd):
)
integral = attach_coordinate_derivatives(integral, samecd_integrals[0])
integrals.append(integral)
return Form(integrals)

# Group integrals by common integrand
# u.dx(0)*dx(1) + u.dx(0)*dx(2) -> u.dx(0)*dx((1,2))
# to avoid duplicate kernels generated after geometry lowering
unique_integrals = defaultdict(tuple)
metadata_table = defaultdict(dict)
for integral in integrals:
integral_type = integral.integral_type()
ufl_domain = integral.ufl_domain()
metadata = integral.metadata()
meta_hash = hash(canonicalize_metadata(metadata))
subdomain_id = integral.subdomain_id()
subdomain_data = id_or_none(integral.subdomain_data())
integrand = renumber_indices(integral.integrand())
unique_integrals[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] += (
subdomain_id,
)
metadata_table[(integral_type, ufl_domain, meta_hash, integrand, subdomain_data)] = metadata

grouped_integrals = []
for integral_data, subdomain_ids in unique_integrals.items():
(integral_type, ufl_domain, metadata, integrand, subdomain_data) = integral_data
integral = Integral(
integrand,
integral_type,
ufl_domain,
subdomain_ids,
metadata_table[integral_data],
subdomain_data,
)
grouped_integrals.append(integral)

return Form(grouped_integrals)


def reconstruct_form_from_integral_data(integral_data):
Expand Down
98 changes: 34 additions & 64 deletions ufl/algorithms/renumbering.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,57 @@
"""Algorithms for renumbering of counted objects, currently variables and indices."""
# Copyright (C) 2008-2016 Martin Sandve Alnæs and Anders Logg
# Copyright (C) 2008-2024 Martin Sandve Alnæs, Anders Logg, Jørgen S. Dokken and Lawrence Mitchell
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later

from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.classes import Zero
from ufl.core.expr import Expr
from ufl.core.multiindex import FixedIndex, Index, MultiIndex
from ufl.variable import Label, Variable
from collections import defaultdict
from itertools import count as _count

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.core.multiindex import Index
from ufl.corealg.multifunction import MultiFunction

class VariableRenumberingTransformer(ReuseTransformer):
"""Variable renumbering transformer."""

def __init__(self):
"""Initialise."""
ReuseTransformer.__init__(self)
self.variable_map = {}

def variable(self, o):
"""Apply to variable."""
e, l = o.ufl_operands # noqa: E741
v = self.variable_map.get(l)
if v is None:
e = self.visit(e)
l2 = Label(len(self.variable_map))
v = Variable(e, l2)
self.variable_map[l] = v
return v

class IndexRelabeller(MultiFunction):
"""Renumber indices to have a consistent index numbering starting from 0."""

class IndexRenumberingTransformer(VariableRenumberingTransformer):
"""Index renumbering transformer.
def __init__(self):
"""Initialize index relabeller with a zero count."""
super().__init__()
count = _count()
self.index_cache = defaultdict(lambda: Index(next(count)))

This is a poorly designed algorithm. It is used in some tests,
please do not use for anything else.
"""
expr = MultiFunction.reuse_if_untouched

def __init__(self):
"""Initialise."""
VariableRenumberingTransformer.__init__(self)
self.index_map = {}
def multi_index(self, o):
"""Apply to multi-indices."""
return type(o)(
tuple(self.index_cache[i] if isinstance(i, Index) else i for i in o.indices())
)

def zero(self, o):
"""Apply to zero."""
fi = o.ufl_free_indices
fid = o.ufl_index_dimensions
mapped_fi = tuple(self.index(Index(count=i)) for i in fi)
paired_fid = [(mapped_fi[pos], fid[pos]) for pos, a in enumerate(fi)]
new_fi, new_fid = zip(*tuple(sorted(paired_fid)))
return Zero(o.ufl_shape, new_fi, new_fid)

def index(self, o):
"""Apply to index."""
if isinstance(o, FixedIndex):
new_indices = [self.index_cache[Index(i)].count() for i in fi]
if fi == () and fid == ():
return o
else:
c = o._count
i = self.index_map.get(c)
if i is None:
i = Index(count=len(self.index_map))
self.index_map[c] = i
return i
new_fi, new_fid = zip(*sorted(zip(new_indices, fid), key=lambda x: x[0]))
return type(o)(o.ufl_shape, tuple(new_fi), tuple(new_fid))

def multi_index(self, o):
"""Apply to multi_index."""
new_indices = tuple(self.index(i) for i in o.indices())
return MultiIndex(new_indices)

def renumber_indices(form):
"""Renumber indices to have a consistent index numbering starting from 0.
def renumber_indices(expr):
"""Renumber indices."""
if isinstance(expr, Expr):
num_free_indices = len(expr.ufl_free_indices)
This is useful to avoid multiple kernels for the same integrand,
but with different subdomain ids.
result = apply_transformer(expr, IndexRenumberingTransformer())
Args:
form: A UFL form, integral or expression.
if isinstance(expr, Expr):
if num_free_indices != len(result.ufl_free_indices):
raise ValueError(
"The number of free indices left in expression "
"should be invariant w.r.t. renumbering."
)
return result
Returns:
A new form, integral or expression with renumbered indices.
"""
reindexer = IndexRelabeller()
return map_integrand_dags(reindexer, form)

0 comments on commit 7a4ae70

Please sign in to comment.