Skip to content

Commit

Permalink
Enum and type updates (#832)
Browse files Browse the repository at this point in the history
* Make 'default' enum lower case in Python interface

* Simplifications

* Import fix

* Type updates

* Syntax fix
  • Loading branch information
garth-wells authored Jun 5, 2024
1 parent 6460e77 commit 8534dd2
Show file tree
Hide file tree
Showing 13 changed files with 40 additions and 125 deletions.
4 changes: 2 additions & 2 deletions python/basix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# Template placeholder for injecting Windows dll directories in CI
# WINDOWSDLL

from basix import cell, finite_element, lattice, polynomials, quadrature, sobolev_spaces
from basix._basixcpp import MapType
from basix._basixcpp import __version__ # type: ignore
from basix import cell, finite_element, lattice, polynomials, quadrature, sobolev_spaces
from basix.cell import CellType, geometry, topology
from basix.finite_element import (
DPCVariant,
Expand All @@ -25,7 +26,6 @@
)
from basix.interpolation import compute_interpolation_operator
from basix.lattice import LatticeSimplexMethod, LatticeType, create_lattice
from basix.maps import MapType
from basix.polynomials import PolynomialType, PolysetType, tabulate_polynomials
from basix.polynomials import restriction as polyset_restriction
from basix.polynomials import superset as polyset_superset
Expand Down
24 changes: 23 additions & 1 deletion python/basix/_basixcpp.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ from numpy.typing import ArrayLike


class CellType(enum.IntEnum):
"""Cell type."""

point = 0

interval = 1
Expand All @@ -23,6 +25,8 @@ class CellType(enum.IntEnum):
pyramid = 7

class DPCVariant(enum.IntEnum):
"""DPC variant."""

unset = 0

simplex_equispaced = 1
Expand All @@ -40,6 +44,8 @@ class DPCVariant(enum.IntEnum):
legendre = 7

class ElementFamily(enum.IntEnum):
"""Finite element family."""

custom = 0

P = 1
Expand Down Expand Up @@ -317,6 +323,8 @@ class FiniteElement_float64:
def dtype(self) -> str: ...

class LagrangeVariant(enum.IntEnum):
"""Lagrange element variant."""

unset = 0

equispaced = 1
Expand Down Expand Up @@ -344,6 +352,8 @@ class LagrangeVariant(enum.IntEnum):
bernstein = 12

class LatticeSimplexMethod(enum.IntEnum):
"""Lattice simplex method."""

none = 0

warp = 1
Expand All @@ -353,6 +363,8 @@ class LatticeSimplexMethod(enum.IntEnum):
centroid = 3

class LatticeType(enum.IntEnum):
"""Lattice type."""

equispaced = 0

gll = 1
Expand All @@ -362,6 +374,8 @@ class LatticeType(enum.IntEnum):
gl = 4

class MapType(enum.IntEnum):
"""Element map type."""

identity = 0

L2Piola = 1
Expand All @@ -375,17 +389,23 @@ class MapType(enum.IntEnum):
doubleContravariantPiola = 5

class PolynomialType(enum.IntEnum):
"""Polynomial type."""

legendre = 0

bernstein = 1

class PolysetType(enum.IntEnum):
"""Polyset type."""

standard = 0

macroedge = 1

class QuadratureType(enum.IntEnum):
Default = 0
"""Quadrature type."""

default = 0

gauss_jacobi = 1

Expand All @@ -394,6 +414,8 @@ class QuadratureType(enum.IntEnum):
xiao_gimbutas = 3

class SobolevSpace(enum.IntEnum):
"""Sobolev space."""

L2 = 0

H1 = 1
Expand Down
13 changes: 0 additions & 13 deletions python/basix/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from basix._basixcpp import topology as _topology

__all__ = [
"string_to_type",
"sub_entity_connectivity",
"volume",
"facet_jacobians",
Expand All @@ -30,18 +29,6 @@
]


def string_to_type(cell: str) -> CellType:
"""Convert a string to a Basix CellType.
Args:
cell: Name of the cell as a string.
Returns:
The cell type.
"""
return CellType[cell]


def sub_entity_connectivity(celltype: CellType) -> list[list[list[list[int]]]]:
"""Numbers of entities connected to each sub-entity of the cell.
Expand Down
35 changes: 1 addition & 34 deletions python/basix/finite_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from basix._basixcpp import tp_dof_ordering as _tp_dof_ordering
from basix._basixcpp import tp_factors as _tp_factors
from basix.cell import CellType
from basix.maps import MapType
from basix import MapType
from basix.polynomials import PolysetType
from basix.sobolev_spaces import SobolevSpace

Expand All @@ -34,8 +34,6 @@
"create_custom_element",
"create_tp_element",
"string_to_family",
"string_to_lagrange_variant",
"string_to_dpc_variant",
"tp_factors",
"tp_dof_ordering",
]
Expand Down Expand Up @@ -872,34 +870,3 @@ def string_to_family(family: str, cell: str) -> ElementFamily:
return families[family]
except KeyError:
raise ValueError(f"Unknown element family: {family} with cell type {cell}")


def string_to_lagrange_variant(variant: str) -> LagrangeVariant:
"""Convert a string to a Basix LagrangeVariant enum.
Args:
variant: Lagrange variant string.
Returns:
The Lagrange variant.
"""
if variant.lower() == "gll":
return LagrangeVariant.gll_warped
elif variant.lower() == "chebyshev":
return LagrangeVariant.chebyshev_isaac
elif variant.lower() == "gl":
return LagrangeVariant.gl_isaac

return LagrangeVariant[variant.lower()]


def string_to_dpc_variant(variant: str) -> DPCVariant:
"""Convert a string to a Basix DPCVariant enum.
Args:
variant: DPC variant as a string.
Returns:
The DPC variant.
"""
return DPCVariant[variant.lower()]
26 changes: 1 addition & 25 deletions python/basix/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,7 @@
from basix._basixcpp import create_lattice as _create_lattice
from basix.cell import CellType

__all__ = ["string_to_type", "string_to_simplex_method"]


def string_to_type(lattice: str) -> LatticeType:
"""Convert a string to a Basix LatticeType enum.
Args:
lattice: Lattice type as a string.
Returns:
Lattice type.
"""
return LatticeType[lattice]


def string_to_simplex_method(method: str) -> LatticeSimplexMethod:
"""Convert a string to a Basix LatticeSimplexMethod enum.
Args:
method: Simplex method as a string.
Returns:
Simplex method.
"""
return LatticeSimplexMethod[method]
__all__ = ["create_lattice"]


def create_lattice(
Expand Down
15 changes: 1 addition & 14 deletions python/basix/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,5 @@
# SPDX-License-Identifier: MIT
"""Maps."""

from basix._basixcpp import MapType

__all__ = ["string_to_type"]


def string_to_type(mapname: str) -> MapType:
"""Convert a string to a Basix MapType.
Args:
mapname: Name of the map as a string.
Returns:
The map type.
"""
return MapType[mapname]
# __all__ = ["string_to_type"]
12 changes: 0 additions & 12 deletions python/basix/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,6 @@ def superset(cell: CellType, type1: PolysetType, type2: PolysetType) -> PolysetT
return _superset(cell, type1, type2)


def string_to_polyset_type(pname: str) -> PolysetType:
"""Convert a string to a Basix PolysetType.
Args:
pname: Name of the polyset type as a string.
Returns:
The polyset type.
"""
return PolysetType[pname]


def tabulate_polynomial_set(
celltype: CellType, ptype: PolysetType, degree: int, nderiv: int, pts: npt.NDArray
) -> npt.ArrayLike:
Expand Down
4 changes: 2 additions & 2 deletions python/basix/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def string_to_type(rule: str) -> QuadratureType:
The quadrature type.
"""
if rule == "default":
return QuadratureType.Default
return QuadratureType.default
elif rule in ["Gauss-Lobatto-Legendre", "GLL"]:
return QuadratureType.gll
elif rule in ["Gauss-Legendre", "GL", "Gauss-Jacobi"]:
Expand All @@ -39,7 +39,7 @@ def string_to_type(rule: str) -> QuadratureType:
def make_quadrature(
cell: CellType,
degree: int,
rule: QuadratureType = QuadratureType.Default,
rule: QuadratureType = QuadratureType.default,
polyset_type: PolysetType = PolysetType.standard,
) -> tuple[_npt.ArrayLike, _npt.ArrayLike]:
"""Create a quadrature rule.
Expand Down
14 changes: 1 addition & 13 deletions python/basix/sobolev_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from basix._basixcpp import SobolevSpace
from basix._basixcpp import sobolev_space_intersection as _ssi

__all__ = ["intersection", "string_to_sobolev_space"]
__all__ = ["intersection"]


def intersection(spaces: list[SobolevSpace]) -> SobolevSpace:
Expand All @@ -24,15 +24,3 @@ def intersection(spaces: list[SobolevSpace]) -> SobolevSpace:
for s in spaces[1:]:
space = _ssi(space, s)
return SobolevSpace[space.name]


def string_to_sobolev_space(space: str) -> SobolevSpace:
"""Convert a string to a Basix SobolevSpace.
Args:
space: Name of the space.
Returns:
Cell type.
"""
return SobolevSpace[space]
8 changes: 4 additions & 4 deletions python/basix/ufl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _ufl_sobolev_space_from_enum(s: _basix.SobolevSpace):
return _spacemap[s]


def _ufl_pullback_from_enum(m: _basix.maps.MapType) -> _AbstractPullback:
def _ufl_pullback_from_enum(m: _basix.MapType) -> _AbstractPullback:
"""Convert an enum to a UFL pull back.
Args:
Expand Down Expand Up @@ -2014,7 +2014,7 @@ def element(
"""
# Conversion of string arguments to types
if isinstance(cell, str):
cell = _basix.cell.string_to_type(cell)
cell = _basix.CellType[cell]
if isinstance(family, str):
if family.startswith("Discontinuous "):
family = family[14:]
Expand Down Expand Up @@ -2265,7 +2265,7 @@ def quadrature_element(
A 'quadrature' finite element.
"""
if isinstance(cell, str):
cell = _basix.cell.string_to_type(cell)
cell = _basix.CellType[cell]

if points is None:
assert weights is None
Expand Down Expand Up @@ -2303,7 +2303,7 @@ def real_element(
"""
if isinstance(cell, str):
cell = _basix.cell.string_to_type(cell)
cell = _basix.CellType[cell]

return _RealElement(cell, value_shape)

Expand Down
2 changes: 1 addition & 1 deletion python/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ NB_MODULE(_basixcpp, m)

nb::enum_<quadrature::type>(m, "QuadratureType", nb::is_arithmetic(),
"Quadrature type.")
.value("Default", quadrature::type::Default)
.value("default", quadrature::type::Default)
.value("gauss_jacobi", quadrature::type::gauss_jacobi)
.value("gll", quadrature::type::gll)
.value("xiao_gimbutas", quadrature::type::xiao_gimbutas);
Expand Down
2 changes: 1 addition & 1 deletion test/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def random_point(cell):
vertices = basix.geometry(basix.cell.string_to_type(cell))
vertices = basix.geometry(basix.CellType[cell])
w = [random.random() for _ in vertices]
return sum(v * i for v, i in zip(vertices, w)) / sum(w)

Expand Down
6 changes: 3 additions & 3 deletions test/test_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_cell_quadrature(celltype, order):


@pytest.mark.parametrize("m", range(7))
@pytest.mark.parametrize("scheme", [basix.QuadratureType.Default, basix.QuadratureType.gll])
@pytest.mark.parametrize("scheme", [basix.QuadratureType.default, basix.QuadratureType.gll])
def test_qorder_line(m, scheme):
Qpts, Qwts = basix.make_quadrature(basix.CellType.interval, m, rule=scheme)
x = sympy.Symbol("x")
Expand All @@ -41,7 +41,7 @@ def test_qorder_line(m, scheme):

@pytest.mark.parametrize("m", range(6))
@pytest.mark.parametrize(
"scheme", [basix.QuadratureType.Default, basix.QuadratureType.gauss_jacobi]
"scheme", [basix.QuadratureType.default, basix.QuadratureType.gauss_jacobi]
)
def test_qorder_tri(m, scheme):
Qpts, Qwts = basix.make_quadrature(basix.CellType.triangle, m, rule=scheme)
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_xiao_gimbutas_tet(m, scheme):

@pytest.mark.parametrize("m", range(9))
@pytest.mark.parametrize(
"scheme", [basix.QuadratureType.Default, basix.QuadratureType.gauss_jacobi]
"scheme", [basix.QuadratureType.default, basix.QuadratureType.gauss_jacobi]
)
def test_qorder_tet(m, scheme):
Qpts, Qwts = basix.make_quadrature(basix.CellType.tetrahedron, m, rule=scheme)
Expand Down

0 comments on commit 8534dd2

Please sign in to comment.