@@ -248,36 +248,18 @@ def shortstr(self):
248
248
class VectorElement (MixedElement ):
249
249
"A special case of a mixed finite element where all elements are equal"
250
250
251
- def __new__ (cls , family , cell , degree , dim = None ,
252
- form_degree = None , quad_scheme = None ):
253
- """Intercepts construction, such that it returns an
254
- TensorProductVectorElement when FiniteElement returns an
255
- TensorProductElement.
256
- """
257
- # Create mixed element from list of finite elements
258
- sub_element = FiniteElement (family , cell , degree ,
259
- form_degree = form_degree ,
260
- quad_scheme = quad_scheme )
261
-
262
- from ufl .finiteelement .tensorproductelement import TensorProductElement
263
- from ufl .finiteelement .tensorproductelement import TensorProductVectorElement
264
- if isinstance (sub_element , TensorProductElement ):
265
- return TensorProductVectorElement (sub_element , dim = dim )
266
-
267
- return super (VectorElement , cls ).__new__ (cls )
268
-
269
- def __init__ (self , family , cell , degree , dim = None ,
251
+ def __init__ (self , family , cell = None , degree = None , dim = None ,
270
252
form_degree = None , quad_scheme = None ):
271
253
"""
272
254
Create vector element (repeated mixed element)
273
255
274
256
*Arguments*
275
257
family (string)
276
- The finite element family
258
+ The finite element family (or a FiniteElement)
277
259
cell
278
- The geometric cell
260
+ The geometric cell (ignored if family is FiniteElement)
279
261
degree (int)
280
- The polynomial degree
262
+ The polynomial degree (ignored if family is a FiniteElement)
281
263
dim (int)
282
264
The value dimension of the element (optional)
283
265
form_degree (int)
@@ -286,20 +268,23 @@ def __init__(self, family, cell, degree, dim=None,
286
268
quad_scheme
287
269
The quadrature scheme (optional)
288
270
"""
289
- if cell is not None :
290
- cell = as_cell (cell )
271
+ if isinstance (family , FiniteElementBase ):
272
+ sub_element = family
273
+ cell = sub_element .cell ()
274
+ else :
275
+ if cell is not None :
276
+ cell = as_cell (cell )
277
+ # Create sub element
278
+ sub_element = FiniteElement (family , cell , degree ,
279
+ form_degree = form_degree ,
280
+ quad_scheme = quad_scheme )
291
281
292
282
# Set default size if not specified
293
283
if dim is None :
294
284
ufl_assert (cell is not None ,
295
285
"Cannot infer vector dimension without a cell." )
296
286
dim = cell .geometric_dimension ()
297
287
298
- # Create sub element
299
- sub_element = FiniteElement (family , cell , degree ,
300
- form_degree = form_degree ,
301
- quad_scheme = quad_scheme )
302
-
303
288
# Create list of sub elements for mixed element constructor
304
289
sub_elements = [sub_element ]* dim
305
290
@@ -311,16 +296,15 @@ def __init__(self, family, cell, degree, dim=None,
311
296
MixedElement .__init__ (self , sub_elements , value_shape = value_shape , reference_value_shape = reference_value_shape )
312
297
# FIXME: Storing this here is strange, isn't that handled by subclass?
313
298
self ._family = sub_element .family ()
314
- self ._degree = degree
299
+ self ._degree = sub_element . degree ()
315
300
self ._sub_element = sub_element
316
301
self ._form_degree = form_degree # Storing for signature_data, not sure if it's needed
317
302
318
303
# Cache repr string
319
304
qs = self .quadrature_scheme ()
320
305
quad_str = "" if qs is None else ", quad_scheme=%r" % (qs ,)
321
- self ._repr = ("VectorElement(%r, %r, %r, dim=%d%s)" %
322
- (self ._family , self .cell (), self ._degree ,
323
- len (self ._sub_elements ), quad_str ))
306
+ self ._repr = ("VectorElement(%r, dim=%d%s)" %
307
+ (sub_element , len (self ._sub_elements ), quad_str ))
324
308
325
309
def __str__ (self ):
326
310
"Format as string for pretty printing."
@@ -340,38 +324,77 @@ class TensorElement(MixedElement):
340
324
"_sub_element_mapping" , "_flattened_sub_element_mapping" ,
341
325
"_mapping" )
342
326
343
- def __new__ (cls , family , cell , degree , shape = None ,
344
- symmetry = None , quad_scheme = None ):
345
- """Intercepts construction, such that it returns an
346
- TensorProductTensorElement when FiniteElement returns an
347
- TensorProductElement.
348
- """
349
- # Compute sub element
350
- sub_element = FiniteElement (family , cell , degree , quad_scheme )
351
-
352
- from ufl .finiteelement .tensorproductelement import TensorProductElement
353
- from ufl .finiteelement .tensorproductelement import TensorProductTensorElement
354
- if isinstance (sub_element , TensorProductElement ):
355
- return TensorProductTensorElement (sub_element , shape = shape , symmetry = symmetry )
356
-
357
- return super (TensorElement , cls ).__new__ (cls )
358
-
359
- def __init__ (self , family , cell , degree , shape = None ,
360
- symmetry = None , quad_scheme = None ):
327
+ def __init__ (self , family , cell = None , degree = None , shape = None , symmetry = None , quad_scheme = None ):
361
328
"Create tensor element (repeated mixed element with optional symmetries)"
362
- # Create scalar sub element
363
- sub_element = FiniteElement (family , cell , degree , quad_scheme )
329
+ if isinstance (family , FiniteElementBase ):
330
+ sub_element = family
331
+ else :
332
+ if cell is not None :
333
+ cell = as_cell (cell )
334
+ sub_element = FiniteElement (family , cell , degree , quad_scheme )
364
335
ufl_assert (sub_element .value_shape () == (),
365
336
"Expecting only scalar valued subelement for TensorElement." )
366
337
367
- shape , symmetry , sub_elements , sub_element_mapping , flattened_sub_element_mapping , \
368
- reference_value_shape , mapping = _tensor_sub_elements (sub_element , shape , symmetry )
338
+ # Set default shape if not specified
339
+ if shape is None :
340
+ ufl_assert (sub_element .cell () is not None ,
341
+ "Cannot infer tensor shape without a cell." )
342
+ dim = sub_element .cell ().geometric_dimension ()
343
+ shape = (dim , dim )
344
+
345
+ if symmetry is None :
346
+ symmetry = EmptyDict
347
+ elif symmetry is True :
348
+ # Construct default symmetry dict for matrix elements
349
+ ufl_assert (len (shape ) == 2 and shape [0 ] == shape [1 ],
350
+ "Cannot set automatic symmetry for non-square tensor." )
351
+ symmetry = dict ( ((i , j ), (j , i )) for i in range (shape [0 ])
352
+ for j in range (shape [1 ]) if i > j )
353
+ else :
354
+ ufl_assert (isinstance (symmetry , dict ), "Expecting symmetry to be None (unset), True, or dict." )
355
+
356
+ # Validate indices in symmetry dict
357
+ for i , j in iteritems (symmetry ):
358
+ ufl_assert (len (i ) == len (j ),
359
+ "Non-matching length of symmetry index tuples." )
360
+ for k in range (len (i )):
361
+ ufl_assert (i [k ] >= 0 and j [k ] >= 0 and
362
+ i [k ] < shape [k ] and j [k ] < shape [k ],
363
+ "Symmetry dimensions out of bounds." )
364
+
365
+ # Compute all index combinations for given shape
366
+ indices = compute_indices (shape )
367
+
368
+ # Compute mapping from indices to sub element number, accounting for symmetry
369
+ sub_elements = []
370
+ sub_element_mapping = {}
371
+ for index in indices :
372
+ if index in symmetry :
373
+ continue
374
+ sub_element_mapping [index ] = len (sub_elements )
375
+ sub_elements += [sub_element ]
376
+
377
+ # Update mapping for symmetry
378
+ for index in indices :
379
+ if index in symmetry :
380
+ sub_element_mapping [index ] = sub_element_mapping [symmetry [index ]]
381
+ flattened_sub_element_mapping = [sub_element_mapping [index ] for i , index in enumerate (indices )]
382
+
383
+ # Compute reference value shape based on symmetries
384
+ if symmetry :
385
+ # Flatten and subtract symmetries
386
+ reference_value_shape = (product (shape )- len (symmetry ),)
387
+ mapping = "symmetries"
388
+ else :
389
+ # Do not flatten if there are no symmetries
390
+ reference_value_shape = shape
391
+ mapping = "identity"
369
392
370
393
# Initialize element data
371
394
MixedElement .__init__ (self , sub_elements , value_shape = shape ,
372
395
reference_value_shape = reference_value_shape )
373
396
self ._family = sub_element .family ()
374
- self ._degree = degree
397
+ self ._degree = sub_element . degree ()
375
398
self ._sub_element = sub_element
376
399
self ._shape = shape
377
400
self ._symmetry = symmetry
@@ -382,9 +405,8 @@ def __init__(self, family, cell, degree, shape=None,
382
405
# Cache repr string
383
406
qs = self .quadrature_scheme ()
384
407
quad_str = "" if qs is None else ", quad_scheme=%r" % (qs ,)
385
- self ._repr = ("TensorElement(%r, %r, %r, shape=%r, symmetry=%r%s)" %
386
- (self ._family , self .cell (), self ._degree , self ._shape ,
387
- self ._symmetry , quad_str ))
408
+ self ._repr = ("TensorElement(%r, shape=%r, symmetry=%r%s)" %
409
+ (sub_element , self ._shape , self ._symmetry , quad_str ))
388
410
389
411
def mapping (self ):
390
412
if self ._symmetry :
@@ -435,64 +457,3 @@ def shortstr(self):
435
457
sym = ""
436
458
return "Tensor<%s x %s%s>" % (self .value_shape (),
437
459
self ._sub_element .shortstr (), sym )
438
-
439
-
440
- def _tensor_sub_elements (sub_element , shape , symmetry ):
441
- # Set default shape if not specified
442
- if shape is None :
443
- ufl_assert (sub_element .cell () is not None ,
444
- "Cannot infer tensor shape without a cell." )
445
- dim = sub_element .cell ().geometric_dimension ()
446
- shape = (dim , dim )
447
-
448
- if symmetry is None :
449
- symmetry = EmptyDict
450
- elif symmetry is True :
451
- # Construct default symmetry dict for matrix elements
452
- ufl_assert (len (shape ) == 2 and shape [0 ] == shape [1 ],
453
- "Cannot set automatic symmetry for non-square tensor." )
454
- symmetry = dict ( ((i , j ), (j , i )) for i in range (shape [0 ])
455
- for j in range (shape [1 ]) if i > j )
456
- else :
457
- ufl_assert (isinstance (symmetry , dict ), "Expecting symmetry to be None (unset), True, or dict." )
458
-
459
- # Validate indices in symmetry dict
460
- for i , j in iteritems (symmetry ):
461
- ufl_assert (len (i ) == len (j ),
462
- "Non-matching length of symmetry index tuples." )
463
- for k in range (len (i )):
464
- ufl_assert (i [k ] >= 0 and j [k ] >= 0 and
465
- i [k ] < shape [k ] and j [k ] < shape [k ],
466
- "Symmetry dimensions out of bounds." )
467
-
468
- # Compute all index combinations for given shape
469
- indices = compute_indices (shape )
470
-
471
- # Compute mapping from indices to sub element number, accounting for symmetry
472
- sub_elements = []
473
- sub_element_mapping = {}
474
- for index in indices :
475
- if index in symmetry :
476
- continue
477
- sub_element_mapping [index ] = len (sub_elements )
478
- sub_elements += [sub_element ]
479
-
480
- # Update mapping for symmetry
481
- for index in indices :
482
- if index in symmetry :
483
- sub_element_mapping [index ] = sub_element_mapping [symmetry [index ]]
484
- flattened_sub_element_mapping = [sub_element_mapping [index ] for i , index in enumerate (indices )]
485
-
486
- # Compute reference value shape based on symmetries
487
- if symmetry :
488
- # Flatten and subtract symmetries
489
- reference_value_shape = (product (shape )- len (symmetry ),)
490
- mapping = "symmetries"
491
- else :
492
- # Do not flatten if there are no symmetries
493
- reference_value_shape = shape
494
- mapping = "identity"
495
-
496
-
497
- return shape , symmetry , sub_elements , sub_element_mapping , \
498
- flattened_sub_element_mapping , reference_value_shape , mapping
0 commit comments