From 6cfcade797ba91d21ea186b81012db46aa3557ed Mon Sep 17 00:00:00 2001 From: fprill <4728053+fprill@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:51:17 +0100 Subject: [PATCH] feature: support icon graphs (#53) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: topology-based encoder/processor/decoder graphs derived from an ICON grid file. * docs: attempt to set a correct license header. * doc: changed the docstring style to the Numpy-style instead of Google-style. * refactor: changed variable name `verts` into `vertices`. * fix: changed float division to double slash operator. * refactor: removed unnecessary `else` branch. * refactor: more descriptive variable name in for loop. * refactor: renamed variable (ignoring the minus sign of `phi`). * refactor: changed name of temporary variable. * refactor: remove unnecessary `else´branch. * refactor: added type annotation for completeness. * refactor: remove default in function argument. * refactor: change argument name to a more understandable name. * refactor: remove redundant code. * refactor: more Pythonic if-else statement. * refactor: more Pythonic if-else statement. * refactor: use more appropriate `LOGGER.debug` instead of verbosity flag. * refactor: more appropriate variable name. * refactor: more appropriate variable name. * refactor: unified the three ICON Edgebuilders and the two Nodebuilders. * refactor: more verbose but also clearer names for variables in mesh construction algorithm. * remove: removed obsolete function `set_constant_edge_id`. * refactor: replaced the sequential ID counter by a UUID. * revert change of copyright notice * [fix] add encoder & processor edges to the __all__ variable * [refactor] move auxiliary functions to utils.py in graphs/generate/ * [doc] added empty torso for class documentation. * [fix] fixed interfaces (masks). * [refactor] remove edge attribute calculation from this PR. * [chore] adjust copyright notice. * [doc] elaborate on icon mesh classes in rst file. * Add Icon tests * update Change log * Add __future__ annotations * fix change log --------- Co-authored-by: Marek Jacob <1129-b380572@users.noreply.gitlab.dkrz.de> Co-authored-by: Florian Prill --- CHANGELOG.md | 1 + docs/graphs/node_coordinates/icon_mesh.rst | 64 +++ pyproject.toml | 2 + src/anemoi/graphs/edges/__init__.py | 12 +- src/anemoi/graphs/edges/builder.py | 137 +++++- src/anemoi/graphs/generate/icon_mesh.py | 395 ++++++++++++++++++ src/anemoi/graphs/generate/utils.py | 76 ++++ src/anemoi/graphs/nodes/__init__.py | 6 + src/anemoi/graphs/nodes/builders/from_icon.py | 83 ++++ tests/nodes/test_icon.py | 196 +++++++++ 10 files changed, 968 insertions(+), 4 deletions(-) create mode 100644 docs/graphs/node_coordinates/icon_mesh.rst create mode 100644 src/anemoi/graphs/generate/icon_mesh.py create mode 100644 src/anemoi/graphs/nodes/builders/from_icon.py create mode 100644 tests/nodes/test_icon.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9307b58..7743f30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD) +- feat: Define node sets and edges based on an ICON icosahedral mesh (#53) ## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08 diff --git a/docs/graphs/node_coordinates/icon_mesh.rst b/docs/graphs/node_coordinates/icon_mesh.rst new file mode 100644 index 0000000..dc306ed --- /dev/null +++ b/docs/graphs/node_coordinates/icon_mesh.rst @@ -0,0 +1,64 @@ +#################################### + Triangular Mesh with ICON Topology +#################################### + +The classes `ICONMultimeshNodes` and `ICONCellGridNodes` define node +sets based on an ICON icosahedral mesh: + +- class `ICONCellGridNodes`: data grid, representing cell circumcenters +- class `ICONMultimeshNodes`: hidden mesh, representing the vertices of + a grid hierarchy + +Both classes, together with the corresponding edge builders + +- class `ICONTopologicalProcessorEdges` +- class `ICONTopologicalEncoderEdges` +- class `ICONTopologicalDecoderEdges` + +are based on the mesh hierarchy that is reconstructed from an ICON mesh +file in NetCDF format, making use of the `refinement_level_v` and +`refinement_level_c` property contained therein. + +- `refinement_level_v[vertex] = 0,1,2, ...`, + where 0 denotes the vertices of the base grid, ie. the icosahedron + including the step of root subdivision RXXB00. + +- `refinement_level_c[cell]`: cell refinement level index such that + value 0 denotes the cells of the base grid, ie. the icosahedron + including the step of root subdivision RXXB00. + +To avoid multiple runs of the reconstruction algorithm, a separate +`ICONNodes` instance is created and used by the builders, see the +following YAML example: + +.. code:: yaml + + nodes: + # ICON mesh + icon_mesh: + node_builder: + _target_: anemoi.graphs.nodes.ICONNodes + name: "icon_grid_0026_R03B07_G" + grid_filename: "icon_grid_0026_R03B07_G.nc" + max_level_multimesh: 3 + max_level_dataset: 3 + # Data nodes + data: + node_builder: + _target_: anemoi.graphs.nodes.ICONCellGridNodes + icon_mesh: "icon_mesh" + attributes: ${graph.attributes.nodes} + # Hidden nodes + hidden: + node_builder: + _target_: anemoi.graphs.nodes.ICONMultimeshNodes + icon_mesh: "icon_mesh" + + edges: + # Processor configuration + - source_name: ${graph.hidden} + target_name: ${graph.hidden} + edge_builder: + _target_: anemoi.graphs.edges.ICONTopologicalProcessorEdges + icon_mesh: "icon_mesh" + attributes: ${graph.attributes.edges} diff --git a/pyproject.toml b/pyproject.toml index 48634ea..36168df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,11 +45,13 @@ dependencies = [ "healpy>=1.17", "hydra-core>=1.3", "matplotlib>=3.4", + "netcdf4", "networkx>=3.1", "plotly>=5.19", "torch>=2.2", "torch-geometric>=2.3.1,<2.5", "trimesh>=4.1", + "typeguard", ] optional-dependencies.all = [ ] diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 478c1b5..ffd5354 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -8,7 +8,17 @@ # nor does it submit to any jurisdiction. from .builder import CutOffEdges +from .builder import ICONTopologicalDecoderEdges +from .builder import ICONTopologicalEncoderEdges +from .builder import ICONTopologicalProcessorEdges from .builder import KNNEdges from .builder import MultiScaleEdges -__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges"] +__all__ = [ + "KNNEdges", + "CutOffEdges", + "MultiScaleEdges", + "ICONTopologicalProcessorEdges", + "ICONTopologicalEncoderEdges", + "ICONTopologicalDecoderEdges", +] diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index b5a5df8..92c3003 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -15,6 +15,7 @@ import networkx as nx import numpy as np +import scipy import torch from anemoi.utils.config import DotDict from hydra.utils import instantiate @@ -80,10 +81,8 @@ def get_edge_index(self, graph: HeteroData) -> torch.Tensor: source_nodes, target_nodes = self.prepare_node_data(graph) adjmat = self.get_adjacency_matrix(source_nodes, target_nodes) - # Get source & target indices of the edges edge_index = np.stack([adjmat.col, adjmat.row], axis=0) - return torch.from_numpy(edge_index).to(torch.int32) def register_edges(self, graph: HeteroData) -> HeteroData: @@ -381,7 +380,13 @@ class MultiScaleEdges(BaseEdgeBuilder): Update the graph with the edges. """ - VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes, StretchedTriNodes] + VALID_NODES = [ + TriNodes, + HexNodes, + LimitedAreaTriNodes, + LimitedAreaHexNodes, + StretchedTriNodes, + ] def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): super().__init__(source_name, target_name) @@ -444,3 +449,129 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) - ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." return super().update_graph(graph, attrs_config) + + +class ICONTopologicalBaseEdgeBuilder(BaseEdgeBuilder): + """Base class for computing edges based on ICON grid topology. + + Attributes + ---------- + source_name : str + The name of the source nodes. + target_name : str + The name of the target nodes. + icon_mesh : str + The name of the ICON mesh (defines both the processor mesh and the data) + """ + + def __init__( + self, + source_name: str, + target_name: str, + icon_mesh: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + self.icon_mesh = icon_mesh + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) + + def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData: + """Update the graph with the edges.""" + assert self.icon_mesh is not None, f"{self.__class__.__name__} requires initialized icon_mesh." + self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address] + return super().update_graph(graph, attrs_config) + + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): + """Parameters + ---------- + source_nodes : NodeStorage + The source nodes. + target_nodes : NodeStorage + The target nodes. + """ + LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}") + nrows = self.icon_sub_graph.num_edges + adj_matrix = scipy.sparse.coo_matrix( + ( + np.ones(nrows), + ( + self.icon_sub_graph.edge_vertices[:, self.vertex_index[0]], + self.icon_sub_graph.edge_vertices[:, self.vertex_index[1]], + ), + ) + ) + return adj_matrix + + +class ICONTopologicalProcessorEdges(ICONTopologicalBaseEdgeBuilder): + """Computes edges based on ICON grid topology: processor grid built + from ICON grid vertices. + """ + + def __init__( + self, + source_name: str, + target_name: str, + icon_mesh: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + self.sub_graph_address = "_multi_mesh" + self.vertex_index = (1, 0) + super().__init__( + source_name, + target_name, + icon_mesh, + source_mask_attr_name, + target_mask_attr_name, + ) + + +class ICONTopologicalEncoderEdges(ICONTopologicalBaseEdgeBuilder): + """Computes encoder edges based on ICON grid topology: ICON cell + circumcenters for mapped onto processor grid built from ICON grid + vertices. + """ + + def __init__( + self, + source_name: str, + target_name: str, + icon_mesh: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + self.sub_graph_address = "_cell_grid" + self.vertex_index = (1, 0) + super().__init__( + source_name, + target_name, + icon_mesh, + source_mask_attr_name, + target_mask_attr_name, + ) + + +class ICONTopologicalDecoderEdges(ICONTopologicalBaseEdgeBuilder): + """Computes encoder edges based on ICON grid topology: mapping from + processor grid built from ICON grid vertices onto ICON cell + circumcenters. + """ + + def __init__( + self, + source_name: str, + target_name: str, + icon_mesh: str, + source_mask_attr_name: str | None = None, + target_mask_attr_name: str | None = None, + ): + self.sub_graph_address = "_cell_grid" + self.vertex_index = (0, 1) + super().__init__( + source_name, + target_name, + icon_mesh, + source_mask_attr_name, + target_mask_attr_name, + ) diff --git a/src/anemoi/graphs/generate/icon_mesh.py b/src/anemoi/graphs/generate/icon_mesh.py new file mode 100644 index 0000000..bf0a851 --- /dev/null +++ b/src/anemoi/graphs/generate/icon_mesh.py @@ -0,0 +1,395 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import itertools +import logging +import uuid +from dataclasses import dataclass +from functools import cached_property +from typing import Optional + +import netCDF4 +import numpy as np +import scipy +from typeguard import typechecked +from typing_extensions import Self + +from anemoi.graphs.generate.utils import convert_adjacency_matrix_to_list +from anemoi.graphs.generate.utils import convert_list_to_adjacency_matrix +from anemoi.graphs.generate.utils import selection_matrix + +LOGGER = logging.getLogger(__name__) + + +@typechecked +class NodeSet: + """Stores nodes on the unit sphere.""" + + id_iter: int = itertools.count() # unique ID for each object + gc_vertices: np.ndarray # geographical (lat/lon) coordinates [rad], shape [:,2] + + def __init__(self, lon: np.ndarray, lat: np.ndarray): + self.gc_vertices = np.column_stack((lon, lat)) + self.id = uuid.uuid4() + + @property + def num_vertices(self) -> int: + return self.gc_vertices.shape[0] + + @cached_property + def cc_vertices(self): + """Cartesian coordinates [rad], shape [:,3].""" + return self._gc_to_cartesian() + + def __add__(self, other: Self) -> Self: + """concatenates two node sets.""" + gc_vertices = np.concatenate((self.gc_vertices, other.gc_vertices)) + return NodeSet(gc_vertices[:, 0], gc_vertices[:, 1]) + + def __eq__(self, other: Self) -> bool: + """Compares two node sets.""" + return self.id == other.id + + def _gc_to_cartesian(self, radius: float = 1.0) -> np.ndarray: + """Returns Cartesian coordinates of the node set, shape [:,3].""" + xyz = ( + radius * np.cos(lat_rad := self.gc_vertices[:, 1]) * np.cos(lon_rad := self.gc_vertices[:, 0]), + radius * np.cos(lat_rad) * np.sin(lon_rad), + radius * np.sin(lat_rad), + ) + return np.stack(xyz, axis=-1) + + +@typechecked +@dataclass +class EdgeID: + """Stores additional categorical data for each edge (IDs for heterogeneous input).""" + + edge_id: np.ndarray + num_classes: int + + def __add__(self, other: Self): + """Concatenates two edge ID datasets.""" + assert self.num_classes == other.num_classes + return EdgeID( + edge_id=np.concatenate((self.edge_id, other.edge_id)), + num_classes=self.num_classes, + ) + + +@typechecked +class GeneralGraph: + """Stores edges for a given node set.""" + + nodeset: NodeSet # graph nodes + edge_vertices: np.ndarray # vertex indices for each edge, shape [:,2] + + def __init__(self, nodeset: NodeSet, bidirectional: bool, edge_vertices: np.ndarray): + self.nodeset = nodeset + # (optional) duplicate edges (bi-directional): + if bidirectional: + self.edge_vertices = np.concatenate([edge_vertices, np.fliplr(edge_vertices)]) + else: + self.edge_vertices = edge_vertices + + @property + def num_vertices(self) -> int: + return self.nodeset.num_vertices + + @property + def num_edges(self) -> int: + return self.edge_vertices.shape[0] + + +@typechecked +class BipartiteGraph: + """Graph defined on a pair of NodeSets.""" + + nodeset: tuple[NodeSet, NodeSet] # source and target node set + edge_vertices: np.ndarray # vertex indices for each edge, shape [:,2] + edge_id: np.ndarray # additional ID for each edge (markers for heterogeneous input) + + def __init__( + self, + nodeset: tuple[NodeSet, NodeSet], + edge_vertices: np.ndarray, + edge_id: Optional[EdgeID] = None, + ): + self.nodeset = nodeset + self.edge_vertices = edge_vertices + self.edge_id = edge_id + + @property + def num_edges(self) -> int: + return self.edge_vertices.shape[0] + + def __add__(self, other: "BipartiteGraph"): + """Concatenates two bipartite graphs that share a common target node set. + Shifts the node indices of the second bipartite graph. + """ + + if not self.nodeset[1] == other.nodeset[1]: + raise ValueError("Only bipartite graphs with common target node set can be merged.") + shifted_edge_vertices = other.edge_vertices + shifted_edge_vertices[:, 0] += self.nodeset[0].num_vertices + # (Optional:) merge one-hot-encoded categorical data (`edge_id`) + edge_id = None if None in (self.edge_id, other.edge_id) else self.edge_id + other.edge_id + + return BipartiteGraph( + nodeset=(self.nodeset[0] + other.nodeset[0], self.nodeset[1]), + edge_vertices=np.concatenate((self.edge_vertices, shifted_edge_vertices)), + edge_id=edge_id, + ) + + +@typechecked +class ICONMultiMesh(GeneralGraph): + """Reads vertices and topology from an ICON grid file; creates multi-mesh.""" + + uuidOfHGrid: str + max_level: int + nodeset: NodeSet # set of ICON grid vertices + cell_vertices: np.ndarray + + def __init__(self, icon_grid_filename: str, max_level: Optional[int] = None): + + # open file, representing the finest level + LOGGER.debug(f"{type(self).__name__}: read ICON grid file '{icon_grid_filename}'") + with netCDF4.Dataset(icon_grid_filename, "r") as ncfile: + # read vertex coordinates + vlon = read_coordinate_array(ncfile, "vlon", "vertex") + vlat = read_coordinate_array(ncfile, "vlat", "vertex") + + edge_vertices_fine = np.asarray(ncfile.variables["edge_vertices"][:] - 1, dtype=np.int64).transpose() + assert ncfile.variables["edge_vertices"].dimensions == ("nc", "edge") + + cell_vertices_fine = np.asarray(ncfile.variables["vertex_of_cell"][:] - 1, dtype=np.int64).transpose() + assert ncfile.variables["vertex_of_cell"].dimensions == ("nv", "cell") + + reflvl_vertex = ncfile.variables["refinement_level_v"][:] + assert ncfile.variables["refinement_level_v"].dimensions == ("vertex",) + + self.uuidOfHGrid = ncfile.uuidOfHGrid + + self.max_level = max_level if max_level is not None else reflvl_vertex.max() + + # generate edge-vertex relations for coarser levels: + (edge_vertices, cell_vertices) = self._get_hierarchy_of_icon_edge_graphs( + edge_vertices_fine=edge_vertices_fine, + cell_vertices_fine=cell_vertices_fine, + reflvl_vertex=reflvl_vertex, + ) + # restrict edge-vertex list to multi_mesh level "max_level": + if self.max_level < len(edge_vertices): + (self.edge_vertices, self.cell_vertices, vlon, vlat) = self._restrict_multi_mesh_level( + edge_vertices, + cell_vertices, + reflvl_vertex=reflvl_vertex, + vlon=vlon, + vlat=vlat, + ) + # store vertices as a `NodeSet`: + self.nodeset = NodeSet(vlon, vlat) + # concatenate edge-vertex lists (= edges of the multi-level mesh): + multi_mesh_edges = np.concatenate([edges for edges in self.edge_vertices], axis=0) + # generate multi-mesh graph data structure: + super().__init__(nodeset=self.nodeset, bidirectional=True, edge_vertices=multi_mesh_edges) + + def _restrict_multi_mesh_level( + self, + edge_vertices: list[np.ndarray], + cell_vertices: np.ndarray, + reflvl_vertex: np.ndarray, + vlon: np.ndarray, + vlat: np.ndarray, + ) -> tuple[list[np.ndarray], np.ndarray, np.ndarray, np.ndarray]: + """Creates a new mesh with only the vertices at the desired level.""" + + num_vertices = reflvl_vertex.shape[0] + vertex_mask = reflvl_vertex <= self.max_level + vertex_glb2loc = np.zeros(num_vertices, dtype=int) + vertex_glb2loc[vertex_mask] = np.arange(vertex_mask.sum()) + return ( + [vertex_glb2loc[vertices] for vertices in edge_vertices[: self.max_level + 1]], + # cell_vertices: preserve negative indices (incomplete cells) + np.where(cell_vertices >= 0, vertex_glb2loc[cell_vertices], cell_vertices), + vlon[vertex_mask], + vlat[vertex_mask], + ) + + def _get_hierarchy_of_icon_edge_graphs( + self, + edge_vertices_fine: np.ndarray, + cell_vertices_fine: np.ndarray, + reflvl_vertex: np.ndarray, + ) -> tuple[list[np.ndarray], np.ndarray]: + """Returns a list of edge-vertex relations (coarsest to finest level).""" + + edge_vertices = [edge_vertices_fine] # list of edge-vertex relations (coarsest to finest level). + + num_vertices = reflvl_vertex.shape[0] + # edge-to-vertex adjacency matrix with 2 non-zero entries per row: + edge2vertex_matrix = convert_list_to_adjacency_matrix(edge_vertices_fine, num_vertices) + # cell-to-vertex adjacency matrix with 3 non-zero entries per row: + cell2vertex_matrix = convert_list_to_adjacency_matrix(cell_vertices_fine, num_vertices) + vertex2vertex_matrix = edge2vertex_matrix.transpose() * edge2vertex_matrix + vertex2vertex_matrix.setdiag(np.ones(num_vertices)) # vertices are self-connected + + selected_vertex_coarse = scipy.sparse.diags(np.ones(num_vertices), dtype=bool) + + # coarsen edge-vertex list from level `ilevel -> ilevel - 1`: + for ilevel in reversed(range(1, reflvl_vertex.max() + 1)): + LOGGER.debug(f" edges[{ilevel}] = {edge_vertices[0].shape[0] : >9}") + + # define edge selection matrix (selecting only edges of which have + # exactly one coarse vertex): + # + # get a boolean mask, matching all edges where one of its vertices + # has refinement level index `ilevel`: + ref_level_mask = reflvl_vertex[edge_vertices[0]] == ilevel + edges_coarse = np.logical_xor(ref_level_mask[:, 0], ref_level_mask[:, 1]) # = bisected coarse edges + idx_edge2edge = np.argwhere(edges_coarse).flatten() + selected_edges = selection_matrix(idx_edge2edge, edges_coarse.shape[0]) + + # define vertex selection matrix selecting only vertices of + # level `ilevel`: + idx_v_fine = np.argwhere(reflvl_vertex == ilevel).flatten() + selected_vertex_fine = selection_matrix(idx_v_fine, num_vertices) + # define vertex selection matrix, selecting only vertices of + # level < `ilevel`, by successively removing `s_fine` from an identity matrix. + selected_vertex_coarse.data[0][idx_v_fine] = False + + # create an adjacency matrix which links each fine level + # vertex to its two coarser neighbor vertices: + vertex2vertex_fine2coarse = selected_vertex_fine * vertex2vertex_matrix * selected_vertex_coarse + # remove rows that have only one non-zero entry + # (corresponding to incomplete parent triangles in LAM grids): + csum = vertex2vertex_fine2coarse * np.ones((vertex2vertex_fine2coarse.shape[1], 1)) + selected_vertex2vertex = selection_matrix( + np.argwhere(csum == 2).flatten(), vertex2vertex_fine2coarse.shape[0] + ) + vertex2vertex_fine2coarse = selected_vertex2vertex * vertex2vertex_fine2coarse + + # then construct the edges-to-parent-vertex adjacency matrix: + parent_edge_vertices = selected_edges * edge2vertex_matrix * vertex2vertex_fine2coarse + # again, we have might have selected edges within + # `selected_edges` which are part of an incomplete parent edge + # (LAM case). We filter these here: + csum = parent_edge_vertices * np.ones((parent_edge_vertices.shape[1], 1)) + selected_edge2edge = selection_matrix(np.argwhere(csum == 2).flatten(), parent_edge_vertices.shape[0]) + parent_edge_vertices = selected_edge2edge * parent_edge_vertices + + # note: the edges-vertex adjacency matrix still has duplicate + # rows, since two child edges have the same parent edge. + edge_vertices_coarse = convert_adjacency_matrix_to_list(parent_edge_vertices, ncols_per_row=2) + edge_vertices.insert(0, edge_vertices_coarse) + + # store cell-to-vert adjacency matrix + if ilevel > self.max_level: + cell2vertex_matrix = cell2vertex_matrix * vertex2vertex_fine2coarse + # similar to the treatment above, we need to handle + # coarse LAM cells which are incomplete. + csum = cell2vertex_matrix * np.ones((cell2vertex_matrix.shape[1], 1)) + selected_cell2cell = selection_matrix(np.argwhere(csum == 3).flatten(), cell2vertex_matrix.shape[0]) + cell2vertex_matrix = selected_cell2cell * cell2vertex_matrix + + # replace edge-to-vertex and vert-to-vert adjacency matrices (for next level): + if ilevel > 1: + vertex2vertex_matrix = selected_vertex_coarse * vertex2vertex_matrix * vertex2vertex_fine2coarse + edge2vertex_matrix = convert_list_to_adjacency_matrix(edge_vertices_coarse, num_vertices) + + # Fine-level cells outside of multi-mesh (LAM boundary) + # correspond to empty rows in the adjacency matrix. We + # substitute these by three additional, non-existent vertices: + csum = 3 - cell2vertex_matrix * np.ones((cell2vertex_matrix.shape[1], 1)) + nvmax = cell2vertex_matrix.shape[1] + cell2vertex_matrix = scipy.sparse.csr_matrix(scipy.sparse.hstack((cell2vertex_matrix, csum, csum, csum))) + + # build a list of cell-vertices [1..num_cells,1..3] for all + # fine-level cells: + cell_vertices = convert_adjacency_matrix_to_list(cell2vertex_matrix, remove_duplicates=False, ncols_per_row=3) + + # finally, translate non-existent vertices into "-1" indices: + cell_vertices = np.where( + cell_vertices >= nvmax, + -np.ones(cell_vertices.shape, dtype=np.int32), + cell_vertices, + ) + + return (edge_vertices, cell_vertices) + + +@typechecked +class ICONCellDataGrid(BipartiteGraph): + """Reads cell locations from an ICON grid file; builds grid-to-mesh edges based on ICON topology.""" + + uuidOfHGrid: str + nodeset: NodeSet # set of ICON cell circumcenters + max_level: int + select_c: np.ndarray + + def __init__( + self, + icon_grid_filename: str, + multi_mesh: Optional[ICONMultiMesh] = None, + max_level: Optional[int] = None, + ): + # open file, representing the finest level + LOGGER.debug(f"{type(self).__name__}: read ICON grid file '{icon_grid_filename}'") + with netCDF4.Dataset(icon_grid_filename, "r") as ncfile: + # read cell circumcenter coordinates + clon = read_coordinate_array(ncfile, "clon", "cell") + clat = read_coordinate_array(ncfile, "clat", "cell") + + reflvl_cell = ncfile.variables["refinement_level_c"][:] + assert ncfile.variables["refinement_level_c"].dimensions == ("cell",) + + self.uuidOfHGrid = ncfile.uuidOfHGrid + + if max_level is not None: + self.max_level = max_level + else: + self.max_level = reflvl_cell.max() + + # restrict to level `max_level`: + self.select_c = np.argwhere(reflvl_cell <= self.max_level) + # generate source grid node set: + self.nodeset = NodeSet(clon[self.select_c], clat[self.select_c]) + + if multi_mesh is not None: + # generate edges between source grid nodes and multi-mesh nodes: + edge_vertices = self._get_grid2mesh_edges(self.select_c, multi_mesh=multi_mesh) + super().__init__((self.nodeset, multi_mesh.nodeset), edge_vertices) + + def _get_grid2mesh_edges(self, select_c: np.ndarray, multi_mesh: ICONMultiMesh) -> np.ndarray: + """Create "grid-to-mesh" edges, ie. edges from (clat,clon) to the + vertices of the multi-mesh. + """ + + num_cells = select_c.shape[0] + num_vertices_per_cell = multi_mesh.cell_vertices.shape[1] + src_list = np.kron(np.arange(num_cells), np.ones((1, num_vertices_per_cell), dtype=np.int64)).flatten() + dst_list = multi_mesh.cell_vertices[select_c[:, 0], :].flatten() + edge_vertices = np.stack((src_list, dst_list), axis=1, dtype=np.int64) + return edge_vertices + + +# ------------------------------------------------------------- + + +@typechecked +def read_coordinate_array(ncfile, arrname: str, dimname: str) -> np.ndarray: + """Auxiliary routine, reading a coordinate array, checking consistency.""" + arr = ncfile.variables[arrname][:] + assert ncfile.variables[arrname].dimensions == (dimname,) + assert ncfile.variables[arrname].units == "radian" + # netCDF4 returns all variables as numpy.ma.core.MaskedArray. + # -> convert to regular arrays + assert not arr.mask.any(), f"There are missing values in {arrname}" + return arr.data diff --git a/src/anemoi/graphs/generate/utils.py b/src/anemoi/graphs/generate/utils.py index 2dfc320..ccb2c0c 100644 --- a/src/anemoi/graphs/generate/utils.py +++ b/src/anemoi/graphs/generate/utils.py @@ -8,6 +8,8 @@ # nor does it submit to any jurisdiction. import numpy as np +import scipy +from typeguard import typechecked def get_coordinates_ordering(coords: np.ndarray) -> np.ndarray: @@ -29,3 +31,77 @@ def get_coordinates_ordering(coords: np.ndarray) -> np.ndarray: index_longitude = np.argsort(coords[index_latitude][:, 0])[::-1] node_ordering = np.arange(coords.shape[0])[index_latitude][index_longitude] return node_ordering + + +@typechecked +def convert_list_to_adjacency_matrix(list_matrix: np.ndarray, ncols: int = 0) -> scipy.sparse.csr_matrix: + """Convert an edge list into an adjacency matrix. + + Parameters + ---------- + list_matrix : np.ndarray + boolean matrix given by list of column indices for each row. + ncols : int + number of columns in result matrix. + + Returns + ------- + scipy.sparse.csr_matrix + sparse matrix [nrows, ncols] + """ + nrows, ncols_per_row = list_matrix.shape + indptr = np.arange(ncols_per_row * (nrows + 1), step=ncols_per_row) + indices = list_matrix.ravel() + return scipy.sparse.csr_matrix((np.ones(nrows * ncols_per_row), indices, indptr), dtype=bool, shape=(nrows, ncols)) + + +@typechecked +def convert_adjacency_matrix_to_list( + adj_matrix: scipy.sparse.csr_matrix, + ncols_per_row: int, + remove_duplicates: bool = True, +) -> np.ndarray: + """Convert an adjacency matrix into an edge list. + + Parameters + ---------- + adj_matrix : scipy.sparse.csr_matrix + sparse (boolean) adjacency matrix + ncols_per_row : int + number of nonzero entries per row + remove_duplicates : bool + logical flag: remove duplicate rows. + + Returns + ------- + np.ndarray + boolean matrix given by list of column indices for each row. + """ + if remove_duplicates: + # The edges-vertex adjacency matrix may have duplicate rows, remove + # them by selecting the rows that are unique: + nrows = int(adj_matrix.nnz // ncols_per_row) + mat = adj_matrix.indices.reshape((nrows, ncols_per_row)) + return np.unique(mat, axis=0) + + nrows = adj_matrix.shape[0] + return adj_matrix.indices.reshape((nrows, ncols_per_row)) + + +@typechecked +def selection_matrix(idx: np.ndarray, num_diagonals: int) -> scipy.sparse.csr_matrix: + """Create a diagonal selection matrix. + + Parameters + ---------- + idx : np.ndarray + integer array of indices + num_diagonals : int + size of (square) selection matrix + + Returns + ------- + scipy.sparse.csr_matrix + diagonal matrix with ones at selected indices (idx,idx). + """ + return scipy.sparse.csr_matrix((np.ones((len(idx))), (idx, idx)), dtype=bool, shape=(num_diagonals, num_diagonals)) diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 3a917f2..228a282 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -12,6 +12,9 @@ from .builders.from_file import ZarrDatasetNodes from .builders.from_healpix import HEALPixNodes from .builders.from_healpix import LimitedAreaHEALPixNodes +from .builders.from_icon import ICONCellGridNodes +from .builders.from_icon import ICONMultimeshNodes +from .builders.from_icon import ICONNodes from .builders.from_refined_icosahedron import HexNodes from .builders.from_refined_icosahedron import LimitedAreaHexNodes from .builders.from_refined_icosahedron import LimitedAreaTriNodes @@ -29,4 +32,7 @@ "LimitedAreaTriNodes", "LimitedAreaHexNodes", "StretchedTriNodes", + "ICONMultimeshNodes", + "ICONCellGridNodes", + "ICONNodes", ] diff --git a/src/anemoi/graphs/nodes/builders/from_icon.py b/src/anemoi/graphs/nodes/builders/from_icon.py new file mode 100644 index 0000000..f0473e9 --- /dev/null +++ b/src/anemoi/graphs/nodes/builders/from_icon.py @@ -0,0 +1,83 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import numpy as np +import torch +from anemoi.utils.config import DotDict +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.icon_mesh import ICONCellDataGrid +from anemoi.graphs.generate.icon_mesh import ICONMultiMesh +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + + +class ICONNodes(BaseNodeBuilder): + """ICON grid (cell and vertex locations).""" + + def __init__(self, name: str, grid_filename: str, max_level_multimesh: int, max_level_dataset: int) -> None: + self.grid_filename = grid_filename + + self.multi_mesh = ICONMultiMesh(self.grid_filename, max_level=max_level_multimesh) + self.cell_grid = ICONCellDataGrid(self.grid_filename, self.multi_mesh, max_level=max_level_dataset) + + super().__init__(name) + + def get_coordinates(self) -> torch.Tensor: + return torch.from_numpy(self.multi_mesh.nodeset.gc_vertices.astype(np.float32)).fliplr() + + def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: + graph[self.name]["_grid_filename"] = self.grid_filename + graph[self.name]["_multi_mesh"] = self.multi_mesh + graph[self.name]["_cell_grid"] = self.cell_grid + return super().register_attributes(graph, config) + + +class ICONTopologicalBaseNodeBuilder(BaseNodeBuilder): + """Base class for data mesh or processor mesh based on an ICON grid. + + Parameters + ---------- + name : str + key for the nodes in the HeteroData graph object. + icon_mesh : str + key corresponding to the ICON mesh (cells and vertices). + """ + + def __init__(self, name: str, icon_mesh: str) -> None: + self.icon_mesh = icon_mesh + super().__init__(name) + + def update_graph(self, graph: HeteroData, attr_config: DotDict | None = None) -> HeteroData: + """Update the graph with new nodes.""" + self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address] + return super().update_graph(graph, attr_config) + + +class ICONMultimeshNodes(ICONTopologicalBaseNodeBuilder): + """Processor mesh based on an ICON grid.""" + + def __init__(self, name: str, icon_mesh: str) -> None: + self.sub_graph_address = "_multi_mesh" + super().__init__(name, icon_mesh) + + def get_coordinates(self) -> torch.Tensor: + return torch.from_numpy(self.icon_sub_graph.nodeset.gc_vertices.astype(np.float32)).fliplr() + + +class ICONCellGridNodes(ICONTopologicalBaseNodeBuilder): + """Data mesh based on an ICON grid.""" + + def __init__(self, name: str, icon_mesh: str) -> None: + self.sub_graph_address = "_cell_grid" + super().__init__(name, icon_mesh) + + def get_coordinates(self) -> torch.Tensor: + return torch.from_numpy(self.icon_sub_graph.nodeset[0].gc_vertices.astype(np.float32)).fliplr() diff --git a/tests/nodes/test_icon.py b/tests/nodes/test_icon.py new file mode 100644 index 0000000..528222a --- /dev/null +++ b/tests/nodes/test_icon.py @@ -0,0 +1,196 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import netCDF4 +import numpy as np +import pytest +import torch +from torch_geometric.data import HeteroData + +from anemoi.graphs.edges import ICONTopologicalDecoderEdges +from anemoi.graphs.edges import ICONTopologicalEncoderEdges +from anemoi.graphs.edges import ICONTopologicalProcessorEdges +from anemoi.graphs.generate.icon_mesh import ICONCellDataGrid +from anemoi.graphs.generate.icon_mesh import ICONMultiMesh +from anemoi.graphs.nodes import ICONCellGridNodes +from anemoi.graphs.nodes import ICONMultimeshNodes +from anemoi.graphs.nodes import ICONNodes +from anemoi.graphs.nodes.builders.base import BaseNodeBuilder + + +class DatasetMock: + """This datasets emulates the most primitive unstructured grid with + refinement. + + Enumeration of cells , edges and vertices in netCDF file is 1 based. + C: cell + E: edge + V: vertex + + Cell C2 with its additional vertex V4 and edges E4 and E4 were added as + a first refinement. + + [V1: 0, 1]🢀-E3--[V3: 1, 1] + 🢁 ╲ 🢁 + | ╲ [C1: ⅔, ⅔] | + | ╲ | + E5 E1 E2 + | ╲ | + | ╲ | + | [C2: ⅓, ⅓] ╲ | + | 🢆 | + [V4: 0, 1]🢀-E4--[V2: 1, 1] + + Note: Triangular refinement does not actually work like this. This grid + mock serves testing purposes only. + + """ + + def __init__(self, *args, **kwargs): + + class MockVariable: + def __init__(self, data, units, dimensions): + self.data = np.ma.asarray(data) + self.shape = data.shape + self.units = units + self.dimensions = dimensions + + def __getitem__(self, key): + return self.data[key] + + self.variables = { + "vlon": MockVariable(np.array([0, 1, 1, 0]), "radian", ("vertex",)), + "vlat": MockVariable(np.array([1, 0, 1, 0]), "radian", ("vertex",)), + "clon": MockVariable(np.array([0.66, 0.33]), "radian", ("cell",)), + "clat": MockVariable(np.array([0.66, 0.33]), "radian", ("cell",)), + "edge_vertices": MockVariable(np.array([[1, 2], [2, 3], [3, 1], [2, 4], [4, 1]]).T, "", ("nc", "edge")), + "vertex_of_cell": MockVariable(np.array([[1, 2, 3], [1, 2, 4]]).T, "", ("nv", "cell")), + "refinement_level_v": MockVariable(np.array([0, 0, 0, 1]), "", ("vertex",)), + "refinement_level_c": MockVariable(np.array([0, 1]), "", ("cell",)), + } + """common array dimensions: + nc: 2, # constant + nv: 3, # constant + vertex: 4, + edge: 5, + cell: 2, + """ + self.uuidOfHGrid = "__test_data__" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +@pytest.mark.parametrize("max_level_multimesh,max_level_dataset", [(0, 0), (0, 1), (1, 1)]) +def test_init(monkeypatch, max_level_multimesh: int, max_level_dataset: int): + """Test ICONNodes initialization.""" + + monkeypatch.setattr(netCDF4, "Dataset", DatasetMock) + node_builder = ICONNodes( + name="test_nodes", + grid_filename="test.nc", + max_level_multimesh=max_level_multimesh, + max_level_dataset=max_level_dataset, + ) + assert isinstance(node_builder, BaseNodeBuilder) + assert isinstance(node_builder, ICONNodes) + + +@pytest.mark.parametrize("Node_builder", [ICONCellGridNodes, ICONMultimeshNodes]) +def test_Node_builder_dependencies(monkeypatch, Node_builder): + """Test that the `Node_builder` depends on the presence of ICONNodes.""" + + monkeypatch.setattr(netCDF4, "Dataset", DatasetMock) + nodes = ICONNodes("test_icon_nodes", "test.nc", 0, 0) + data_nodes = Node_builder("data_nodes", "test_icon_nodes") + config = {} + graph = HeteroData() + graph = nodes.register_attributes(graph, config) + + data_nodes.update_graph(graph) + + data_nodes = ICONCellGridNodes("data_nodes2", "missing_icon_nodes") + with pytest.raises(KeyError): + data_nodes.update_graph(graph) + + +class Test_Edge_builder_dependencies: + @pytest.fixture() + def icon_graph(self, monkeypatch) -> HeteroData: + """Return a HeteroData object with ICONNodes nodes.""" + graph = HeteroData() + monkeypatch.setattr(netCDF4, "Dataset", DatasetMock) + nodes = ICONNodes("test_icon_nodes", "test.nc", 1, 0) + + graph = nodes.update_graph(graph, {}) + + data_nodes = ICONCellGridNodes("data", "test_icon_nodes") + graph = data_nodes.register_attributes(graph, {}) + + return graph + + @pytest.mark.parametrize( + "Edge_builder", [ICONTopologicalProcessorEdges, ICONTopologicalEncoderEdges, ICONTopologicalDecoderEdges] + ) + def test_ProcessorEdges_dependencies(self, icon_graph, Edge_builder): + """Test that the `Edge_builder` depends on the presence of ICONNodes.""" + + edges = Edge_builder( + source_name="data", + target_name="data", + icon_mesh="test_icon_nodes", + ) + edges.update_graph(icon_graph) + + edges2 = Edge_builder( + source_name="data", + target_name="data", + icon_mesh="missing_icon_nodes", + ) + with pytest.raises(KeyError): + edges2.update_graph(icon_graph) + + +def test_register_nodes(monkeypatch): + """Test ICONNodes register correctly the nodes.""" + monkeypatch.setattr(netCDF4, "Dataset", DatasetMock) + nodes = ICONNodes("test_icon_nodes", "test.nc", 0, 0) + graph = HeteroData() + + graph = nodes.register_nodes(graph) + + assert graph["test_icon_nodes"].x is not None + assert isinstance(graph["test_icon_nodes"].x, torch.Tensor) + assert graph["test_icon_nodes"].x.shape[1] == 2 + assert graph["test_icon_nodes"].x.shape[0] == 3, "number of vertices at refinement_level_v == 0" + assert graph["test_icon_nodes"].node_type == "ICONNodes" + + nodes2 = ICONNodes("test_icon_nodes", "test.nc", 1, 0) + graph = nodes2.register_nodes(graph) + assert graph["test_icon_nodes"].x.shape[0] == 4, "number of vertices at refinement_level_v == 1" + + +def test_register_attributes( + monkeypatch, + graph_with_nodes: HeteroData, +): + """Test ICONNodes register correctly the weights.""" + monkeypatch.setattr(netCDF4, "Dataset", DatasetMock) + nodes = ICONNodes("test_icon_nodes", "test.nc", 0, 0) + config = {"test_attr": {"_target_": "anemoi.graphs.nodes.attributes.UniformWeights"}} + graph = HeteroData() + + graph = nodes.register_attributes(graph, config) + + assert graph["test_icon_nodes"]["_grid_filename"] is not None + assert isinstance(graph["test_icon_nodes"]["_multi_mesh"], ICONMultiMesh) + assert isinstance(graph["test_icon_nodes"]["_cell_grid"], ICONCellDataGrid)