Skip to content

Commit

Permalink
Start simplification and consistency fixes for fem::FiniteElement (#…
Browse files Browse the repository at this point in the history
…3502)

* Simplify element

* Small fix

* Tidy up

* Update comment

* Simplification

* Simplify

* Logic fix

* Fix test

* Fix

* Update comment

* Test update

* Update docs
  • Loading branch information
garth-wells authored Nov 8, 2024
1 parent 6cc0210 commit b40ab4d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 37 deletions.
42 changes: 26 additions & 16 deletions cpp/dolfinx/fem/FiniteElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ FiniteElement<T>::FiniteElement(mesh::CellType cell_type,
std::size_t block_size, bool symmetric)
: _signature("Quadrature element " + std::to_string(pshape[0]) + " "
+ std::to_string(block_size)),
_space_dim(pshape[0] * block_size), _reference_value_shape({}),
_bs(block_size), _is_mixed(false), _symmetric(symmetric),
_needs_dof_permutations(false), _needs_dof_transformations(false),
_space_dim(pshape[0] * block_size),
_reference_value_shape(std::vector<std::size_t>{}), _bs(block_size),
_symmetric(symmetric), _needs_dof_permutations(false),
_needs_dof_transformations(false),
_entity_dofs(mesh::cell_dim(cell_type) + 1),
_entity_closure_dofs(mesh::cell_dim(cell_type) + 1),
_points(std::vector<T>(points.begin(), points.end()), pshape)
Expand All @@ -98,7 +99,6 @@ FiniteElement<T>::FiniteElement(const basix::FiniteElement<T>& element,
std::size_t block_size, bool symmetric)
: _space_dim(block_size * element.dim()),
_reference_value_shape(element.value_shape()), _bs(block_size),
_is_mixed(false),
_element(std::make_unique<basix::FiniteElement<T>>(element)),
_symmetric(symmetric),
_needs_dof_permutations(
Expand Down Expand Up @@ -139,11 +139,16 @@ FiniteElement<T>::FiniteElement(const basix::FiniteElement<T>& element,
template <std::floating_point T>
FiniteElement<T>::FiniteElement(
const std::vector<std::shared_ptr<const FiniteElement<T>>>& elements)
: _space_dim(0), _sub_elements(elements), _bs(1), _is_mixed(true),
_symmetric(false), _needs_dof_permutations(false),
_needs_dof_transformations(false)
: _space_dim(0), _sub_elements(elements),
_reference_value_shape(std::nullopt), _bs(1), _symmetric(false),
_needs_dof_permutations(false), _needs_dof_transformations(false)
{
std::size_t vsize = 0;
if (elements.size() < 2)
{
throw std::runtime_error("FiniteElement constructor for mixed elements "
"called with a single element.");
}

_signature = "Mixed element (";

const std::vector<std::vector<std::vector<int>>>& ed
Expand All @@ -159,7 +164,6 @@ FiniteElement<T>::FiniteElement(
int dof_offset = 0;
for (auto& e : elements)
{
vsize += e->reference_value_size();
_signature += e->signature() + ", ";

if (e->needs_dof_permutations())
Expand Down Expand Up @@ -191,7 +195,6 @@ FiniteElement<T>::FiniteElement(
}

_space_dim = dof_offset;
_reference_value_shape = {vsize};
_signature += ")";
}
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -226,10 +229,12 @@ int FiniteElement<T>::space_dimension() const noexcept
}
//-----------------------------------------------------------------------------
template <std::floating_point T>
std::span<const std::size_t>
FiniteElement<T>::reference_value_shape() const noexcept
std::span<const std::size_t> FiniteElement<T>::reference_value_shape() const
{
return _reference_value_shape;
if (_reference_value_shape)
return *_reference_value_shape;
else
throw std::runtime_error("Element does not have a reference_value_shape.");
}
//-----------------------------------------------------------------------------
template <std::floating_point T>
Expand All @@ -255,8 +260,13 @@ bool FiniteElement<T>::symmetric() const
template <std::floating_point T>
int FiniteElement<T>::reference_value_size() const
{
return std::accumulate(_reference_value_shape.begin(),
_reference_value_shape.end(), 1, std::multiplies{});
if (_reference_value_shape)
{
return std::accumulate(_reference_value_shape->begin(),
_reference_value_shape->end(), 1, std::multiplies{});
}
else
throw std::runtime_error("Element does not have a reference_value_shape.");
}
//-----------------------------------------------------------------------------
template <std::floating_point T>
Expand Down Expand Up @@ -292,7 +302,7 @@ int FiniteElement<T>::num_sub_elements() const noexcept
template <std::floating_point T>
bool FiniteElement<T>::is_mixed() const noexcept
{
return _is_mixed;
return !_reference_value_shape;
}
//-----------------------------------------------------------------------------
template <std::floating_point T>
Expand Down
46 changes: 28 additions & 18 deletions cpp/dolfinx/fem/FiniteElement.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <dolfinx/mesh/cell_types.h>
#include <functional>
#include <memory>
#include <optional>
#include <span>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -103,19 +104,32 @@ class FiniteElement
/// @return Dimension of the finite element space
int space_dimension() const noexcept;

/// Block size of the finite element function space. For
/// BlockedElements, this is the number of DOFs
/// colocated at each DOF point. For other elements, this is always 1.
/// @brief Block size of the finite element function space.
///
/// For BlockedElements, this is the number of DOFs colocated at each
/// DOF point. For other elements, this is always 1.
/// @return Block size of the finite element space
int block_size() const noexcept;

/// The value size, e.g. 1 for a scalar function, 2 for a 2D vector, 9
/// for a second-order tensor in 3D, for the reference element
/// @return The value size for the reference element
/// @brief Value size.
///
/// The value size is the product of the value shape, e.g. is is 1
/// for a scalar function, 2 for a 2D vector, 9 for a second-order
/// tensor in 3D.
/// @throws Exception is thrown for a mixed element as mixed elements
/// do not have a value shape.
/// @return The value size.
int reference_value_size() const;

/// The reference value shape
std::span<const std::size_t> reference_value_shape() const noexcept;
/// @brief Value shape.
///
/// The value shape described the shape of the finite element field,
/// e.g. {} for a scalar, {3, 3} for a tensor in 3D. Mixed elements do
/// not have a value shape.
/// @throws Exception is thrown for a mixed element as mixed elements
/// do not have a value shape.
/// @return The value shape.
std::span<const std::size_t> reference_value_shape() const;

/// The local DOFs associated with each subentity of the cell
const std::vector<std::vector<std::vector<int>>>&
Expand Down Expand Up @@ -324,9 +338,8 @@ class FiniteElement

if (!_sub_elements.empty())
{
if (_is_mixed)
if (!_reference_value_shape) // Mixed element
{
// Mixed element
std::vector<std::function<void(
std::span<U>, std::span<const std::uint32_t>, std::int32_t, int)>>
sub_element_fns;
Expand Down Expand Up @@ -426,11 +439,10 @@ class FiniteElement
// Do nothing
};
}
else if (_sub_elements.size() != 0)
else if (!_sub_elements.empty())
{
if (_is_mixed)
if (!_reference_value_shape) // Mixed element
{
// Mixed element
std::vector<std::function<void(
std::span<U>, std::span<const std::uint32_t>, std::int32_t, int)>>
sub_element_fns;
Expand Down Expand Up @@ -724,16 +736,14 @@ class FiniteElement
std::vector<std::shared_ptr<const FiniteElement<geometry_type>>>
_sub_elements;

// Dimension of each value space
std::vector<std::size_t> _reference_value_shape;
// Value space shape, e.g. {} for a scalar, {3, 3} for a tensor in 3D.
// For a mixed element it is std::nullopt.
std::optional<std::vector<std::size_t>> _reference_value_shape;

// Block size for BlockedElements. This gives the number of DOFs
// co-located at each dof 'point'.
int _bs;

// Indicate whether this is a mixed element
bool _is_mixed;

// Basix Element (nullptr for mixed elements)
std::unique_ptr<basix::FiniteElement<geometry_type>> _element;

Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/fem/test_dofmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_block_size():
V = functionspace(mesh, mixed_element([P2, P2]))
assert V.dofmap.index_map_bs == 1

for i in range(1, 6):
for i in range(2, 6):
W = functionspace(mesh, mixed_element(i * [P2]))
assert W.dofmap.index_map_bs == 1

Expand Down
2 changes: 0 additions & 2 deletions python/test/unit/fem/test_mixed_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def test_mixed_element(rank, family, cell, degree):
A.scatter_reverse()
norms.append(A.squared_norm())

U_el = mixed_element([U_el])

for i in norms[1:]:
assert np.isclose(norms[0], i)

Expand Down

0 comments on commit b40ab4d

Please sign in to comment.