Skip to content

Commit

Permalink
Add adjacency list python wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
schnellerhase committed Nov 16, 2024
1 parent 871fc27 commit 468f8a2
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 38 deletions.
21 changes: 11 additions & 10 deletions python/dolfinx/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
import numpy.typing as npt

if typing.TYPE_CHECKING:
from dolfinx.cpp.graph import AdjacencyList_int32
from dolfinx.mesh import Mesh


from dolfinx import cpp as _cpp
from dolfinx.graph import AdjacencyList

__all__ = [
"BoundingBoxTree",
Expand Down Expand Up @@ -155,9 +154,7 @@ def compute_collisions_trees(
return _cpp.geometry.compute_collisions_trees(tree0._cpp_object, tree1._cpp_object)


def compute_collisions_points(
tree: BoundingBoxTree, x: npt.NDArray[np.floating]
) -> _cpp.graph.AdjacencyList_int32:
def compute_collisions_points(tree: BoundingBoxTree, x: npt.NDArray[np.floating]) -> AdjacencyList:
"""Compute collisions between points and leaf bounding boxes.
Bounding boxes can overlap, therefore points can collide with more
Expand All @@ -172,7 +169,7 @@ def compute_collisions_points(
point.
"""
return _cpp.geometry.compute_collisions_points(tree._cpp_object, x)
return AdjacencyList(_cpp.geometry.compute_collisions_points(tree._cpp_object, x))


def compute_closest_entity(
Expand Down Expand Up @@ -216,8 +213,8 @@ def create_midpoint_tree(mesh: Mesh, dim: int, entities: npt.NDArray[np.int32])


def compute_colliding_cells(
mesh: Mesh, candidates: AdjacencyList_int32, x: npt.NDArray[np.floating]
):
mesh: Mesh, candidates: AdjacencyList, x: npt.NDArray[np.floating]
) -> AdjacencyList:
"""From a mesh, find which cells collide with a set of points.
Args:
Expand All @@ -231,10 +228,14 @@ def compute_colliding_cells(
collide with the ith point.
"""
return _cpp.geometry.compute_colliding_cells(mesh._cpp_object, candidates, x)
return AdjacencyList(
_cpp.geometry.compute_colliding_cells(mesh._cpp_object, candidates._cpp_object, x)
)


def squared_distance(mesh: Mesh, dim: int, entities: list[int], points: npt.NDArray[np.floating]):
def squared_distance(
mesh: Mesh, dim: int, entities: list[int], points: npt.NDArray[np.floating]
) -> npt.NDArray[np.floating]:
"""Compute the squared distance between a point and a mesh entity.
The distance is computed between the ith input points and the ith
Expand Down
49 changes: 35 additions & 14 deletions python/dolfinx/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2021 Garth N. Wells
# Copyright (C) 2021-2024 Garth N. Wells and Paul T. Kühner
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
Expand All @@ -7,7 +7,10 @@

from __future__ import annotations

from typing import Optional, Union

import numpy as np
import numpy.typing as npt

from dolfinx import cpp as _cpp
from dolfinx.cpp.graph import partitioner
Expand All @@ -28,10 +31,34 @@
pass


__all__ = ["adjacencylist", "partitioner"]
__all__ = ["AdjacencyList", "adjacencylist", "partitioner"]


class AdjacencyList:
_cpp_object: Union[_cpp.la.AdjacencyList_int32, _cpp.la.AdjacencyList_int64]

def __init__(self, cpp_object: Union[_cpp.la.AdjacencyList_int32, _cpp.la.AdjacencyList_int64]):
self._cpp_object = cpp_object

def links(self, node: Union[np.int32, np.int64]) -> npt.NDArray[Union[np.int32, np.int64]]:
return self._cpp_object.links(node)

@property
def array(self) -> npt.NDArray[Union[np.int32, np.int64]]:
return self._cpp_object.array

def adjacencylist(data: np.ndarray, offsets=None):
@property
def offsets(self) -> npt.NDArray[np.int32]:
return self._cpp_object.offsets

@property
def num_nodes(self) -> np.int32:
return self._cpp_object.num_nodes


def adjacencylist(
data: npt.NDArray[Union[np.int32, np.int64]], offsets: Optional[npt.NDArray[np.int32]] = None
) -> AdjacencyList:
"""Create an AdjacencyList for int32 or int64 datasets.
Args:
Expand All @@ -42,15 +69,9 @@ def adjacencylist(data: np.ndarray, offsets=None):
Returns:
An adjacency list.
"""
if offsets is None:
try:
return _cpp.graph.AdjacencyList_int32(data)
except TypeError:
return _cpp.graph.AdjacencyList_int64(data)
else:
try:
return _cpp.graph.AdjacencyList_int32(data, offsets)
except TypeError:
return _cpp.graph.AdjacencyList_int64(data, offsets)
# Switch to np.isdtype(data.dtype, np.int32) once numpy >= 2.0 is enforced
is_32bit = data.dtype == np.int32
cpp_t = _cpp.graph.AdjacencyList_int32 if is_32bit else _cpp.graph.AdjacencyList_int64
cpp_object = cpp_t(data, offsets) if offsets is not None else cpp_t(data)
return AdjacencyList(cpp_object)
7 changes: 4 additions & 3 deletions python/dolfinx/io/gmshio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dolfinx import cpp as _cpp
from dolfinx import default_real_type
from dolfinx.cpp.graph import AdjacencyList_int32
from dolfinx.graph import AdjacencyList, adjacencylist
from dolfinx.io.utils import distribute_entity_data
from dolfinx.mesh import CellType, Mesh, create_mesh, meshtags, meshtags_from_entities

Expand Down Expand Up @@ -298,7 +299,7 @@ def model_to_mesh(
mesh, mesh.topology.dim, cells, cell_values
)
mesh.topology.create_connectivity(mesh.topology.dim, 0)
adj = _cpp.graph.AdjacencyList_int32(local_entities)
adj = adjacencylist(local_entities)
ct = meshtags_from_entities(
mesh, mesh.topology.dim, adj, local_values.astype(np.int32, copy=False)
)
Expand All @@ -323,7 +324,7 @@ def model_to_mesh(
mesh, tdim - 1, marked_facets, facet_values
)
mesh.topology.create_connectivity(topology.dim - 1, tdim)
adj = _cpp.graph.AdjacencyList_int32(local_entities)
adj = adjacencylist(local_entities)
ft = meshtags_from_entities(mesh, tdim - 1, adj, local_values.astype(np.int32, copy=False))
ft.name = "Facet tags"
else:
Expand All @@ -338,7 +339,7 @@ def read_from_msh(
rank: int = 0,
gdim: int = 3,
partitioner: typing.Optional[
typing.Callable[[_MPI.Comm, int, int, AdjacencyList_int32], AdjacencyList_int32]
typing.Callable[[_MPI.Comm, int, int, AdjacencyList], AdjacencyList_int32]
] = None,
) -> tuple[Mesh, _cpp.mesh.MeshTags_int32, _cpp.mesh.MeshTags_int32]:
"""Read a Gmsh .msh file and return a :class:`dolfinx.mesh.Mesh` and cell facet markers.
Expand Down
7 changes: 5 additions & 2 deletions python/dolfinx/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from dolfinx.cpp.refinement import RefinementOption
from dolfinx.fem import CoordinateElement as _CoordinateElement
from dolfinx.fem import coordinate_element as _coordinate_element
from dolfinx.graph import AdjacencyList

__all__ = [
"meshtags_from_entities",
Expand Down Expand Up @@ -735,7 +736,7 @@ def meshtags(


def meshtags_from_entities(
msh: Mesh, dim: int, entities: _cpp.graph.AdjacencyList_int32, values: npt.NDArray[typing.Any]
msh: Mesh, dim: int, entities: AdjacencyList, values: npt.NDArray[typing.Any]
):
"""Create a :class:dolfinx.mesh.MeshTags` object that associates
data with a subset of mesh entities, where the entities are defined
Expand All @@ -762,7 +763,9 @@ def meshtags_from_entities(
elif isinstance(values, float):
values = np.full(entities.num_nodes, values, dtype=np.double)
values = np.asarray(values)
return MeshTags(_cpp.mesh.create_meshtags(msh.topology._cpp_object, dim, entities, values))
return MeshTags(
_cpp.mesh.create_meshtags(msh.topology._cpp_object, dim, entities._cpp_object, values)
)


def create_interval(
Expand Down
6 changes: 2 additions & 4 deletions python/test/unit/fem/test_assemble_domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import pytest

import ufl
from dolfinx import cpp as _cpp
from dolfinx import default_scalar_type, fem, la
from dolfinx.fem import Constant, Function, assemble_scalar, dirichletbc, form, functionspace
from dolfinx.graph import adjacencylist
from dolfinx.mesh import (
GhostMode,
Mesh,
Expand All @@ -33,9 +33,7 @@ def mesh():
def create_cell_meshtags_from_entities(mesh: Mesh, dim: int, cells: np.ndarray, values: np.ndarray):
mesh.topology.create_connectivity(mesh.topology.dim, 0)
cell_to_vertices = mesh.topology.connectivity(mesh.topology.dim, 0)
entities = _cpp.graph.AdjacencyList_int32(
np.array([cell_to_vertices.links(cell) for cell in cells])
)
entities = adjacencylist(np.array([cell_to_vertices.links(cell) for cell in cells]))
return meshtags_from_entities(mesh, dim, entities, values)


Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/fem/test_assembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,7 @@ def test_assemble_empty_rank_mesh(self):
def partitioner(comm, nparts, local_graph, num_ghost_nodes):
"""Leave cells on the curent rank"""
dest = np.full(len(cells), comm.rank, dtype=np.int32)
return graph.adjacencylist(dest)
return graph.adjacencylist(dest)._cpp_object

if comm.rank == 0:
# Put cells on rank 0
Expand Down
3 changes: 2 additions & 1 deletion python/test/unit/fem/test_dofmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ def test_empty_rank_collapse():
def self_partitioner(comm: MPI.Intracomm, n, m, topo):
dests = np.full(len(topo[0]) // 2, comm.rank, dtype=np.int32)
offsets = np.arange(len(topo[0]) // 2 + 1, dtype=np.int32)
return dolfinx.graph.adjacencylist(dests, offsets)
# TODO: can we improve on this interface? I.e. warp to do cpp type conversion automatically
return dolfinx.graph.adjacencylist(dests, offsets)._cpp_object

mesh = create_mesh(MPI.COMM_WORLD, cells, nodes, c_el, partitioner=self_partitioner)

Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/io/test_adios2.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_empty_rank_mesh(self, tempdir):
def partitioner(comm, nparts, local_graph, num_ghost_nodes):
"""Leave cells on the current rank"""
dest = np.full(len(cells), comm.rank, dtype=np.int32)
return adjacencylist(dest)
return adjacencylist(dest)._cpp_object

if comm.rank == 0:
cells = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int64)
Expand Down
4 changes: 3 additions & 1 deletion python/test/unit/mesh/test_dual_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def test_dgrsph_1d():
x = 0
# Circular chain of interval cells
cells = [[n0, n0 + 1], [n0 + 1, n0 + 2], [n0 + 2, x]]
w = mesh.build_dual_graph(MPI.COMM_WORLD, mesh.CellType.interval, to_adj(cells, np.int64))
w = mesh.build_dual_graph(
MPI.COMM_WORLD, mesh.CellType.interval, to_adj(cells, np.int64)._cpp_object
)
assert w.num_nodes == 3
for i in range(w.num_nodes):
assert len(w.links(i)) == 2
2 changes: 1 addition & 1 deletion python/test/unit/mesh/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_empty_rank_mesh(dtype):
def partitioner(comm, nparts, local_graph, num_ghost_nodes):
"""Leave cells on the curent rank"""
dest = np.full(len(cells), comm.rank, dtype=np.int32)
return graph.adjacencylist(dest)
return graph.adjacencylist(dest)._cpp_object

if comm.rank == 0:
cells = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int64)
Expand Down

0 comments on commit 468f8a2

Please sign in to comment.