From bc24440d97fb7c2091876e5c3bd5688852266bcc Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Fri, 15 Sep 2023 15:58:02 +0100 Subject: [PATCH] Correct tabulate for blocked elements (#698) * correct tabulate for blocked elements * correct test --- python/basix/ufl.py | 12 +++++++++++- test/test_ufl_wrapper.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/basix/ufl.py b/python/basix/ufl.py index aed646293..b90dcf3f6 100644 --- a/python/basix/ufl.py +++ b/python/basix/ufl.py @@ -2,6 +2,7 @@ import functools as _functools import hashlib as _hashlib +import itertools as _itertools import typing as _typing from abc import abstractmethod as _abstractmethod from warnings import warn as _warn @@ -1034,7 +1035,16 @@ def tabulate(self, nderivs: int, points: _npt.NDArray[_np.float64]) -> _npt.NDAr output = [] for table in self.sub_element.tabulate(nderivs, points): # Repeat sub element horizontally - new_table = _np.repeat(table, self._block_size, axis=-1) + assert len(table.shape) == 2 + new_table = _np.zeros((table.shape[0], *self.block_shape, + self._block_size * table.shape[1])) + for i, j in enumerate(_itertools.product(*[range(s) for s in self.block_shape])): + if len(j) == 1: + new_table[:, j[0], i::self._block_size] = table + elif len(j) == 2: + new_table[:, j[0], j[1], i::self._block_size] = table + else: + raise NotImplementedError() output.append(new_table) return _np.asarray(output, dtype=_np.float64) diff --git a/test/test_ufl_wrapper.py b/test/test_ufl_wrapper.py index 22257001b..8c194414b 100644 --- a/test/test_ufl_wrapper.py +++ b/test/test_ufl_wrapper.py @@ -25,7 +25,7 @@ def test_finite_element(inputs): def test_vector_element(inputs): e = basix.ufl.element(*inputs, rank=1) table = e.tabulate(0, [[0, 0]]) - assert table.shape == (1, 1, e.dim) + assert table.shape == (1, 1, e.value_size, e.dim) @pytest.mark.parametrize("inputs", [