Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify code path of DirichletBC::set #3505

Merged
merged 21 commits into from
Nov 15, 2024
Merged
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 84 additions & 65 deletions cpp/dolfinx/fem/DirichletBC.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (C) 2007-2021 Michal Habera, Anders Logg, Garth N. Wells
// and Jørgen S.Dokken
// Copyright (C) 2007-2024 Michal Habera, Anders Logg, Garth N. Wells, Jørgen
// S.Dokken and Paul T. Kühner
//
// This file is part of DOLFINx (https://www.fenicsproject.org)
//
Expand Down Expand Up @@ -246,6 +246,20 @@ std::array<std::vector<std::int32_t>, 2> locate_dofs_geometrical(
return dofs;
}

namespace
{
// To be used with std::variant, compare
// https://en.cppreference.com/w/cpp/utility/variant/visit
template <class... Ts>
struct overloaded : Ts...
{
using Ts::operator()...;
};
// explicit deduction guide
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
} // namespace

/// Object for setting (strong) Dirichlet boundary conditions
/// \f[u = g \ \text{on} \ G,\f]
/// where \f$u\f$ is the solution to be computed, \f$g\f$ is a function
Expand Down Expand Up @@ -300,12 +314,11 @@ class DirichletBC
/// @note The size of of `g` must be equal to the block size if `V`.
/// Use the Function version if this is not the case, e.g. for some
/// mixed spaces.
template <typename S, typename X,
typename
= std::enable_if_t<std::is_convertible_v<S, T>
or std::is_convertible_v<S, std::span<const T>>>>
template <typename S, typename X>
requires std::is_convertible_v<std::remove_cvref_t<X>,
std::vector<std::int32_t>>
&& (std::is_convertible_v<S, T>
|| std::is_convertible_v<S, std::span<const T>>)
DirichletBC(const S& g, X&& dofs, std::shared_ptr<const FunctionSpace<U>> V)
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
: DirichletBC(std::make_shared<Constant<T>>(g), dofs, V)
{
Expand Down Expand Up @@ -495,76 +508,82 @@ class DirichletBC
void set(std::span<T> x, std::optional<std::span<const T>> x0,
T alpha = 1) const
{
std::int32_t x_size = x.size();
if (alpha == T(0)) // Optimisation for when alpha == 0
// set_fn is a lambda which gets evaluated for every index in [0,
// _dofs0.size()) and its result is assigned to x[_dofs0[i]].
auto apply = [&](std::invocable<std::int32_t> auto set_fn)
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
for (std::int32_t idx : _dofs0)
static_assert(
std::is_same_v<std::invoke_result_t<decltype(set_fn), std::int32_t>,
T>);

std::int32_t x_size = x.size();
for (std::int32_t i = 0; i < _dofs0.size(); ++i)
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
if (idx < x_size)
x[idx] = 0;
if (_dofs0[i] < x_size)
x[_dofs0[i]] = set_fn(i);
}
};

if (alpha == T(0)) // Optimisation for when alpha == 0
{
apply([](std::int32_t) -> T { return 0; });
return;
}
else

auto handle_function = [&](std::shared_ptr<const Function<T, U>> g)
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
if (std::holds_alternative<std::shared_ptr<const Function<T, U>>>(_g))
assert(g);
std::span<const T> values = g->x()->array();

// Extract degrees of freedom associated with g. If g is in a collapsed
// sub-space, get the dofs in this space, otherwise the degrees of g is
// the same as for x
auto dofs_g = _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g);

if (x0.has_value())
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
auto g = std::get<std::shared_ptr<const Function<T, U>>>(_g);
assert(g);
auto dofs1_g
= _dofs1_g.empty() ? std::span(_dofs0) : std::span(_dofs1_g);
std::span<const T> values = g->x()->array();
if (x0.has_value())
{
std::span<const T> _x0 = x0.value();
assert(x.size() <= _x0.size());
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
if (_dofs0[i] < x_size)
std::span<const T> _x0 = x0.value();
assert(x.size() <= _x0.size());
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
apply(
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
[&](std::int32_t i) -> T
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = alpha * (values[dofs1_g[i]] - _x0[_dofs0[i]]);
}
}
}
else
{
for (std::size_t i = 0; i < _dofs0.size(); ++i)
{
if (_dofs0[i] < x_size)
assert(dofs_g[i] < static_cast<std::int32_t>(values.size()));
return alpha * (values[dofs_g[i]] - _x0[_dofs0[i]]);
});
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
apply(
[&](std::int32_t i) -> T
{
assert(dofs1_g[i] < (std::int32_t)values.size());
x[_dofs0[i]] = alpha * values[dofs1_g[i]];
}
}
}
assert(dofs_g[i] < static_cast<std::int32_t>(values.size()));
return alpha * values[dofs_g[i]];
});
}
else if (std::holds_alternative<std::shared_ptr<const Constant<T>>>(_g))
};

auto handle_constant = [&](std::shared_ptr<const Constant<T>> g)
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<T>& value = g->value;
std::int32_t bs = _function_space->dofmap()->bs();
if (x0.has_value())
{
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
auto g = std::get<std::shared_ptr<const Constant<T>>>(_g);
const std::vector<T>& value = g->value;
std::int32_t bs = _function_space->dofmap()->bs();
if (x0.has_value())
{
assert(x.size() <= x0.value().size());
std::ranges::for_each(
_dofs0,
[x_size, &x, x0 = x0.value(), &value, alpha, bs](auto dof)
{
if (dof < x_size)
x[dof] = alpha * (value[dof % bs] - x0[dof]);
});
}
else
{
std::ranges::for_each(_dofs0,
[x_size, bs, alpha, &value, &x](auto dof)
{
if (dof < x_size)
x[dof] = alpha * value[dof % bs];
});
}
assert(x.size() <= x0.value().size());
apply(
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
[&](std::int32_t i) -> T
{
auto dof = _dofs0[i];
return alpha * (value[dof % bs] - x0.value()[dof]);
});
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved
}
}
else
{
apply([&](std::int32_t i) -> T
{ return alpha * value[_dofs0[i] % bs]; });
}
};

std::visit(overloaded{handle_function, handle_constant}, _g);
}
schnellerhase marked this conversation as resolved.
Show resolved Hide resolved

/// @brief Set `markers[i] = true` if dof `i` has a boundary condition
Expand Down
Loading