Skip to content

Commit

Permalink
Correct tabulate for blocked elements (#698)
Browse files Browse the repository at this point in the history
* correct tabulate for blocked elements

* correct test
  • Loading branch information
mscroggs authored Sep 15, 2023
1 parent 463734f commit bc24440
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
12 changes: 11 additions & 1 deletion python/basix/ufl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools as _functools
import hashlib as _hashlib
import itertools as _itertools
import typing as _typing
from abc import abstractmethod as _abstractmethod
from warnings import warn as _warn
Expand Down Expand Up @@ -1034,7 +1035,16 @@ def tabulate(self, nderivs: int, points: _npt.NDArray[_np.float64]) -> _npt.NDAr
output = []
for table in self.sub_element.tabulate(nderivs, points):
# Repeat sub element horizontally
new_table = _np.repeat(table, self._block_size, axis=-1)
assert len(table.shape) == 2
new_table = _np.zeros((table.shape[0], *self.block_shape,
self._block_size * table.shape[1]))
for i, j in enumerate(_itertools.product(*[range(s) for s in self.block_shape])):
if len(j) == 1:
new_table[:, j[0], i::self._block_size] = table
elif len(j) == 2:
new_table[:, j[0], j[1], i::self._block_size] = table
else:
raise NotImplementedError()
output.append(new_table)
return _np.asarray(output, dtype=_np.float64)

Expand Down
2 changes: 1 addition & 1 deletion test/test_ufl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_finite_element(inputs):
def test_vector_element(inputs):
e = basix.ufl.element(*inputs, rank=1)
table = e.tabulate(0, [[0, 0]])
assert table.shape == (1, 1, e.dim)
assert table.shape == (1, 1, e.value_size, e.dim)


@pytest.mark.parametrize("inputs", [
Expand Down

0 comments on commit bc24440

Please sign in to comment.