Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax assumption on BaseFormOperator's dual argument slot #283

Merged
merged 11 commits into from
Jul 17, 2024
42 changes: 36 additions & 6 deletions test/test_external_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from ufl import (
Action,
Argument,
Coargument,
Coefficient,
Constant,
Form,
FunctionSpace,
Matrix,
Mesh,
TestFunction,
TrialFunction,
Expand All @@ -21,6 +23,7 @@
derivative,
dx,
inner,
replace,
sin,
triangle,
)
Expand Down Expand Up @@ -266,7 +269,7 @@ def get_external_operators(form_base):
elif isinstance(form_base, BaseForm):
return form_base.base_form_operators()
else:
raise ValueError("Expecting FormBase argument!")
raise ValueError("Expecting BaseForm argument!")


def test_adjoint_action_jacobian(V1, V2, V3):
Expand Down Expand Up @@ -339,16 +342,17 @@ def vstar_N(number):
dFdu_adj = adjoint(dFdu)
dFdm_adj = adjoint(dFdm)

assert dFdu_adj.arguments() == (u_hat(n_arg),) + v_F
assert dFdm_adj.arguments() == (m_hat(n_arg),) + v_F
V = v_F[0].ufl_function_space()
assert dFdu_adj.arguments() == (TestFunction(V1), TrialFunction(V))
assert dFdm_adj.arguments() == (TestFunction(V2), TrialFunction(V))

# Action of the adjoint
q = Coefficient(v_F[0].ufl_function_space())
q = Coefficient(V)
action_dFdu_adj = action(dFdu_adj, q)
action_dFdm_adj = action(dFdm_adj, q)

assert action_dFdu_adj.arguments() == (u_hat(n_arg),)
assert action_dFdm_adj.arguments() == (m_hat(n_arg),)
assert action_dFdu_adj.arguments() == (TestFunction(V1),)
assert action_dFdm_adj.arguments() == (TestFunction(V2),)


def test_multiple_external_operators(V1, V2):
Expand Down Expand Up @@ -486,3 +490,29 @@ def test_multiple_external_operators(V1, V2):

dFdu = expand_derivatives(derivative(F, u))
assert dFdu == dFdu_partial + Action(dFdN1_partial, dN1du) + Action(dFdN5_partial, dN5du)


def test_replace(V1):
u = Coefficient(V1, count=0)
N = ExternalOperator(u, function_space=V1)

# dN(u; uhat, v*)
dN = expand_derivatives(derivative(N, u))
vstar, uhat = dN.arguments()
assert isinstance(vstar, Coargument)

# Replace v* by a Form
v = TestFunction(V1)
F = inner(u, v) * dx
G = replace(dN, {vstar: F})

dN_replaced = dN._ufl_expr_reconstruct_(u, argument_slots=(F, uhat))
assert G == dN_replaced

# Replace v* by an Action
M = Matrix(V1, V1)
A = Action(M, u)
G = replace(dN, {vstar: A})

dN_replaced = dN._ufl_expr_reconstruct_(u, argument_slots=(A, uhat))
assert G == dN_replaced
4 changes: 1 addition & 3 deletions ufl/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def _analyze_domains(self):
from ufl.domain import join_domains

# Collect domains
self._domains = join_domains(
chain.from_iterable(e.ufl_domains() for e in self.ufl_operands)
)
self._domains = join_domains(chain.from_iterable(e.ufl_domain() for e in self.ufl_operands))

def equals(self, other):
"""Check if two Actions are equal."""
Expand Down
6 changes: 5 additions & 1 deletion ufl/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def form(self):

def _analyze_form_arguments(self):
"""The arguments of adjoint are the reverse of the form arguments."""
self._arguments = self._form.arguments()[::-1]
reversed_args = self._form.arguments()[::-1]
# Canonical numbering for arguments that is consistent with other BaseForm objects.
self._arguments = tuple(
type(arg)(arg.ufl_function_space(), number=i) for i, arg in enumerate(reversed_args)
)
self._coefficients = self._form.coefficients()

def _analyze_domains(self):
Expand Down
4 changes: 2 additions & 2 deletions ufl/algorithms/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ufl.algorithms.analysis import has_exact_type
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.classes import CoefficientDerivative, Form
from ufl.classes import BaseForm, CoefficientDerivative
from ufl.constantvalue import as_ufl
from ufl.core.external_operator import ExternalOperator
from ufl.core.interpolate import Interpolate
Expand All @@ -28,7 +28,7 @@ def __init__(self, mapping):
# One can replace Coarguments by 1-Forms
def get_shape(x):
"""Get the shape of an object."""
if isinstance(x, Form):
if isinstance(x, BaseForm):
return x.arguments()[0].ufl_shape
return x.ufl_shape

Expand Down
11 changes: 9 additions & 2 deletions ufl/core/base_form_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,21 @@ def count(self):
@property
def ufl_shape(self):
"""Return the UFL shape of the coefficient.produced by the operator."""
return self.arguments()[0]._ufl_shape
arg, *_ = self.argument_slots()
if isinstance(arg, BaseForm):
arg, *_ = arg.arguments()
return arg._ufl_shape

def ufl_function_space(self):
"""Return the function space associated to the operator.

I.e. return the dual of the base form operator's Coargument.
"""
return self.arguments()[0]._ufl_function_space.dual()
arg, *_ = self.argument_slots()
if isinstance(arg, BaseForm):
arg, *_ = arg.arguments()
return arg._ufl_function_space
return arg._ufl_function_space.dual()

def _ufl_expr_reconstruct_(
self, *operands, function_space=None, derivatives=None, argument_slots=None
Expand Down
12 changes: 8 additions & 4 deletions ufl/core/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
#
# Modified by Nacime Bouziani, 2021-2022

from ufl.action import Action
from ufl.argument import Argument, Coargument
from ufl.coefficient import Cofunction
from ufl.constantvalue import as_ufl
from ufl.core.base_form_operator import BaseFormOperator
from ufl.core.ufl_type import ufl_type
from ufl.duals import is_dual
from ufl.form import Form
from ufl.form import BaseForm, Form
from ufl.functionspace import AbstractFunctionSpace


Expand All @@ -35,7 +36,7 @@ def __init__(self, expr, v):
defined on the dual of the FunctionSpace to interpolate into.
"""
# This check could be more rigorous.
dual_args = (Coargument, Cofunction, Form)
dual_args = (Coargument, Cofunction, Form, Action, BaseFormOperator)

if isinstance(v, AbstractFunctionSpace):
if is_dual(v):
Expand All @@ -53,8 +54,11 @@ def __init__(self, expr, v):
# Reversed order convention
argument_slots = (v, expr)
# Get the primal space (V** = V)
vv = v if not isinstance(v, Form) else v.arguments()[0]
function_space = vv.ufl_function_space().dual()
if isinstance(v, BaseForm):
arg, *_ = v.arguments()
function_space = arg.ufl_function_space()
else:
function_space = v.ufl_function_space().dual()
# Set the operand as `expr` for DAG traversal purpose.
operand = expr
BaseFormOperator.__init__(
Expand Down
Loading