From 4e32e8617ac9f08938af4d66455453bc37fd45e3 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Sat, 15 Jul 2023 08:17:51 +0100 Subject: [PATCH 1/4] Ksagiyam/fix repr withmapping (#182) * WithMapping: fix repr * TensorProductCell: fix reconstruct --- test/test_cell.py | 6 ++++++ test/test_elements.py | 6 ++++++ ufl/cell.py | 2 +- ufl/finiteelement/hdivcurl.py | 2 +- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/test_cell.py b/test/test_cell.py index 0be9b7a7b..21a5240fa 100644 --- a/test/test_cell.py +++ b/test/test_cell.py @@ -59,3 +59,9 @@ def test_cells_2d(cell): assert cell.num_peaks() == cell.num_vertices() +def test_tensorproductcell(): + orig = ufl.TensorProductCell(ufl.interval, ufl.interval) + cell = orig.reconstruct() + assert cell.sub_cells() == orig.sub_cells() + assert cell.topological_dimension() == orig.topological_dimension() + assert cell.geometric_dimension() == orig.geometric_dimension() diff --git a/test/test_elements.py b/test/test_elements.py index a83c24f69..49a700cd7 100755 --- a/test/test_elements.py +++ b/test/test_elements.py @@ -229,3 +229,9 @@ def test_mse(): element = FiniteElement('GLL-Edge L2', interval, degree - 1) assert element == eval(repr(element)) + + +def test_withmapping(): + base = FiniteElement("CG", interval, 1) + element = WithMapping(base, "identity") + assert element == eval(repr(element)) diff --git a/ufl/cell.py b/ufl/cell.py index 625887a10..96150bc61 100644 --- a/ufl/cell.py +++ b/ufl/cell.py @@ -409,7 +409,7 @@ def reconstruct(self, **kwargs: typing.Any) -> Cell: gdim = value else: raise TypeError(f"reconstruct() got unexpected keyword argument '{key}'") - return TensorProductCell(self._cellname, geometric_dimension=gdim) + return TensorProductCell(*self._cells, geometric_dimension=gdim) def simplex(topological_dimension: int, geometric_dimension: typing.Optional[int] = None): diff --git a/ufl/finiteelement/hdivcurl.py b/ufl/finiteelement/hdivcurl.py index 6dda827bc..f6412ef23 100644 --- a/ufl/finiteelement/hdivcurl.py +++ b/ufl/finiteelement/hdivcurl.py @@ -118,7 +118,7 @@ def __getattr__(self, attr): (type(self).__name__, attr)) def __repr__(self): - return f"WithMapping({repr(self.wrapee)}, {self._mapping})" + return f"WithMapping({repr(self.wrapee)}, '{self._mapping}')" def value_shape(self): gdim = self.cell().geometric_dimension() From 83f340873822c0e95bc8bd19e6e3a7bb8ae55a16 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Tue, 8 Aug 2023 13:12:12 +0100 Subject: [PATCH 2/4] Fix flake8 checks with new flake8 release (#185) * fix flake * more flake --- test/test_complex.py | 8 ++++---- test/test_duals.py | 8 ++++---- test/test_new_ad.py | 5 +++-- test/test_signature.py | 3 +-- ufl/argument.py | 2 +- ufl/cell.py | 2 +- ufl/core/ufl_type.py | 4 ++-- ufl/exprequals.py | 2 +- ufl/finiteelement/finiteelementbase.py | 2 +- ufl/form.py | 4 ++-- ufl/operators.py | 15 ++++++++------- ufl/sobolevspace.py | 5 +++-- 12 files changed, 31 insertions(+), 29 deletions(-) diff --git a/test/test_complex.py b/test/test_complex.py index 4559457ef..72926d225 100755 --- a/test/test_complex.py +++ b/test/test_complex.py @@ -10,10 +10,10 @@ from ufl.algorithms import estimate_total_polynomial_degree from ufl.algorithms.comparison_checker import do_comparison_check, ComplexComparisonError from ufl.algorithms.formtransformations import compute_form_adjoint -from ufl import TestFunction, TrialFunction, triangle, FiniteElement, \ - as_ufl, inner, grad, dx, dot, outer, conj, sqrt, sin, cosh, \ - atan, ln, exp, as_tensor, real, imag, conditional, \ - min_value, max_value, gt, lt, cos, ge, le, Coefficient +from ufl import (TestFunction, TrialFunction, triangle, FiniteElement, + as_ufl, inner, grad, dx, dot, outer, conj, sqrt, sin, cosh, + atan, ln, exp, as_tensor, real, imag, conditional, + min_value, max_value, gt, lt, cos, ge, le, Coefficient) def test_conj(self): diff --git a/test/test_duals.py b/test/test_duals.py index 8fdf357bb..d29c62639 100644 --- a/test/test_duals.py +++ b/test/test_duals.py @@ -1,10 +1,10 @@ #!/usr/bin/env py.test # -*- coding: utf-8 -*- -from ufl import FiniteElement, FunctionSpace, MixedFunctionSpace, \ - Coefficient, Matrix, Cofunction, FormSum, Argument, Coargument,\ - TestFunction, TrialFunction, Adjoint, Action, \ - action, adjoint, derivative, tetrahedron, triangle, interval, dx +from ufl import (FiniteElement, FunctionSpace, MixedFunctionSpace, + Coefficient, Matrix, Cofunction, FormSum, Argument, Coargument, + TestFunction, TrialFunction, Adjoint, Action, + action, adjoint, derivative, tetrahedron, triangle, interval, dx) from ufl.constantvalue import Zero from ufl.form import ZeroBaseForm diff --git a/test/test_new_ad.py b/test/test_new_ad.py index 99022fd80..65b3dd158 100755 --- a/test/test_new_ad.py +++ b/test/test_new_ad.py @@ -8,8 +8,9 @@ from ufl.classes import Grad from ufl.algorithms import tree_format from ufl.algorithms.renumbering import renumber_indices -from ufl.algorithms.apply_derivatives import apply_derivatives, GenericDerivativeRuleset, \ - GradRuleset, VariableRuleset, GateauxDerivativeRuleset +from ufl.algorithms.apply_derivatives import ( + apply_derivatives, GenericDerivativeRuleset, + GradRuleset, VariableRuleset, GateauxDerivativeRuleset) # Note: the old tests in test_automatic_differentiation.py are a bit messy diff --git a/test/test_signature.py b/test/test_signature.py index dd2c494c4..96d6754e2 100755 --- a/test/test_signature.py +++ b/test/test_signature.py @@ -9,8 +9,7 @@ from ufl import * from ufl.classes import MultiIndex, FixedIndex -from ufl.algorithms.signature import compute_multiindex_hashdata, \ - compute_terminal_hashdata +from ufl.algorithms.signature import compute_multiindex_hashdata, compute_terminal_hashdata from itertools import chain diff --git a/ufl/argument.py b/ufl/argument.py index a29fcefda..74358d503 100644 --- a/ufl/argument.py +++ b/ufl/argument.py @@ -130,7 +130,7 @@ def __eq__(self, other): are the same ufl element but different dolfin function spaces. """ return ( - type(self) == type(other) and self._number == other._number and # noqa: W504 + type(self) is type(other) and self._number == other._number and # noqa: W504 self._part == other._part and self._ufl_function_space == other._ufl_function_space ) diff --git a/ufl/cell.py b/ufl/cell.py index 96150bc61..9ef8e9975 100644 --- a/ufl/cell.py +++ b/ufl/cell.py @@ -63,7 +63,7 @@ def reconstruct(self, **kwargs: typing.Any) -> Cell: def __lt__(self, other: AbstractCell) -> bool: """Define an arbitrarily chosen but fixed sort order for all cells.""" - if type(self) == type(other): + if type(self) is type(other): s = (self.geometric_dimension(), self.topological_dimension()) o = (other.geometric_dimension(), other.topological_dimension()) if s != o: diff --git a/ufl/core/ufl_type.py b/ufl/core/ufl_type.py index 85b917e79..da86467d6 100644 --- a/ufl/core/ufl_type.py +++ b/ufl/core/ufl_type.py @@ -39,7 +39,7 @@ def __hash__(self) -> int: return hash(self._ufl_hash_data_()) def __eq__(self, other): - return type(self) == type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() + return type(self) is type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() def __ne__(self, other): return not self.__eq__(other) @@ -61,7 +61,7 @@ def __hash__(self): def __eq__(self, other): "__eq__ implementation attached in attach_operators_from_hash_data" - return type(self) == type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() + return type(self) is type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() cls.__eq__ = __eq__ def __ne__(self, other): diff --git a/ufl/exprequals.py b/ufl/exprequals.py index e3cdd3c38..dc00ae4fb 100644 --- a/ufl/exprequals.py +++ b/ufl/exprequals.py @@ -15,7 +15,7 @@ def expr_equals(self, other): # Fast cutoffs for common cases, type difference or hash # difference will cutoff more or less all nonequal types - if type(self) != type(other) or hash(self) != hash(other): + if type(self) is not type(other) or hash(self) != hash(other): return False # Large objects are costly to compare with themselves diff --git a/ufl/finiteelement/finiteelementbase.py b/ufl/finiteelement/finiteelementbase.py index 9e6aa58cf..c92178e6f 100644 --- a/ufl/finiteelement/finiteelementbase.py +++ b/ufl/finiteelement/finiteelementbase.py @@ -83,7 +83,7 @@ def __hash__(self): def __eq__(self, other): "Compute element equality for insertion in hashmaps." - return type(self) == type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() + return type(self) is type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() def __ne__(self, other): "Compute element inequality for insertion in hashmaps." diff --git a/ufl/form.py b/ufl/form.py index 1c7b8cf95..019c7c12f 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -438,7 +438,7 @@ def __ne__(self, other): def equals(self, other): "Evaluate ``bool(lhs_form == rhs_form)``." - if type(other) != Form: + if type(other) is not Form: return False if len(self._integrals) != len(other._integrals): return False @@ -754,7 +754,7 @@ def __hash__(self): def equals(self, other): "Evaluate ``bool(lhs_form == rhs_form)``." - if type(other) != FormSum: + if type(other) is not FormSum: return False if self is other: return True diff --git a/ufl/operators.py b/ufl/operators.py index 7b3c6171d..44dfea0c4 100644 --- a/ufl/operators.py +++ b/ufl/operators.py @@ -19,17 +19,18 @@ from ufl.form import Form from ufl.constantvalue import Zero, RealValue, ComplexValue, as_ufl from ufl.differentiation import VariableDerivative, Grad, Div, Curl, NablaGrad, NablaDiv -from ufl.tensoralgebra import Transposed, Inner, Outer, Dot, Cross, \ - Determinant, Inverse, Cofactor, Trace, Deviatoric, Skew, Sym +from ufl.tensoralgebra import ( + Transposed, Inner, Outer, Dot, Cross, + Determinant, Inverse, Cofactor, Trace, Deviatoric, Skew, Sym) from ufl.coefficient import Coefficient from ufl.variable import Variable from ufl.tensors import as_tensor, as_matrix, as_vector, ListTensor -from ufl.conditional import EQ, NE, \ - AndCondition, OrCondition, NotCondition, Conditional, MaxValue, MinValue +from ufl.conditional import ( + EQ, NE, AndCondition, OrCondition, NotCondition, Conditional, MaxValue, MinValue) from ufl.algebra import Conj, Real, Imag -from ufl.mathfunctions import Sqrt, Exp, Ln, Erf,\ - Cos, Sin, Tan, Cosh, Sinh, Tanh, Acos, Asin, Atan, Atan2,\ - BesselJ, BesselY, BesselI, BesselK +from ufl.mathfunctions import ( + Sqrt, Exp, Ln, Erf, Cos, Sin, Tan, Cosh, Sinh, Tanh, Acos, Asin, Atan, Atan2, + BesselJ, BesselY, BesselI, BesselK) from ufl.averaging import CellAvg, FacetAvg from ufl.indexed import Indexed from ufl.geometry import SpatialCoordinate, FacetNormal diff --git a/ufl/sobolevspace.py b/ufl/sobolevspace.py index fb6056c40..f6a624d14 100644 --- a/ufl/sobolevspace.py +++ b/ufl/sobolevspace.py @@ -111,8 +111,9 @@ def __init__(self, orders): the position denotes in what spatial variable the smoothness requirement is enforced. """ - assert all(isinstance(x, int) or isinf(x) for x in orders), \ - ("Order must be an integer or infinity.") + assert all( + isinstance(x, int) or isinf(x) + for x in orders), "Order must be an integer or infinity." name = "DirectionalH" parents = [L2] super(DirectionalSobolevSpace, self).__init__(name, parents) From b0d635a2315bba4530d90bff40966d3705631fd1 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Tue, 8 Aug 2023 13:31:50 +0100 Subject: [PATCH 3/4] Add Counted mixin class and refactor form signature computation (#178) * Add Counted mixin class and refactor form signature computation * fixup * fixup --------- Co-authored-by: Matthew Scroggs --- ufl/algorithms/signature.py | 8 +---- ufl/coefficient.py | 13 +++------ ufl/constant.py | 10 ++----- ufl/core/multiindex.py | 13 +++------ ufl/form.py | 58 ++++++++++++++++++++++++++++++------- ufl/matrix.py | 11 +++---- ufl/utils/counted.py | 45 ++++++++++++++-------------- ufl/variable.py | 16 +++++----- 8 files changed, 91 insertions(+), 83 deletions(-) diff --git a/ufl/algorithms/signature.py b/ufl/algorithms/signature.py index ab9690bd6..4650818c1 100644 --- a/ufl/algorithms/signature.py +++ b/ufl/algorithms/signature.py @@ -43,7 +43,6 @@ def compute_terminal_hashdata(expressions, renumbering): # arguments, and just take repr of the rest of the terminals while # we're iterating over them terminal_hashdata = {} - labels = {} index_numbering = {} for expression in expressions: for expr in traverse_unique_terminals(expression): @@ -69,12 +68,7 @@ def compute_terminal_hashdata(expressions, renumbering): data = expr._ufl_signature_data_(renumbering) elif isinstance(expr, Label): - # Numbering labels as we visit them # TODO: Include in - # renumbering - data = labels.get(expr) - if data is None: - data = "L%d" % len(labels) - labels[expr] = data + data = expr._ufl_signature_data_(renumbering) elif isinstance(expr, ExprList): # Not really a terminal but can have 0 operands... diff --git a/ufl/coefficient.py b/ufl/coefficient.py index d88743375..704c3912d 100644 --- a/ufl/coefficient.py +++ b/ufl/coefficient.py @@ -19,13 +19,13 @@ from ufl.functionspace import AbstractFunctionSpace, FunctionSpace, MixedFunctionSpace from ufl.form import BaseForm from ufl.split_functions import split -from ufl.utils.counted import counted_init +from ufl.utils.counted import Counted from ufl.duals import is_primal, is_dual # --- The Coefficient class represents a coefficient in a form --- -class BaseCoefficient(object): +class BaseCoefficient(Counted): """UFL form argument type: Parent Representation of a form coefficient.""" # Slots are disabled here because they cause trouble in PyDOLFIN @@ -33,14 +33,13 @@ class BaseCoefficient(object): # __slots__ = ("_count", "_ufl_function_space", "_repr", "_ufl_shape") _ufl_noslots_ = True __slots__ = () - _globalcount = 0 _ufl_is_abstract_ = True def __getnewargs__(self): return (self._ufl_function_space, self._count) def __init__(self, function_space, count=None): - counted_init(self, count, Coefficient) + Counted.__init__(self, count, Coefficient) if isinstance(function_space, FiniteElementBase): # For legacy support for .ufl files using cells, we map @@ -57,9 +56,6 @@ def __init__(self, function_space, count=None): self._repr = "BaseCoefficient(%s, %s)" % ( repr(self._ufl_function_space), repr(self._count)) - def count(self): - return self._count - @property def ufl_shape(self): "Return the associated UFL shape." @@ -111,6 +107,7 @@ class Cofunction(BaseCoefficient, BaseForm): __slots__ = ( "_count", + "_counted_class", "_arguments", "_ufl_function_space", "ufl_operands", @@ -118,7 +115,6 @@ class Cofunction(BaseCoefficient, BaseForm): "_ufl_shape", "_hash" ) - # _globalcount = 0 _primal = False _dual = True @@ -161,7 +157,6 @@ class Coefficient(FormArgument, BaseCoefficient): """UFL form argument type: Representation of a form coefficient.""" _ufl_noslots_ = True - _globalcount = 0 _primal = True _dual = False diff --git a/ufl/constant.py b/ufl/constant.py index 66e7a8f33..2819647e3 100644 --- a/ufl/constant.py +++ b/ufl/constant.py @@ -11,17 +11,16 @@ from ufl.core.ufl_type import ufl_type from ufl.core.terminal import Terminal from ufl.domain import as_domain -from ufl.utils.counted import counted_init +from ufl.utils.counted import Counted @ufl_type() -class Constant(Terminal): +class Constant(Terminal, Counted): _ufl_noslots_ = True - _globalcount = 0 def __init__(self, domain, shape=(), count=None): Terminal.__init__(self) - counted_init(self, count=count, countedclass=Constant) + Counted.__init__(self, count, Constant) self._ufl_domain = as_domain(domain) self._ufl_shape = shape @@ -31,9 +30,6 @@ def __init__(self, domain, shape=(), count=None): self._repr = "Constant({}, {}, {})".format( repr(self._ufl_domain), repr(self._ufl_shape), repr(self._count)) - def count(self): - return self._count - @property def ufl_shape(self): return self._ufl_shape diff --git a/ufl/core/multiindex.py b/ufl/core/multiindex.py index 9b9d15b8d..8d9c5dee6 100644 --- a/ufl/core/multiindex.py +++ b/ufl/core/multiindex.py @@ -10,7 +10,7 @@ # Modified by Massimiliano Leoni, 2016. -from ufl.utils.counted import counted_init +from ufl.utils.counted import Counted from ufl.core.ufl_type import ufl_type from ufl.core.terminal import Terminal @@ -70,20 +70,15 @@ def __repr__(self): return r -class Index(IndexBase): +class Index(IndexBase, Counted): """UFL value: An index with no value assigned. Used to represent free indices in Einstein indexing notation.""" - __slots__ = ("_count",) - - _globalcount = 0 + __slots__ = ("_count", "_counted_class") def __init__(self, count=None): IndexBase.__init__(self) - counted_init(self, count, Index) - - def count(self): - return self._count + Counted.__init__(self, count, Index) def __hash__(self): return hash(("Index", self._count)) diff --git a/ufl/form.py b/ufl/form.py index 019c7c12f..eeff85020 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -16,12 +16,15 @@ from itertools import chain from ufl.checks import is_scalar_constant_expression +from ufl.constant import Constant from ufl.constantvalue import Zero from ufl.core.expr import Expr, ufl_err_str from ufl.core.ufl_type import UFLType, ufl_type from ufl.domain import extract_unique_domain, sort_domains from ufl.equation import Equation from ufl.integral import Integral +from ufl.utils.counted import Counted +from ufl.utils.sorting import sorted_by_count # Export list for ufl.classes __all_classes__ = ["Form", "BaseForm", "ZeroBaseForm"] @@ -257,8 +260,9 @@ class Form(BaseForm): "_arguments", "_coefficients", "_coefficient_numbering", - "_constant_numbering", "_constants", + "_constant_numbering", + "_terminal_numbering", "_hash", "_signature", # --- Dict that external frameworks can place framework-specific @@ -289,11 +293,10 @@ def __init__(self, integrals): self._coefficients = None self._coefficient_numbering = None self._constant_numbering = None + self._terminal_numbering = None from ufl.algorithms.analysis import extract_constants self._constants = extract_constants(self) - self._constant_numbering = dict( - (c, i) for i, c in enumerate(self._constants)) # Internal variables for caching of hash and signature after # first request @@ -406,8 +409,15 @@ def coefficients(self): def coefficient_numbering(self): """Return a contiguous numbering of coefficients in a mapping ``{coefficient:number}``.""" + # cyclic import + from ufl.coefficient import Coefficient + if self._coefficient_numbering is None: - self._analyze_form_arguments() + self._coefficient_numbering = { + expr: num + for expr, num in self.terminal_numbering().items() + if isinstance(expr, Coefficient) + } return self._coefficient_numbering def constants(self): @@ -416,8 +426,38 @@ def constants(self): def constant_numbering(self): """Return a contiguous numbering of constants in a mapping ``{constant:number}``.""" + if self._constant_numbering is None: + self._constant_numbering = { + expr: num + for expr, num in self.terminal_numbering().items() + if isinstance(expr, Constant) + } return self._constant_numbering + def terminal_numbering(self): + """Return a contiguous numbering for all counted objects in the form. + + The returned object is mapping from terminal to its number (an integer). + + The numbering is computed per type so :class:`Coefficient`s, + :class:`Constant`s, etc will each be numbered from zero. + + """ + # cyclic import + from ufl.algorithms.analysis import extract_type + + if self._terminal_numbering is None: + exprs_by_type = defaultdict(set) + for counted_expr in extract_type(self, Counted): + exprs_by_type[counted_expr._counted_class].add(counted_expr) + + numbering = {} + for exprs in exprs_by_type.values(): + for i, expr in enumerate(sorted_by_count(exprs)): + numbering[expr] = i + self._terminal_numbering = numbering + return self._terminal_numbering + def signature(self): "Signature for use with jit cache (independent of incidental numbering of indices etc.)" if self._signature is None: @@ -625,23 +665,19 @@ def _analyze_form_arguments(self): sorted(set(arguments), key=lambda x: x.number())) self._coefficients = tuple( sorted(set(coefficients), key=lambda x: x.count())) - self._coefficient_numbering = dict( - (c, i) for i, c in enumerate(self._coefficients)) def _compute_renumbering(self): # Include integration domains and coefficients in renumbering dn = self.domain_numbering() - cn = self.coefficient_numbering() - cnstn = self.constant_numbering() + tn = self.terminal_numbering() renumbering = {} renumbering.update(dn) - renumbering.update(cn) - renumbering.update(cnstn) + renumbering.update(tn) # Add domains of coefficients, these may include domains not # among integration domains k = len(dn) - for c in cn: + for c in self.coefficients(): d = extract_unique_domain(c) if d is not None and d not in renumbering: renumbering[d] = k diff --git a/ufl/matrix.py b/ufl/matrix.py index bd3a25d91..0b120f414 100644 --- a/ufl/matrix.py +++ b/ufl/matrix.py @@ -13,24 +13,24 @@ from ufl.core.ufl_type import ufl_type from ufl.argument import Argument from ufl.functionspace import AbstractFunctionSpace -from ufl.utils.counted import counted_init +from ufl.utils.counted import Counted # --- The Matrix class represents a matrix, an assembled two form --- @ufl_type() -class Matrix(BaseForm): +class Matrix(BaseForm, Counted): """An assemble linear operator between two function spaces.""" __slots__ = ( "_count", + "_counted_class", "_ufl_function_spaces", "ufl_operands", "_repr", "_hash", "_ufl_shape", "_arguments") - _globalcount = 0 def __getnewargs__(self): return (self._ufl_function_spaces[0], self._ufl_function_spaces[1], @@ -38,7 +38,7 @@ def __getnewargs__(self): def __init__(self, row_space, column_space, count=None): BaseForm.__init__(self) - counted_init(self, count, Matrix) + Counted.__init__(self, count, Matrix) if not isinstance(row_space, AbstractFunctionSpace): raise ValueError("Expecting a FunctionSpace as the row space.") @@ -52,9 +52,6 @@ def __init__(self, row_space, column_space, count=None): self._hash = None self._repr = f"Matrix({self._ufl_function_spaces[0]!r}, {self._ufl_function_spaces[1]!r}, {self._count!r})" - def count(self): - return self._count - def ufl_function_spaces(self): "Get the tuple of function spaces of this coefficient." return self._ufl_function_spaces diff --git a/ufl/utils/counted.py b/ufl/utils/counted.py index 66d3fd79f..f04bc2aca 100644 --- a/ufl/utils/counted.py +++ b/ufl/utils/counted.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"Utilites for types with a global unique counter attached to each object." +"Mixin class for types with a global unique counter attached to each object." # Copyright (C) 2008-2016 Martin Sandve Alnæs # @@ -7,37 +7,34 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later +import itertools -def counted_init(self, count=None, countedclass=None): - "Initialize a counted object, see ExampleCounted below for how to use." - if countedclass is None: - countedclass = type(self) +class Counted: + """Mixin class for globally counted objects.""" - if count is None: - count = countedclass._globalcount + # Mixin classes do not work well with __slots__ so _count must be + # added to the __slots__ of the inheriting class + __slots__ = () - self._count = count + _counter = None - if self._count >= countedclass._globalcount: - countedclass._globalcount = self._count + 1 + def __init__(self, count=None, counted_class=None): + """Initialize the Counted instance. + :arg count: The object count, if ``None`` defaults to the next value + according to the global counter (per type). + :arg counted_class: Class to attach the global counter too. If ``None`` + then ``type(self)`` will be used. -class ExampleCounted(object): - """An example class for classes of objects identified by a global counter. + """ + # create a new counter for each subclass + counted_class = counted_class or type(self) + if counted_class._counter is None: + counted_class._counter = itertools.count() - Mimic this class to create globally counted objects within a single type. - """ - # Store the count for each object - __slots__ = ("_count",) + self._count = count if count is not None else next(counted_class._counter) + self._counted_class = counted_class - # Store a global counter with the class - _globalcount = 0 - - # Call counted_init with an optional constructor argument and the class - def __init__(self, count=None): - counted_init(self, count, ExampleCounted) - - # Make the count accessible def count(self): return self._count diff --git a/ufl/variable.py b/ufl/variable.py index 1b68d6c06..b5a29dc26 100644 --- a/ufl/variable.py +++ b/ufl/variable.py @@ -8,7 +8,7 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later -from ufl.utils.counted import counted_init +from ufl.utils.counted import Counted from ufl.core.expr import Expr from ufl.core.ufl_type import ufl_type from ufl.core.terminal import Terminal @@ -17,17 +17,12 @@ @ufl_type() -class Label(Terminal): - __slots__ = ("_count",) - - _globalcount = 0 +class Label(Terminal, Counted): + __slots__ = ("_count", "_counted_class") def __init__(self, count=None): Terminal.__init__(self) - counted_init(self, count, Label) - - def count(self): - return self._count + Counted.__init__(self, count, Label) def __str__(self): return "Label(%d)" % self._count @@ -55,6 +50,9 @@ def ufl_domains(self): "Return tuple of domains related to this terminal object." return () + def _ufl_signature_data_(self, renumbering): + return ("Label", renumbering[self]) + @ufl_type(is_shaping=True, is_index_free=True, num_ops=1, inherit_shape_from_operand=0) class Variable(Operator): From a49736c7e4372b9bac61d728c223c8b6784b25a3 Mon Sep 17 00:00:00 2001 From: Matthew Scroggs Date: Wed, 9 Aug 2023 16:13:24 +0100 Subject: [PATCH 4/4] avoid error if Label not in renumbering (#187) --- ufl/variable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ufl/variable.py b/ufl/variable.py index b5a29dc26..935c2f8c8 100644 --- a/ufl/variable.py +++ b/ufl/variable.py @@ -51,6 +51,8 @@ def ufl_domains(self): return () def _ufl_signature_data_(self, renumbering): + if self not in renumbering: + return ("Label", self._count) return ("Label", renumbering[self])