diff --git a/ufl/algorithms/compute_form_data.py b/ufl/algorithms/compute_form_data.py index b2c34b1fc..49c58214e 100644 --- a/ufl/algorithms/compute_form_data.py +++ b/ufl/algorithms/compute_form_data.py @@ -30,7 +30,7 @@ from ufl.algorithms.remove_complex_nodes import remove_complex_nodes from ufl.classes import Coefficient, Form, FunctionSpace, GeometricFacetQuantity from ufl.corealg.traversal import traverse_unique_terminals -from ufl.domain import extract_unique_domain, collect_domains_in_form +from ufl.domain import extract_unique_domain, extract_domains from ufl.utils.sequences import max_degree @@ -296,7 +296,7 @@ def compute_form_data( form = apply_integral_scaling(form) # Apply default restriction to fully continuous terminals - have_multiple_domains = len(collect_domains_in_form(form)) > 1 + have_multiple_domains = len(extract_domains(form)) > 1 if do_apply_default_restrictions: form = apply_default_restrictions(form, have_multiple_domains=have_multiple_domains) diff --git a/ufl/domain.py b/ufl/domain.py index 4820b139e..aeaffbdbe 100644 --- a/ufl/domain.py +++ b/ufl/domain.py @@ -261,7 +261,6 @@ def as_domain(domain): if isinstance(domain, MixedMesh): domain, = set(domain._meshes) return domain - try: return extract_unique_domain(domain) except AttributeError: @@ -337,7 +336,24 @@ def extract_domains(expr, expand_mixed_mesh=True): domainlist = [] if isinstance(expr, Form): - pass + form = expr + # Add integration domains + domainlist.extend(expr.ufl_domains()) + # Add domains of coefficients, these may include domains not + # among integration domains + for c in form.coefficients(): + domainlist.extend(extract_domains(c)) + # Add domains of arguments, these may include domains not + # among integration domains + for a in form._arguments: + domainlist.extend(extract_domains(a)) + # Add domains of constants, these may include domains not + # among integration domains + for c in form._constants: + domainlist.extend(extract_domains(c)) + # Add domains of geometric quantities + for gq in form._geometric_quantities: + domainlist.append(gq._domain) else: for t in traverse_unique_terminals(expr): domainlist.extend(t.ufl_domains()) @@ -355,16 +371,6 @@ def extract_unique_domain(expr, expand_mixed_mesh=True): return None -def collect_domains_in_form(form): - meshes = form.ufl_domains() # form._integration_domains - if any(isinstance(m, MixedMesh) for m in meshes): - raise RuntimeError("Found a MixedMesh in form._integration_domains") - meshes = set(meshes) - for integral in form.integrals(): - meshes.update(extract_domains(integral.integrand())) - return sort_domains(meshes) - - def find_geometric_dimension(expr): """Find the geometric dimension of an expression.""" gdims = set() diff --git a/ufl/form.py b/ufl/form.py index e46611bee..83c6d73da 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -650,41 +650,12 @@ def _compute_renumbering(self): renumbering = {} renumbering.update(dn) renumbering.update(tn) - - # Add domains of coefficients, these may include domains not - # among integration domains k = len(dn) - for c in self.coefficients(): - ds = extract_domains(c) - for d in ds: - if d not in renumbering: - renumbering[d] = k - k += 1 - - # Add domains of arguments, these may include domains not - # among integration domains - for a in self._arguments: - ds = extract_domains(a) - for d in ds: - if d not in renumbering: - renumbering[d] = k - k += 1 - - # Add domains of constants, these may include domains not - # among integration domains - for c in self._constants: - ds = extract_domains(c) - for d in ds: - if d not in renumbering: - renumbering[d] = k - k += 1 - - for gq in self._geometric_quantities: - d = gq._domain + ds = extract_domains(self) + for d in ds: if d not in renumbering: renumbering[d] = k k += 1 - return renumbering def _compute_signature(self):