Skip to content

Commit

Permalink
Merged in remove-tpve-tpte (pull request #33)
Browse files Browse the repository at this point in the history
* remove-tpve-tpte:
  Purge now unnecessary OPVE and OPTE
  Support building Vector/TensorElement from elements
  • Loading branch information
wence- committed Feb 22, 2016
2 parents 6eb8619 + 3dd3e86 commit 8905425
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 280 deletions.
2 changes: 1 addition & 1 deletion test/test_change_to_reference_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from ufl.algorithms.transformer import ReuseTransformer, apply_transformer
from ufl.compound_expressions import determinant_expr, cross_expr, inverse_expr
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorProductVectorElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement
'''


Expand Down
7 changes: 3 additions & 4 deletions ufl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
FiniteElement,
MixedElement, VectorElement, TensorElement
EnrichedElement, RestrictedElement,
TensorProductElement, TensorProductVectorElement, TensorProductTensorElement,
HDiv, HCurl
TensorProductElement,
HDivElement, HCurlElement
BrokenElement, TraceElement
FacetElement, InteriorElement
Expand Down Expand Up @@ -215,7 +215,6 @@
from ufl.finiteelement import FiniteElementBase, FiniteElement, \
MixedElement, VectorElement, TensorElement, EnrichedElement, \
RestrictedElement, TensorProductElement, \
TensorProductVectorElement, TensorProductTensorElement, \
HDivElement, HCurlElement, BrokenElement, TraceElement, \
FacetElement, InteriorElement

Expand Down Expand Up @@ -313,7 +312,7 @@
'FiniteElementBase', 'FiniteElement',
'MixedElement', 'VectorElement', 'TensorElement', 'EnrichedElement',
'RestrictedElement', 'TensorProductElement',
'TensorProductVectorElement', 'TensorProductTensorElement', 'HDivElement', 'HCurlElement',
'HDivElement', 'HCurlElement',
'BrokenElement', 'TraceElement', 'FacetElement', 'InteriorElement',
'register_element', 'show_elements',
'FunctionSpace',
Expand Down
4 changes: 2 additions & 2 deletions ufl/algorithms/apply_function_pullbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from ufl.tensors import as_tensor, as_vector

from ufl.finiteelement import (FiniteElement, EnrichedElement, VectorElement, MixedElement,
TensorProductElement, TensorProductVectorElement, TensorElement,
FacetElement, InteriorElement, BrokenElement, TraceElement)
TensorProductElement, FacetElement, InteriorElement,
BrokenElement, TraceElement)
from ufl.utils.sequences import product

def sub_elements_with_mappings(element):
Expand Down
2 changes: 1 addition & 1 deletion ufl/algorithms/change_to_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ufl.permutation import compute_indices

from ufl.compound_expressions import determinant_expr, cross_expr, inverse_expr
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorProductVectorElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement
from ufl.finiteelement import FiniteElement, EnrichedElement, VectorElement, MixedElement, TensorProductElement, TensorElement, FacetElement, InteriorElement, BrokenElement, TraceElement

from ufl.algorithms.apply_function_pullbacks import apply_function_pullbacks
from ufl.algorithms.apply_geometry_lowering import apply_geometry_lowering
Expand Down
4 changes: 0 additions & 4 deletions ufl/finiteelement/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
from ufl.finiteelement.enrichedelement import EnrichedElement
from ufl.finiteelement.restrictedelement import RestrictedElement
from ufl.finiteelement.tensorproductelement import TensorProductElement
from ufl.finiteelement.tensorproductelement import TensorProductVectorElement
from ufl.finiteelement.tensorproductelement import TensorProductTensorElement
from ufl.finiteelement.hdivcurl import HDivElement, HCurlElement
from ufl.finiteelement.brokenelement import BrokenElement
from ufl.finiteelement.traceelement import TraceElement
Expand All @@ -49,8 +47,6 @@
"EnrichedElement",
"RestrictedElement",
"TensorProductElement",
"TensorProductVectorElement",
"TensorProductTensorElement",
"HDivElement",
"HCurlElement",
"BrokenElement",
Expand Down
201 changes: 81 additions & 120 deletions ufl/finiteelement/mixedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,36 +248,18 @@ def shortstr(self):
class VectorElement(MixedElement):
"A special case of a mixed finite element where all elements are equal"

def __new__(cls, family, cell, degree, dim=None,
form_degree=None, quad_scheme=None):
"""Intercepts construction, such that it returns an
TensorProductVectorElement when FiniteElement returns an
TensorProductElement.
"""
# Create mixed element from list of finite elements
sub_element = FiniteElement(family, cell, degree,
form_degree=form_degree,
quad_scheme=quad_scheme)

from ufl.finiteelement.tensorproductelement import TensorProductElement
from ufl.finiteelement.tensorproductelement import TensorProductVectorElement
if isinstance(sub_element, TensorProductElement):
return TensorProductVectorElement(sub_element, dim=dim)

return super(VectorElement, cls).__new__(cls)

def __init__(self, family, cell, degree, dim=None,
def __init__(self, family, cell=None, degree=None, dim=None,
form_degree=None, quad_scheme=None):
"""
Create vector element (repeated mixed element)
*Arguments*
family (string)
The finite element family
The finite element family (or a FiniteElement)
cell
The geometric cell
The geometric cell (ignored if family is FiniteElement)
degree (int)
The polynomial degree
The polynomial degree (ignored if family is a FiniteElement)
dim (int)
The value dimension of the element (optional)
form_degree (int)
Expand All @@ -286,20 +268,23 @@ def __init__(self, family, cell, degree, dim=None,
quad_scheme
The quadrature scheme (optional)
"""
if cell is not None:
cell = as_cell(cell)
if isinstance(family, FiniteElementBase):
sub_element = family
cell = sub_element.cell()
else:
if cell is not None:
cell = as_cell(cell)
# Create sub element
sub_element = FiniteElement(family, cell, degree,
form_degree=form_degree,
quad_scheme=quad_scheme)

# Set default size if not specified
if dim is None:
ufl_assert(cell is not None,
"Cannot infer vector dimension without a cell.")
dim = cell.geometric_dimension()

# Create sub element
sub_element = FiniteElement(family, cell, degree,
form_degree=form_degree,
quad_scheme=quad_scheme)

# Create list of sub elements for mixed element constructor
sub_elements = [sub_element]*dim

Expand All @@ -311,16 +296,15 @@ def __init__(self, family, cell, degree, dim=None,
MixedElement.__init__(self, sub_elements, value_shape=value_shape, reference_value_shape=reference_value_shape)
# FIXME: Storing this here is strange, isn't that handled by subclass?
self._family = sub_element.family()
self._degree = degree
self._degree = sub_element.degree()
self._sub_element = sub_element
self._form_degree = form_degree # Storing for signature_data, not sure if it's needed

# Cache repr string
qs = self.quadrature_scheme()
quad_str = "" if qs is None else ", quad_scheme=%r" % (qs,)
self._repr = ("VectorElement(%r, %r, %r, dim=%d%s)" %
(self._family, self.cell(), self._degree,
len(self._sub_elements), quad_str))
self._repr = ("VectorElement(%r, dim=%d%s)" %
(sub_element, len(self._sub_elements), quad_str))

def __str__(self):
"Format as string for pretty printing."
Expand All @@ -340,38 +324,77 @@ class TensorElement(MixedElement):
"_sub_element_mapping", "_flattened_sub_element_mapping",
"_mapping")

def __new__(cls, family, cell, degree, shape=None,
symmetry=None, quad_scheme=None):
"""Intercepts construction, such that it returns an
TensorProductTensorElement when FiniteElement returns an
TensorProductElement.
"""
# Compute sub element
sub_element = FiniteElement(family, cell, degree, quad_scheme)

from ufl.finiteelement.tensorproductelement import TensorProductElement
from ufl.finiteelement.tensorproductelement import TensorProductTensorElement
if isinstance(sub_element, TensorProductElement):
return TensorProductTensorElement(sub_element, shape=shape, symmetry=symmetry)

return super(TensorElement, cls).__new__(cls)

def __init__(self, family, cell, degree, shape=None,
symmetry=None, quad_scheme=None):
def __init__(self, family, cell=None, degree=None, shape=None, symmetry=None, quad_scheme=None):
"Create tensor element (repeated mixed element with optional symmetries)"
# Create scalar sub element
sub_element = FiniteElement(family, cell, degree, quad_scheme)
if isinstance(family, FiniteElementBase):
sub_element = family
else:
if cell is not None:
cell = as_cell(cell)
sub_element = FiniteElement(family, cell, degree, quad_scheme)
ufl_assert(sub_element.value_shape() == (),
"Expecting only scalar valued subelement for TensorElement.")

shape, symmetry, sub_elements, sub_element_mapping, flattened_sub_element_mapping, \
reference_value_shape, mapping = _tensor_sub_elements(sub_element, shape, symmetry)
# Set default shape if not specified
if shape is None:
ufl_assert(sub_element.cell() is not None,
"Cannot infer tensor shape without a cell.")
dim = sub_element.cell().geometric_dimension()
shape = (dim, dim)

if symmetry is None:
symmetry = EmptyDict
elif symmetry is True:
# Construct default symmetry dict for matrix elements
ufl_assert(len(shape) == 2 and shape[0] == shape[1],
"Cannot set automatic symmetry for non-square tensor.")
symmetry = dict( ((i, j), (j, i)) for i in range(shape[0])
for j in range(shape[1]) if i > j )
else:
ufl_assert(isinstance(symmetry, dict), "Expecting symmetry to be None (unset), True, or dict.")

# Validate indices in symmetry dict
for i, j in iteritems(symmetry):
ufl_assert(len(i) == len(j),
"Non-matching length of symmetry index tuples.")
for k in range(len(i)):
ufl_assert(i[k] >= 0 and j[k] >= 0 and
i[k] < shape[k] and j[k] < shape[k],
"Symmetry dimensions out of bounds.")

# Compute all index combinations for given shape
indices = compute_indices(shape)

# Compute mapping from indices to sub element number, accounting for symmetry
sub_elements = []
sub_element_mapping = {}
for index in indices:
if index in symmetry:
continue
sub_element_mapping[index] = len(sub_elements)
sub_elements += [sub_element]

# Update mapping for symmetry
for index in indices:
if index in symmetry:
sub_element_mapping[index] = sub_element_mapping[symmetry[index]]
flattened_sub_element_mapping = [sub_element_mapping[index] for i, index in enumerate(indices)]

# Compute reference value shape based on symmetries
if symmetry:
# Flatten and subtract symmetries
reference_value_shape = (product(shape)-len(symmetry),)
mapping = "symmetries"
else:
# Do not flatten if there are no symmetries
reference_value_shape = shape
mapping = "identity"

# Initialize element data
MixedElement.__init__(self, sub_elements, value_shape=shape,
reference_value_shape=reference_value_shape)
self._family = sub_element.family()
self._degree = degree
self._degree = sub_element.degree()
self._sub_element = sub_element
self._shape = shape
self._symmetry = symmetry
Expand All @@ -382,9 +405,8 @@ def __init__(self, family, cell, degree, shape=None,
# Cache repr string
qs = self.quadrature_scheme()
quad_str = "" if qs is None else ", quad_scheme=%r" % (qs,)
self._repr = ("TensorElement(%r, %r, %r, shape=%r, symmetry=%r%s)" %
(self._family, self.cell(), self._degree, self._shape,
self._symmetry, quad_str))
self._repr = ("TensorElement(%r, shape=%r, symmetry=%r%s)" %
(sub_element, self._shape, self._symmetry, quad_str))

def mapping(self):
if self._symmetry:
Expand Down Expand Up @@ -435,64 +457,3 @@ def shortstr(self):
sym = ""
return "Tensor<%s x %s%s>" % (self.value_shape(),
self._sub_element.shortstr(), sym)


def _tensor_sub_elements(sub_element, shape, symmetry):
# Set default shape if not specified
if shape is None:
ufl_assert(sub_element.cell() is not None,
"Cannot infer tensor shape without a cell.")
dim = sub_element.cell().geometric_dimension()
shape = (dim, dim)

if symmetry is None:
symmetry = EmptyDict
elif symmetry is True:
# Construct default symmetry dict for matrix elements
ufl_assert(len(shape) == 2 and shape[0] == shape[1],
"Cannot set automatic symmetry for non-square tensor.")
symmetry = dict( ((i, j), (j, i)) for i in range(shape[0])
for j in range(shape[1]) if i > j )
else:
ufl_assert(isinstance(symmetry, dict), "Expecting symmetry to be None (unset), True, or dict.")

# Validate indices in symmetry dict
for i, j in iteritems(symmetry):
ufl_assert(len(i) == len(j),
"Non-matching length of symmetry index tuples.")
for k in range(len(i)):
ufl_assert(i[k] >= 0 and j[k] >= 0 and
i[k] < shape[k] and j[k] < shape[k],
"Symmetry dimensions out of bounds.")

# Compute all index combinations for given shape
indices = compute_indices(shape)

# Compute mapping from indices to sub element number, accounting for symmetry
sub_elements = []
sub_element_mapping = {}
for index in indices:
if index in symmetry:
continue
sub_element_mapping[index] = len(sub_elements)
sub_elements += [sub_element]

# Update mapping for symmetry
for index in indices:
if index in symmetry:
sub_element_mapping[index] = sub_element_mapping[symmetry[index]]
flattened_sub_element_mapping = [sub_element_mapping[index] for i, index in enumerate(indices)]

# Compute reference value shape based on symmetries
if symmetry:
# Flatten and subtract symmetries
reference_value_shape = (product(shape)-len(symmetry),)
mapping = "symmetries"
else:
# Do not flatten if there are no symmetries
reference_value_shape = shape
mapping = "identity"


return shape, symmetry, sub_elements, sub_element_mapping, \
flattened_sub_element_mapping, reference_value_shape, mapping
Loading

0 comments on commit 8905425

Please sign in to comment.