Skip to content

Commit 00a18d4

Browse files
committed
extend extract_domains for form
1 parent 50dc313 commit 00a18d4

File tree

3 files changed

+22
-45
lines changed

3 files changed

+22
-45
lines changed

ufl/algorithms/compute_form_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ufl.algorithms.remove_complex_nodes import remove_complex_nodes
3131
from ufl.classes import Coefficient, Form, FunctionSpace, GeometricFacetQuantity
3232
from ufl.corealg.traversal import traverse_unique_terminals
33-
from ufl.domain import extract_unique_domain, collect_domains_in_form
33+
from ufl.domain import extract_unique_domain, extract_domains
3434
from ufl.utils.sequences import max_degree
3535

3636

@@ -296,7 +296,7 @@ def compute_form_data(
296296
form = apply_integral_scaling(form)
297297

298298
# Apply default restriction to fully continuous terminals
299-
have_multiple_domains = len(collect_domains_in_form(form)) > 1
299+
have_multiple_domains = len(extract_domains(form)) > 1
300300
if do_apply_default_restrictions:
301301
form = apply_default_restrictions(form, have_multiple_domains=have_multiple_domains)
302302

ufl/domain.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def as_domain(domain):
261261
if isinstance(domain, MixedMesh):
262262
domain, = set(domain._meshes)
263263
return domain
264-
265264
try:
266265
return extract_unique_domain(domain)
267266
except AttributeError:
@@ -337,7 +336,24 @@ def extract_domains(expr, expand_mixed_mesh=True):
337336

338337
domainlist = []
339338
if isinstance(expr, Form):
340-
pass
339+
form = expr
340+
# Add integration domains
341+
domainlist.extend(expr.ufl_domains())
342+
# Add domains of coefficients, these may include domains not
343+
# among integration domains
344+
for c in form.coefficients():
345+
domainlist.extend(extract_domains(c))
346+
# Add domains of arguments, these may include domains not
347+
# among integration domains
348+
for a in form._arguments:
349+
domainlist.extend(extract_domains(a))
350+
# Add domains of constants, these may include domains not
351+
# among integration domains
352+
for c in form._constants:
353+
domainlist.extend(extract_domains(c))
354+
# Add domains of geometric quantities
355+
for gq in form._geometric_quantities:
356+
domainlist.append(gq._domain)
341357
else:
342358
for t in traverse_unique_terminals(expr):
343359
domainlist.extend(t.ufl_domains())
@@ -355,16 +371,6 @@ def extract_unique_domain(expr, expand_mixed_mesh=True):
355371
return None
356372

357373

358-
def collect_domains_in_form(form):
359-
meshes = form.ufl_domains() # form._integration_domains
360-
if any(isinstance(m, MixedMesh) for m in meshes):
361-
raise RuntimeError("Found a MixedMesh in form._integration_domains")
362-
meshes = set(meshes)
363-
for integral in form.integrals():
364-
meshes.update(extract_domains(integral.integrand()))
365-
return sort_domains(meshes)
366-
367-
368374
def find_geometric_dimension(expr):
369375
"""Find the geometric dimension of an expression."""
370376
gdims = set()

ufl/form.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -650,41 +650,12 @@ def _compute_renumbering(self):
650650
renumbering = {}
651651
renumbering.update(dn)
652652
renumbering.update(tn)
653-
654-
# Add domains of coefficients, these may include domains not
655-
# among integration domains
656653
k = len(dn)
657-
for c in self.coefficients():
658-
ds = extract_domains(c)
659-
for d in ds:
660-
if d not in renumbering:
661-
renumbering[d] = k
662-
k += 1
663-
664-
# Add domains of arguments, these may include domains not
665-
# among integration domains
666-
for a in self._arguments:
667-
ds = extract_domains(a)
668-
for d in ds:
669-
if d not in renumbering:
670-
renumbering[d] = k
671-
k += 1
672-
673-
# Add domains of constants, these may include domains not
674-
# among integration domains
675-
for c in self._constants:
676-
ds = extract_domains(c)
677-
for d in ds:
678-
if d not in renumbering:
679-
renumbering[d] = k
680-
k += 1
681-
682-
for gq in self._geometric_quantities:
683-
d = gq._domain
654+
ds = extract_domains(self)
655+
for d in ds:
684656
if d not in renumbering:
685657
renumbering[d] = k
686658
k += 1
687-
688659
return renumbering
689660

690661
def _compute_signature(self):

0 commit comments

Comments
 (0)