diff --git a/cpp/dolfinx/fem/DirichletBC.h b/cpp/dolfinx/fem/DirichletBC.h index e0aa9f29c4..83e8994a36 100644 --- a/cpp/dolfinx/fem/DirichletBC.h +++ b/cpp/dolfinx/fem/DirichletBC.h @@ -1,5 +1,5 @@ -// Copyright (C) 2007-2021 Michal Habera, Anders Logg, Garth N. Wells -// and Jørgen S.Dokken +// Copyright (C) 2007-2024 Michal Habera, Anders Logg, Garth N. Wells, Jørgen +// S.Dokken and Paul T. Kühner // // This file is part of DOLFINx (https://www.fenicsproject.org) // @@ -300,12 +300,11 @@ class DirichletBC /// @note The size of of `g` must be equal to the block size if `V`. /// Use the Function version if this is not the case, e.g. for some /// mixed spaces. - template - or std::is_convertible_v>>> - requires std::is_convertible_v, - std::vector> + template + requires(std::is_convertible_v + || std::is_convertible_v>) + && std::is_convertible_v, + std::vector> DirichletBC(const S& g, X&& dofs, std::shared_ptr> V) : DirichletBC(std::make_shared>(g), dofs, V) { @@ -497,76 +496,86 @@ class DirichletBC void set(std::span x, std::optional> x0, T alpha = 1) const { - std::int32_t x_size = x.size(); - if (alpha == T(0)) // Optimisation for when alpha == 0 + // set_fn is a lambda which gets evaluated for every index in [0, + // _dofs0.size()) and its result is assigned to x[_dofs0[i]]. + auto apply = [&](std::invocable auto set_fn) { - for (std::int32_t idx : _dofs0) + static_assert( + std::is_same_v, + T>); + + std::int32_t x_size = x.size(); + for (std::size_t i = 0; i < _dofs0.size(); ++i) { - if (idx < x_size) - x[idx] = 0; + if (_dofs0[i] < x_size) + x[_dofs0[i]] = set_fn(i); } + }; + + if (alpha == T(0)) // Optimisation for when alpha == 0 + { + apply([](std::int32_t) -> T { return 0; }); + return; } - else + + if (std::holds_alternative>>(_g)) { - if (std::holds_alternative>>(_g)) + auto g = std::get>>(_g); + assert(g); + std::span values = g->x()->array(); + + // Extract degrees of freedom associated with g. If g is in a collapsed + // sub-space, get the dofs in this space, otherwise the degrees of g is + // the same as for x + auto dofs_g = _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g); + + if (x0) { - auto g = std::get>>(_g); - assert(g); - auto dofs1_g - = _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g); - std::span values = g->x()->array(); - if (x0) - { - std::span _x0 = x0.value(); - assert(x.size() <= _x0.size()); - for (std::size_t i = 0; i < _dofs0.size(); ++i) - { - if (_dofs0[i] < x_size) + assert(x.size() <= x0->size()); + apply( + [dofs_g, x0 = *x0, alpha, values, + &dofs0 = this->_dofs0](std::int32_t i) -> T { - assert(dofs1_g[i] < (std::int32_t)values.size()); - x[_dofs0[i]] = alpha * (values[dofs1_g[i]] - _x0[_dofs0[i]]); - } - } - } - else - { - for (std::size_t i = 0; i < _dofs0.size(); ++i) - { - if (_dofs0[i] < x_size) + assert(dofs_g[i] < static_cast(values.size())); + return alpha * (values[dofs_g[i]] - x0[dofs0[i]]); + }); + } + else + { + apply( + [dofs_g, values, alpha](std::int32_t i) -> T { - assert(dofs1_g[i] < (std::int32_t)values.size()); - x[_dofs0[i]] = alpha * values[dofs1_g[i]]; - } - } - } + assert(dofs_g[i] < static_cast(values.size())); + return alpha * values[dofs_g[i]]; + }); } - else if (std::holds_alternative>>(_g)) + } + else if (std::holds_alternative>>(_g)) + { + auto g = std::get>>(_g); + const std::vector& value = g->value; + std::int32_t bs = _function_space->dofmap()->bs(); + if (x0) { - auto g = std::get>>(_g); - const std::vector& value = g->value; - std::int32_t bs = _function_space->dofmap()->bs(); - if (x0) - { - assert(x.size() <= x0.value().size()); - std::ranges::for_each( - _dofs0, - [x_size, &x, x0 = x0.value(), &value, alpha, bs](auto dof) - { - if (dof < x_size) - x[dof] = alpha * (value[dof % bs] - x0[dof]); - }); - } - else - { - std::ranges::for_each(_dofs0, - [x_size, bs, alpha, &value, &x](auto dof) - { - if (dof < x_size) - x[dof] = alpha * value[dof % bs]; - }); - } + assert(x.size() <= x0->size()); + apply( + [x0 = *x0, alpha, bs, &value, &dofs0 = _dofs0](std::int32_t i) -> T + { + auto dof = dofs0[i]; + return alpha * (value[dof % bs] - x0[dof]); + }); + } + else + { + apply([alpha, bs, &value, &dofs0 = _dofs0](std::int32_t i) -> T + { return alpha * value[dofs0[i] % bs]; }); } } + else + { + // replace with std::unreachable once C++23 is supported + assert(false); + } } /// @brief Set `markers[i] = true` if dof `i` has a boundary condition