From bde4f2f98f80a8e1f1a6ef9c004c0f13615dbc72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B8rgen=20Schartum=20Dokken?= Date: Wed, 2 Oct 2024 09:32:21 +0200 Subject: [PATCH] Fix extract block for tensor spaces (#308) * Fix extract block for tensor spaces * Remove conditional from text, as it is described in detail above. * Various sanity checks and error handling * Add handling of zero (happens with extract_block) --- ufl/algorithms/apply_derivatives.py | 2 + ufl/algorithms/formsplitter.py | 85 +++++++++++++++++++---------- ufl/finiteelement.py | 2 +- 3 files changed, 59 insertions(+), 30 deletions(-) diff --git a/ufl/algorithms/apply_derivatives.py b/ufl/algorithms/apply_derivatives.py index 6fcbb1b5f..bd9624f12 100644 --- a/ufl/algorithms/apply_derivatives.py +++ b/ufl/algorithms/apply_derivatives.py @@ -1139,6 +1139,8 @@ def compute_gprimeterm(ngrads, vval, vcomp, wshape, wcomp): gprimesum = gprimesum + compute_gprimeterm( ngrads, vval, vcomp, wshape, wcomp ) + elif isinstance(v, Zero): + pass else: if wshape != (): diff --git a/ufl/algorithms/formsplitter.py b/ufl/algorithms/formsplitter.py index 2e7671cc3..00519963d 100644 --- a/ufl/algorithms/formsplitter.py +++ b/ufl/algorithms/formsplitter.py @@ -8,6 +8,8 @@ # # Modified by Cecile Daversin-Catty, 2018 +from typing import Optional + from ufl.algorithms.map_integrands import map_integrand_dags from ufl.argument import Argument from ufl.classes import FixedIndex, ListTensor @@ -31,20 +33,10 @@ def argument(self, obj): if obj.part() is not None: # Mixed element built from MixedFunctionSpace, # whose sub-function spaces are indexed by obj.part() - if len(obj.ufl_shape) == 0: - if obj.part() == self.idx[obj.number()]: - return obj - else: - return Zero() + if obj.part() == self.idx[obj.number()]: + return obj else: - indices = [()] - for m in obj.ufl_shape: - indices = [(k + (j,)) for k in indices for j in range(m)] - - if obj.part() == self.idx[obj.number()]: - return as_vector([obj[j] for j in indices]) - else: - return as_vector([Zero() for j in indices]) + return Zero(obj.ufl_shape) else: # Mixed element built from MixedElement, # whose sub-elements need their function space to be created @@ -92,33 +84,63 @@ def multi_index(self, obj): expr = MultiFunction.reuse_if_untouched -def extract_blocks(form, i=None, j=None): - """Extract blocks.""" +def extract_blocks(form, i: Optional[int] = None, j: Optional[None] = None): + """Extract blocks of a form. + + If arity is 0, returns the form. + If arity is 1, return the ith block. If ``i`` is ``None``, return all blocks. + If arity is 2, return the ``(i,j)`` entry. If ``j`` is ``None``, return the ith row. + + If neither `i` nor `j` are set, return all blocks (as a scalar, vector or tensor). + + Args: + form: A form + i: Index of the block to extract. If set to ``None``, ``j`` must be None. + j: Index of the block to extract. + """ + if i is None and j is not None: + raise RuntimeError(f"Cannot extract block with {j=} and {i=}.") + fs = FormSplitter() arguments = form.arguments() - forms = [] numbers = tuple(sorted(set(a.number() for a in arguments))) arity = len(numbers) assert arity <= 2 if arity == 0: return (form,) - parts = [] - for a in arguments: - if len(a.ufl_element().sub_elements) > 0: - return fs.split(form, i, j) + # If mixed element, each argument has no sub-elements + parts = tuple(sorted(set(part for a in arguments if (part := a.part()) is not None))) + if parts == (): + if i is None and j is None: + num_sub_elements = arguments[0].ufl_element().num_sub_elements + forms = [] + for pi in range(num_sub_elements): + form_i = [] + for pj in range(num_sub_elements): + f = fs.split(form, pi, pj) + if f.empty(): + form_i.append(None) + else: + form_i.append(f) + forms.append(tuple(form_i)) + return tuple(forms) else: - # If standard element, extract only part - parts.append(a.part()) - parts = tuple(sorted(set(parts))) - for pi in parts: + return fs.split(form, i, j) + + # If mixed function space, each argument has sub-elements + forms = [] + num_parts = len(parts) + for pi in range(num_parts): + form_i = [] if arity > 1: - for pj in parts: + for pj in range(num_parts): f = fs.split(form, pi, pj) if f.empty(): - forms.append(None) + form_i.append(None) else: - forms.append(f) + form_i.append(f) + forms.append(tuple(form_i)) else: f = fs.split(form, pi) if f.empty(): @@ -131,10 +153,15 @@ def extract_blocks(form, i=None, j=None): except TypeError: # Only one form returned forms_tuple = (forms,) - if i is not None: + if (num_rows := len(forms_tuple)) <= i: + raise RuntimeError(f"Cannot extract block {i} from form with {num_rows} blocks.") if arity > 1 and j is not None: - return forms_tuple[i * len(parts) + j] + if (num_cols := len(forms_tuple[i])) <= j: + raise RuntimeError( + f"Cannot extract block {i},{j} from form with {num_rows}x{num_cols} blocks." + ) + return forms_tuple[i][j] else: return forms_tuple[i] else: diff --git a/ufl/finiteelement.py b/ufl/finiteelement.py index a5ffff6a6..413746b45 100644 --- a/ufl/finiteelement.py +++ b/ufl/finiteelement.py @@ -148,7 +148,7 @@ def components(self) -> _typing.Dict[_typing.Tuple[int, ...], int]: offset = 0 c_offset = 0 for e in self.sub_elements: - for i, j in enumerate(np.ndindex(e.value_shape)): + for i, j in enumerate(np.ndindex(e.reference_value_shape)): components[(offset + i,)] = c_offset + e.components[j] c_offset += max(e.components.values()) + 1 offset += e.value_size