Skip to content

Commit f0f3088

Browse files
committed
correct tabulate for blocked elements
1 parent 463734f commit f0f3088

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

python/basix/ufl.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import functools as _functools
44
import hashlib as _hashlib
5+
import itertools as _itertools
56
import typing as _typing
67
from abc import abstractmethod as _abstractmethod
78
from warnings import warn as _warn
@@ -1034,7 +1035,16 @@ def tabulate(self, nderivs: int, points: _npt.NDArray[_np.float64]) -> _npt.NDAr
10341035
output = []
10351036
for table in self.sub_element.tabulate(nderivs, points):
10361037
# Repeat sub element horizontally
1037-
new_table = _np.repeat(table, self._block_size, axis=-1)
1038+
assert len(table.shape) == 2
1039+
new_table = _np.zeros((table.shape[0], *self.block_shape,
1040+
self._block_size * table.shape[1]))
1041+
for i, j in enumerate(_itertools.product(*[range(s) for s in self.block_shape])):
1042+
if len(j) == 1:
1043+
new_table[:, j[0], i::self._block_size] = table
1044+
elif len(j) == 2:
1045+
new_table[:, j[0], j[1], i::self._block_size] = table
1046+
else:
1047+
raise NotImplementedError()
10381048
output.append(new_table)
10391049
return _np.asarray(output, dtype=_np.float64)
10401050

0 commit comments

Comments
 (0)