Skip to content

Commit

Permalink
Merge branch 'main' into mscroggs/reference_
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgensd authored Oct 2, 2024
2 parents ed13fa5 + bde4f2f commit 07dc71a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
2 changes: 2 additions & 0 deletions ufl/algorithms/apply_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,8 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp):
gprimesum = gprimesum + compute_gprimeterm(
ngrads, vval, vcomp, wshape, wcomp
)
elif isinstance(v, Zero):
pass

else:
if wshape != ():
Expand Down
85 changes: 56 additions & 29 deletions ufl/algorithms/formsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#
# Modified by Cecile Daversin-Catty, 2018

from typing import Optional

from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.argument import Argument
from ufl.classes import FixedIndex, ListTensor
Expand All @@ -31,20 +33,10 @@ def argument(self, obj):
if obj.part() is not None:
# Mixed element built from MixedFunctionSpace,
# whose sub-function spaces are indexed by obj.part()
if len(obj.ufl_shape) == 0:
if obj.part() == self.idx[obj.number()]:
return obj
else:
return Zero()
if obj.part() == self.idx[obj.number()]:
return obj
else:
indices = [()]
for m in obj.ufl_shape:
indices = [(k + (j,)) for k in indices for j in range(m)]

if obj.part() == self.idx[obj.number()]:
return as_vector([obj[j] for j in indices])
else:
return as_vector([Zero() for j in indices])
return Zero(obj.ufl_shape)
else:
# Mixed element built from MixedElement,
# whose sub-elements need their function space to be created
Expand Down Expand Up @@ -92,33 +84,63 @@ def multi_index(self, obj):
expr = MultiFunction.reuse_if_untouched


def extract_blocks(form, i=None, j=None):
"""Extract blocks."""
def extract_blocks(form, i: Optional[int] = None, j: Optional[None] = None):
"""Extract blocks of a form.
If arity is 0, returns the form.
If arity is 1, return the ith block. If ``i`` is ``None``, return all blocks.
If arity is 2, return the ``(i,j)`` entry. If ``j`` is ``None``, return the ith row.
If neither `i` nor `j` are set, return all blocks (as a scalar, vector or tensor).
Args:
form: A form
i: Index of the block to extract. If set to ``None``, ``j`` must be None.
j: Index of the block to extract.
"""
if i is None and j is not None:
raise RuntimeError(f"Cannot extract block with {j=} and {i=}.")

fs = FormSplitter()
arguments = form.arguments()
forms = []
numbers = tuple(sorted(set(a.number() for a in arguments)))
arity = len(numbers)
assert arity <= 2
if arity == 0:
return (form,)

parts = []
for a in arguments:
if len(a.ufl_element().sub_elements) > 0:
return fs.split(form, i, j)
# If mixed element, each argument has no sub-elements
parts = tuple(sorted(set(part for a in arguments if (part := a.part()) is not None)))
if parts == ():
if i is None and j is None:
num_sub_elements = arguments[0].ufl_element().num_sub_elements
forms = []
for pi in range(num_sub_elements):
form_i = []
for pj in range(num_sub_elements):
f = fs.split(form, pi, pj)
if f.empty():
form_i.append(None)
else:
form_i.append(f)
forms.append(tuple(form_i))
return tuple(forms)
else:
# If standard element, extract only part
parts.append(a.part())
parts = tuple(sorted(set(parts)))
for pi in parts:
return fs.split(form, i, j)

# If mixed function space, each argument has sub-elements
forms = []
num_parts = len(parts)
for pi in range(num_parts):
form_i = []
if arity > 1:
for pj in parts:
for pj in range(num_parts):
f = fs.split(form, pi, pj)
if f.empty():
forms.append(None)
form_i.append(None)
else:
forms.append(f)
form_i.append(f)
forms.append(tuple(form_i))
else:
f = fs.split(form, pi)
if f.empty():
Expand All @@ -131,10 +153,15 @@ def extract_blocks(form, i=None, j=None):
except TypeError:
# Only one form returned
forms_tuple = (forms,)

if i is not None:
if (num_rows := len(forms_tuple)) <= i:
raise RuntimeError(f"Cannot extract block {i} from form with {num_rows} blocks.")
if arity > 1 and j is not None:
return forms_tuple[i * len(parts) + j]
if (num_cols := len(forms_tuple[i])) <= j:
raise RuntimeError(
f"Cannot extract block {i},{j} from form with {num_rows}x{num_cols} blocks."
)
return forms_tuple[i][j]
else:
return forms_tuple[i]
else:
Expand Down

0 comments on commit 07dc71a

Please sign in to comment.