Skip to content

Commit

Permalink
Purge now unnecessary OPVE and OPTE
Browse files Browse the repository at this point in the history
Now that we can build vector and tensor elements on top of existing
elements, we no longer need the unwieldy OuterProductVectorElement and
OuterProductTensorElement constructors.
  • Loading branch information
wence- committed Feb 12, 2016
1 parent 87af423 commit 3dd3e86
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 224 deletions.
2 changes: 1 addition & 1 deletion test/test_change_to_reference_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.compound_expressions import determinant_expr, cross_expr, inverse_expr
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorProductVectorElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement
'''


Expand Down
7 changes: 3 additions & 4 deletions ufl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
FiniteElement,
MixedElement, VectorElement, TensorElement
EnrichedElement, RestrictedElement,
TensorProductElement, TensorProductVectorElement, TensorProductTensorElement,
HDiv, HCurl
TensorProductElement,
HDivElement, HCurlElement
BrokenElement, TraceElement
FacetElement, InteriorElement
Expand Down Expand Up @@ -215,7 +215,6 @@
from ufl.finiteelement import FiniteElementBase, FiniteElement, \
MixedElement, VectorElement, TensorElement, EnrichedElement, \
RestrictedElement, TensorProductElement, \
TensorProductVectorElement, TensorProductTensorElement, \
HDivElement, HCurlElement, BrokenElement, TraceElement, \
FacetElement, InteriorElement

Expand Down Expand Up @@ -313,7 +312,7 @@
'FiniteElementBase', 'FiniteElement',
'MixedElement', 'VectorElement', 'TensorElement', 'EnrichedElement',
'RestrictedElement', 'TensorProductElement',
'TensorProductVectorElement', 'TensorProductTensorElement', 'HDivElement', 'HCurlElement',
'HDivElement', 'HCurlElement',
'BrokenElement', 'TraceElement', 'FacetElement', 'InteriorElement',
'register_element', 'show_elements',
'FunctionSpace',
Expand Down
4 changes: 2 additions & 2 deletions ufl/algorithms/apply_function_pullbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from ufl.tensors import as_tensor, as_vector

from ufl.finiteelement import (FiniteElement, EnrichedElement, VectorElement, MixedElement,
TensorProductElement, TensorProductVectorElement, TensorElement,
FacetElement, InteriorElement, BrokenElement, TraceElement)
TensorProductElement, FacetElement, InteriorElement,
BrokenElement, TraceElement)
from ufl.utils.sequences import product

def sub_elements_with_mappings(element):
Expand Down
2 changes: 1 addition & 1 deletion ufl/algorithms/change_to_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ufl.permutation import compute_indices

from ufl.compound_expressions import determinant_expr, cross_expr, inverse_expr
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorProductVectorElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement

from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks
from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering
Expand Down
4 changes: 0 additions & 4 deletions ufl/finiteelement/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
from ufl.finiteelement.enrichedelement import EnrichedElement
from ufl.finiteelement.restrictedelement import RestrictedElement
from ufl.finiteelement.tensorproductelement import TensorProductElement
from ufl.finiteelement.tensorproductelement import TensorProductVectorElement
from ufl.finiteelement.tensorproductelement import TensorProductTensorElement
from ufl.finiteelement.hdivcurl import HDivElement, HCurlElement
from ufl.finiteelement.brokenelement import BrokenElement
from ufl.finiteelement.traceelement import TraceElement
Expand All @@ -49,8 +47,6 @@
"EnrichedElement",
"RestrictedElement",
"TensorProductElement",
"TensorProductVectorElement",
"TensorProductTensorElement",
"HDivElement",
"HCurlElement",
"BrokenElement",
Expand Down
118 changes: 54 additions & 64 deletions ufl/finiteelement/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,67 @@ def __init__(self, family, cell=None, degree=None, shape=None, symmetry=None, qu
"Create tensor element (repeated mixed element with optional symmetries)"
if isinstance(family, FiniteElementBase):
sub_element = family
cell = sub_element.cell()
else:
if cell is not None:
cell = as_cell(cell)
sub_element = FiniteElement(family, cell, degree, quad_scheme)
ufl_assert(sub_element.value_shape() == (),
"Expecting only scalar valued subelement for TensorElement.")

shape, symmetry, sub_elements, sub_element_mapping, flattened_sub_element_mapping, \
reference_value_shape, mapping = _tensor_sub_elements(sub_element, shape, symmetry)
# Set default shape if not specified
if shape is None:
ufl_assert(sub_element.cell() is not None,
"Cannot infer tensor shape without a cell.")
dim = sub_element.cell().geometric_dimension()
shape = (dim, dim)

if symmetry is None:
symmetry = EmptyDict
elif symmetry is True:
# Construct default symmetry dict for matrix elements
ufl_assert(len(shape) == 2 and shape[0] == shape[1],
"Cannot set automatic symmetry for non-square tensor.")
symmetry = dict( ((i, j), (j, i)) for i in range(shape[0])
for j in range(shape[1]) if i > j )
else:
ufl_assert(isinstance(symmetry, dict), "Expecting symmetry to be None (unset), True, or dict.")

# Validate indices in symmetry dict
for i, j in iteritems(symmetry):
ufl_assert(len(i) == len(j),
"Non-matching length of symmetry index tuples.")
for k in range(len(i)):
ufl_assert(i[k] >= 0 and j[k] >= 0 and
i[k] < shape[k] and j[k] < shape[k],
"Symmetry dimensions out of bounds.")

# Compute all index combinations for given shape
indices = compute_indices(shape)

# Compute mapping from indices to sub element number, accounting for symmetry
sub_elements = []
sub_element_mapping = {}
for index in indices:
if index in symmetry:
continue
sub_element_mapping[index] = len(sub_elements)
sub_elements += [sub_element]

# Update mapping for symmetry
for index in indices:
if index in symmetry:
sub_element_mapping[index] = sub_element_mapping[symmetry[index]]
flattened_sub_element_mapping = [sub_element_mapping[index] for i, index in enumerate(indices)]

# Compute reference value shape based on symmetries
if symmetry:
# Flatten and subtract symmetries
reference_value_shape = (product(shape)-len(symmetry),)
mapping = "symmetries"
else:
# Do not flatten if there are no symmetries
reference_value_shape = shape
mapping = "identity"

# Initialize element data
MixedElement.__init__(self, sub_elements, value_shape=shape,
Expand Down Expand Up @@ -406,64 +457,3 @@ def shortstr(self):
sym = ""
return "Tensor<%s x %s%s>" % (self.value_shape(),
self._sub_element.shortstr(), sym)


def _tensor_sub_elements(sub_element, shape, symmetry):
# Set default shape if not specified
if shape is None:
ufl_assert(sub_element.cell() is not None,
"Cannot infer tensor shape without a cell.")
dim = sub_element.cell().geometric_dimension()
shape = (dim, dim)

if symmetry is None:
symmetry = EmptyDict
elif symmetry is True:
# Construct default symmetry dict for matrix elements
ufl_assert(len(shape) == 2 and shape[0] == shape[1],
"Cannot set automatic symmetry for non-square tensor.")
symmetry = dict( ((i, j), (j, i)) for i in range(shape[0])
for j in range(shape[1]) if i > j )
else:
ufl_assert(isinstance(symmetry, dict), "Expecting symmetry to be None (unset), True, or dict.")

# Validate indices in symmetry dict
for i, j in iteritems(symmetry):
ufl_assert(len(i) == len(j),
"Non-matching length of symmetry index tuples.")
for k in range(len(i)):
ufl_assert(i[k] >= 0 and j[k] >= 0 and
i[k] < shape[k] and j[k] < shape[k],
"Symmetry dimensions out of bounds.")

# Compute all index combinations for given shape
indices = compute_indices(shape)

# Compute mapping from indices to sub element number, accounting for symmetry
sub_elements = []
sub_element_mapping = {}
for index in indices:
if index in symmetry:
continue
sub_element_mapping[index] = len(sub_elements)
sub_elements += [sub_element]

# Update mapping for symmetry
for index in indices:
if index in symmetry:
sub_element_mapping[index] = sub_element_mapping[symmetry[index]]
flattened_sub_element_mapping = [sub_element_mapping[index] for i, index in enumerate(indices)]

# Compute reference value shape based on symmetries
if symmetry:
# Flatten and subtract symmetries
reference_value_shape = (product(shape)-len(symmetry),)
mapping = "symmetries"
else:
# Do not flatten if there are no symmetries
reference_value_shape = shape
mapping = "identity"


return shape, symmetry, sub_elements, sub_element_mapping, \
flattened_sub_element_mapping, reference_value_shape, mapping
148 changes: 0 additions & 148 deletions ufl/finiteelement/tensorproductelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from ufl.assertions import ufl_assert
from ufl.cell import TensorProductCell, as_cell
from ufl.finiteelement.mixedelement import MixedElement, _tensor_sub_elements
from ufl.finiteelement.finiteelementbase import FiniteElementBase


Expand Down Expand Up @@ -81,150 +80,3 @@ def shortstr(self):
"Short pretty-print."
return "TensorProductElement(%s)" \
% str([self._A.shortstr(), self._B.shortstr()])


class TensorProductVectorElement(MixedElement):
"""A special case of a mixed finite element where all
elements are equal TensorProductElements"""
__slots__ = ("_sub_element")

def __init__(self, *args, **kwargs):
if isinstance(args[0], TensorProductElement):
self._from_sub_element(*args, **kwargs)
else:
self._from_product_parts(*args, **kwargs)

def _from_product_parts(self, A, B, cell=None, dim=None):
sub_element = TensorProductElement(A, B, cell=cell)
self._from_sub_element(sub_element, dim=dim)

def _from_sub_element(self, sub_element, dim=None):
assert isinstance(sub_element, TensorProductElement)

dim = dim or sub_element.cell().geometric_dimension()
sub_elements = [sub_element]*dim

# Get common family name (checked in FiniteElement.__init__)
family = sub_element.family()

# Compute value shape
value_shape = (dim,) + sub_element.value_shape()

# Initialize element data
MixedElement.__init__(self, sub_elements, value_shape=value_shape)
self._family = family
self._degree = sub_element.degree()

self._sub_element = sub_element

# Cache repr string
self._repr = "TensorProductVectorElement(%r, dim=%d)" % \
(self._sub_element, len(self._sub_elements))

@property
def _A(self):
return self._sub_element._A

@property
def _B(self):
return self._sub_element._B

def mapping(self):
return self._sub_element.mapping()

def __str__(self):
"Format as string for pretty printing."
return "<Outer product vector element: %r x %r>" % \
(self._sub_element, self.num_sub_elements())

def shortstr(self):
"Format as string for pretty printing."
return "OPVector"


class TensorProductTensorElement(MixedElement):
"""A special case of a mixed finite element where all
elements are equal TensorProductElements"""
__slots__ = ("_sub_element", "_shape", "_symmetry",
"_sub_element_mapping", "_flattened_sub_element_mapping",
"_mapping")

def __init__(self, *args, **kwargs):
if isinstance(args[0], TensorProductElement):
self._from_sub_element(*args, **kwargs)
else:
self._from_product_parts(*args, **kwargs)

def _from_product_parts(self, A, B, cell=None,
shape=None, symmetry=None, quad_scheme=None):
sub_element = TensorProductElement(A, B, cell=cell,
quad_scheme=quad_scheme)
self._from_sub_element(sub_element, shape=shape, symmetry=symmetry)

def _from_sub_element(self, sub_element, shape=None, symmetry=None):
assert isinstance(sub_element, TensorProductElement)

shape, symmetry, sub_elements, sub_element_mapping, flattened_sub_element_mapping, \
reference_value_shape, mapping = _tensor_sub_elements(sub_element, shape, symmetry)

# Initialize element data
MixedElement.__init__(self, sub_elements, value_shape=shape,
reference_value_shape=reference_value_shape)
self._family = sub_element.family()
self._degree = sub_element.degree()
self._sub_element = sub_element
self._shape = shape
self._symmetry = symmetry
self._sub_element_mapping = sub_element_mapping
self._flattened_sub_element_mapping = flattened_sub_element_mapping
self._mapping = mapping

# Cache repr string
self._repr = "TensorProductTensorElement(%r, shape=%r, symmetry=%r)" % \
(self._sub_element, self._shape, self._symmetry)

@property
def _A(self):
return self._sub_element._A

@property
def _B(self):
return self._sub_element._B

def signature_data(self, renumbering):
data = ("TensorProductTensorElement", self._A, self._B,
self._shape, self._symmetry, self._quad_scheme,
("no cell" if self._cell is None else
self._cell.signature_data(renumbering)))
return data

def reconstruct(self, **kwargs):
"""Construct a new TensorProductTensorElement with some properties
replaced with new values."""
cell = kwargs.get("cell", self.cell())
shape = kwargs.get("shape", self._shape)
symmetry = kwargs.get("symmetry", self._symmetry)
return TensorProductTensorElement(self._A, self._B, cell=cell,
shape=shape, symmetry=symmetry)

def reconstruction_signature(self):
"""Format as string for evaluation as Python object.
For use with cross language frameworks, stored in generated code
and evaluated later in Python to reconstruct this object.
This differs from repr in that it does not include domain
label and data, which must be reconstructed or supplied by other means.
"""
return "TensorProductTensorElement(%r, %r, %r, %r)" % (
self._sub_element, self._shape,
self._symmetry, self._quad_scheme)

def __str__(self):
"Format as string for pretty printing."
return "<Outer product tensor element: %r x %r>" % \
(self._sub_element, self._shape)

def shortstr(self):
"Format as string for pretty printing."
return "OPTensor"

0 comments on commit 3dd3e86

Please sign in to comment.