diff --git a/python/basix/ufl.py b/python/basix/ufl.py index aed646293..b90dcf3f6 100644 --- a/python/basix/ufl.py +++ b/python/basix/ufl.py @@ -2,6 +2,7 @@ import functools as _functools import hashlib as _hashlib +import itertools as _itertools import typing as _typing from abc import abstractmethod as _abstractmethod from warnings import warn as _warn @@ -1034,7 +1035,16 @@ def tabulate(self, nderivs: int, points: _npt.NDArray[_np.float64]) -> _npt.NDAr output = [] for table in self.sub_element.tabulate(nderivs, points): # Repeat sub element horizontally - new_table = _np.repeat(table, self._block_size, axis=-1) + assert len(table.shape) == 2 + new_table = _np.zeros((table.shape[0], *self.block_shape, + self._block_size * table.shape[1])) + for i, j in enumerate(_itertools.product(*[range(s) for s in self.block_shape])): + if len(j) == 1: + new_table[:, j[0], i::self._block_size] = table + elif len(j) == 2: + new_table[:, j[0], j[1], i::self._block_size] = table + else: + raise NotImplementedError() output.append(new_table) return _np.asarray(output, dtype=_np.float64)