Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify code path of DirichletBC::set #3505

Merged
merged 21 commits into from
Nov 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 75 additions & 66 deletions cpp/dolfinx/fem/DirichletBC.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (C) 2007-2021 Michal Habera, Anders Logg, Garth N. Wells
// and Jørgen S.Dokken
// Copyright (C) 2007-2024 Michal Habera, Anders Logg, Garth N. Wells, Jørgen
// S.Dokken and Paul T. Kühner
//
// This file is part of DOLFINx (https://www.fenicsproject.org)
//
Expand Down Expand Up @@ -300,12 +300,11 @@ class DirichletBC
/// @note The size of of `g` must be equal to the block size if `V`.
/// Use the Function version if this is not the case, e.g. for some
/// mixed spaces.
template <typename S, typename X,
typename
= std::enable_if_t<std::is_convertible_v<S, T>
or std::is_convertible_v<S, std::span<const T>>>>
requires std::is_convertible_v<std::remove_cvref_t<X>,
std::vector<std::int32_t>>
template <typename S, typename X>
requires(std::is_convertible_v<S, T>
|| std::is_convertible_v<S, std::span<const T>>)
&& std::is_convertible_v<std::remove_cvref_t<X>,
std::vector<std::int32_t>>
DirichletBC(const S& g, X&& dofs, std::shared_ptr<const FunctionSpace<U>> V)
: DirichletBC(std::make_shared<Constant<T>>(g), dofs, V)
{
Expand Down Expand Up @@ -497,76 +496,86 @@ class DirichletBC
void set(std::span<T> x, std::optional<std::span<const T>> x0,
T alpha = 1) const
{
std::int32_t x_size = x.size();
if (alpha == T(0)) // Optimisation for when alpha == 0
// set_fn is a lambda which gets evaluated for every index in [0,
// _dofs0.size()) and its result is assigned to x[_dofs0[i]].
auto apply = [&](std::invocable<std::int32_t> auto set_fn)
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
for (std::int32_t idx : _dofs0)
static_assert(
std::is_same_v<std::invoke_result_t<decltype(set_fn), std::int32_t>,
T>);

std::int32_t x_size = x.size();
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
if (idx < x_size)
x[idx] = 0;
if (_dofs0[i] < x_size)
x[_dofs0[i]] = set_fn(i);
}
};

if (alpha == T(0)) // Optimisation for when alpha == 0
{
apply([](std::int32_t) -> T { return 0; });
return;
}
else

if (std::holds_alternative<std::shared_ptr<const Function<T, U>>>(_g))
{
if (std::holds_alternative<std::shared_ptr<const Function<T, U>>>(_g))
auto g = std::get<std::shared_ptr<const Function<T, U>>>(_g);
assert(g);
std::span<const T> values = g->x()->array();

// Extract degrees of freedom associated with g. If g is in a collapsed
// sub-space, get the dofs in this space, otherwise the degrees of g is
// the same as for x
auto dofs_g = _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g);

if (x0)
{
auto g = std::get<std::shared_ptr<const Function<T, U>>>(_g);
assert(g);
auto dofs1_g
= _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g);
std::span<const T> values = g->x()->array();
if (x0)
{
std::span<const T> _x0 = x0.value();
assert(x.size() <= _x0.size());
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
if (_dofs0[i] < x_size)
assert(x.size() <= x0->size());
apply(
[dofs_g, x0 = *x0, alpha, values,
&dofs0 = this->_dofs0](std::int32_t i) -> T
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = alpha * (values[dofs1_g[i]] - _x0[_dofs0[i]]);
}
}
}
else
{
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
if (_dofs0[i] < x_size)
assert(dofs_g[i] < static_cast<std::int32_t>(values.size()));
return alpha * (values[dofs_g[i]] - x0[dofs0[i]]);
});
}
else
{
apply(
[dofs_g, values, alpha](std::int32_t i) -> T
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = alpha * values[dofs1_g[i]];
}
}
}
assert(dofs_g[i] < static_cast<std::int32_t>(values.size()));
return alpha * values[dofs_g[i]];
});
}
else if (std::holds_alternative<std::shared_ptr<const Constant<T>>>(_g))
}
else if (std::holds_alternative<std::shared_ptr<const Constant<T>>>(_g))
{
auto g = std::get<std::shared_ptr<const Constant<T>>>(_g);
const std::vector<T>& value = g->value;
std::int32_t bs = _function_space->dofmap()->bs();
if (x0)
{
auto g = std::get<std::shared_ptr<const Constant<T>>>(_g);
const std::vector<T>& value = g->value;
std::int32_t bs = _function_space->dofmap()->bs();
if (x0)
{
assert(x.size() <= x0.value().size());
std::ranges::for_each(
_dofs0,
[x_size, &x, x0 = x0.value(), &value, alpha, bs](auto dof)
{
if (dof < x_size)
x[dof] = alpha * (value[dof % bs] - x0[dof]);
});
}
else
{
std::ranges::for_each(_dofs0,
[x_size, bs, alpha, &value, &x](auto dof)
{
if (dof < x_size)
x[dof] = alpha * value[dof % bs];
});
}
assert(x.size() <= x0->size());
apply(
[x0 = *x0, alpha, bs, &value, &dofs0 = _dofs0](std::int32_t i) -> T
{
auto dof = dofs0[i];
return alpha * (value[dof % bs] - x0[dof]);
});
}
else
{
apply([alpha, bs, &value, &dofs0 = _dofs0](std::int32_t i) -> T
{ return alpha * value[dofs0[i] % bs]; });
}
}
else
{
// replace with std::unreachable once C++23 is supported
assert(false);
}
}

/// @brief Set `markers[i] = true` if dof `i` has a boundary condition
Expand Down
Loading