Skip to content

Commit

Permalink
6 generate graphs from icosahedral meshes (#11)
Browse files Browse the repository at this point in the history
* Global Encoder-Processor-Decoder graph (#9)

* feat: Initial implementation of global graphs

Co-authored by: Mario Santa Cruz <[email protected]>
Co-authored-by: Helen Theissen <[email protected]>
Co-authored-by: Jesper Dramsch <[email protected]>

* fix: attributes as torch.float32

* new test: attributes must be float32

* fix typo

* Homogeneize base builders

* improve test docstrings

* homogeneize (name as class attribute)

* new input config

* new default

* feat: Initial implementation of global graphs

Co-authored by: Mario Santa Cruz <[email protected]>

* add cli command

* Ignore .pt files

* run pre-commit

* docstring + log erros

* initial tests

* feat: initial version of AttributeBuilder

* refactor: separate into node edge attribute builders

* feat: edge_length moved to edges/attributes.py

* remove __init__

* bugfix (encoder edge lengths) + refector

* feat: support path and dict for `config` argument

* fix: error

* refactor: naming

* fix: pre-commit

* feat: builders icosahedral

* feat: Add icosahedral graph generation

Co-authored-by: Mario Santa Cruz <[email protected]>

* refactor: remove create_shere

* feat: Icosahedral edge builder

* feat: hexagonal graph generation

Co-authored-by: Mario Santa Cruz <[email protected]>

* feat: hexagonal builders

* fix: AOI not implemented yet

* fix: abstractmethod and renaming

* chore: add dependencies

* test: add tests for trimesh

* test: add tests for hex (h3)

* fix: imports

* fix: output type

* refactor: delete unused file

* refactor: renaming and positioning

* feat: ensure src and dst always the same

* fix: imports

* fix: edge_name not supported

* test: add tests for TriIcosahedralEdges

* fix: assert missing for Hexagonal edges

* test: hexagonal edges

* fix: avoid same name

* fix: imports

* fix: conflicts

* update tests

* Include xhops to hexagonal edges

* docs: update docstrings

* fix: update attribute name

* refactor: rename multiscale nodes

* refactor: rename icosahedral nodes

* improve: clarity of function

* improve: function syntax

* refactor: simplify resolution assignment

* refactor: improve variable naming in icosahedral graph generation

* more comments

* refactor: naming

* refactor: separate into functions

* refactor: remove unused code

* refactor: remove unused options and rename

* doc: clarify cells and nodes

* fix: add return statements

* homogeneize: tri & hex edges

* naming: xhops to x_hops

* naming: x_hops in docstring

* docstring

* blank lines

* fix: add return statement

* simplify icoshaedral edges

* feat: select edge builder method based on ico node type

* test: adjust tests to new MultiScaleEdges

* docs: improve docstrings

* fix: remove LAM filtering for icosahedral, leave for next PR

* h3 (v4) not supported

---------
  • Loading branch information
theissenhelen authored Jul 25, 2024
1 parent 5b3bbe1 commit 746ea2b
Show file tree
Hide file tree
Showing 12 changed files with 770 additions and 10 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ dynamic = [
dependencies = [
"anemoi-datasets[data]>=0.3.3",
"anemoi-utils>=0.3.6",
"h3>=3.7.6,<4",
"hydra-core>=1.3",
"networkx>=3.1",
"torch>=2.2",
"torch-geometric>=2.3.1,<2.5",
"trimesh>=4.1",
]

optional-dependencies.all = [
Expand Down
3 changes: 2 additions & 1 deletion src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .builder import CutOffEdges
from .builder import KNNEdges
from .builder import MultiScaleEdges

__all__ = ["KNNEdges", "CutOffEdges"]
__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges"]
74 changes: 70 additions & 4 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import abstractmethod
from typing import Optional

import networkx as nx
import numpy as np
import torch
from anemoi.utils.config import DotDict
Expand All @@ -12,6 +13,10 @@
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.generate import hexagonal
from anemoi.graphs.generate import icosahedral
from anemoi.graphs.nodes.builder import HexNodes
from anemoi.graphs.nodes.builder import TriNodes
from anemoi.graphs.utils import get_grid_reference_distance

LOGGER = logging.getLogger(__name__)
Expand All @@ -33,7 +38,7 @@ def name(self) -> tuple[str, str, str]:
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ...

def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
"""Prepare nodes information."""
"""Prepare node information and get source and target nodes."""
return graph[self.source_name], graph[self.target_name]

def get_edge_index(self, graph: HeteroData) -> torch.Tensor:
Expand Down Expand Up @@ -188,8 +193,6 @@ class CutOffEdges(BaseEdgeBuilder):
The name of the target nodes.
cutoff_factor : float
Factor to multiply the grid reference distance to get the cut-off radius.
radius : float
Cut-off radius.
Methods
-------
Expand Down Expand Up @@ -235,7 +238,7 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor]
return radius

def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
"""Prepare nodes information."""
"""Prepare node information and get source and target nodes."""
self.radius = self.get_cutoff_radius(graph)
return super().prepare_node_data(graph)

Expand All @@ -260,3 +263,66 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
nearest_neighbour.fit(source_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo()
return adj_matrix


class MultiScaleEdges(BaseEdgeBuilder):
"""Base class for multi-scale edges in the nodes of a graph."""

def __init__(self, source_name: str, target_name: str, x_hops: int):
super().__init__(source_name, target_name)
assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same."
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
assert x_hops > 0, "Number of x_hops must be positive"
self.x_hops = x_hops

def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
x_hops=self.x_hops,
) # HeteroData refuses to accept None

adjmat = nx.to_scipy_sparse_array(
source_nodes["nx_graph"], nodelist=list(range(len(source_nodes["nx_graph"]))), format="coo"
)
return adjmat

def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):

source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
x_hops=self.x_hops,
)

adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo")
return adjmat

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.node_type == TriNodes.__name__:
adjmat = self.adjacency_from_tri_nodes(source_nodes)
elif self.node_type == HexNodes.__name__:
adjmat = self.adjacency_from_hex_nodes(source_nodes)
else:
raise ValueError(f"Invalid node type {self.node_type}")

adjmat = self.post_process_adjmat(source_nodes, adjmat)

return adjmat

def post_process_adjmat(self, nodes: NodeStorage, adjmat):
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["node_ordering"])}
sort_func = np.vectorize(graph_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
return adjmat

def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
assert (
graph[self.source_name].node_type == TriNodes.__name__
or graph[self.source_name].node_type == HexNodes.__name__
), f"{self.__class__.__name__} requires {TriNodes.__name__} or {HexNodes.__name__}."

self.node_type = graph[self.source_name].node_type

return super().update_graph(graph, attrs_config)
229 changes: 229 additions & 0 deletions src/anemoi/graphs/generate/hexagonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
from typing import Optional

import h3
import networkx as nx
import numpy as np
from sklearn.metrics.pairwise import haversine_distances


def create_hexagonal_nodes(
resolutions: list[int],
area: Optional[dict] = None,
) -> tuple[nx.Graph, np.ndarray, list[int]]:
"""Creates a global mesh from a refined icosahedro.
This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each
refinement level, a hexagon cell (nodes) has 7 child cells (aperture 7).
Parameters
----------
resolutions : list[int]
Levels of mesh resolution to consider.
area : dict
A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global
mesh.
Returns
-------
graph : networkx.Graph
The specified graph (nodes & edges).
coords_rad : np.ndarray
The node coordinates (not ordered) in radians.
node_ordering : list[int]
Order of the nodes in the graph to be sorted by latitude and longitude.
"""
graph = nx.Graph()

area_kwargs = {"area": area}

for resolution in resolutions:
graph = add_nodes_for_resolution(graph, resolution, **area_kwargs)

coords = np.deg2rad(np.array([h3.h3_to_geo(node) for node in graph.nodes]))

# Sort nodes by latitude and longitude
node_ordering = np.lexsort(coords.T[::-1], axis=0)

return graph, coords, list(node_ordering)


def add_nodes_for_resolution(
graph: nx.Graph,
resolution: int,
**area_kwargs: Optional[dict],
) -> nx.Graph:
"""Add all nodes at a specified refinement level to a graph.
Parameters
----------
graph : networkx.Graph
The graph to add the nodes.
resolution : int
The H3 refinement level. It can be an integer from 0 to 15.
area_kwargs: dict
Additional arguments to pass to the get_nodes_at_resolution function.
"""

nodes = get_nodes_at_resolution(resolution, **area_kwargs)

for idx in nodes:
graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx)))

return graph


def get_nodes_at_resolution(
resolution: int,
area: Optional[dict] = None,
) -> set[str]:
"""Get nodes at a specified refinement level over the entire globe.
If area is not None, it will return the nodes within the specified area
Parameters
----------
resolution : int
The H3 refinement level. It can be an integer from 0 to 15.
area : dict
An area as GeoJSON dictionary specifying a polygon. Defaults to None.
Returns
-------
nodes : set[str]
The set of H3 indexes at the specified resolution level.
"""
nodes = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution)

# TODO: AOI not used in the current implementation.

return nodes


def add_edges_to_nx_graph(
graph: nx.Graph,
resolutions: list[int],
x_hops: int = 1,
depth_children: int = 1,
) -> nx.Graph:
"""Adds the edges to the graph.
This method includes multi-scale connections to the existing graph. The different scales
are defined by the resolutions (or refinement levels) specified.
Parameters
----------
graph : networkx.Graph
The graph to add the edges.
resolutions : list[int]
Levels of mesh resolution to consider.
x_hops: int
The number of hops to consider for the neighbours.
depth_children : int
The number of resolution levels to consider for the connections of children. Defaults to 1, which includes
connections up to the next resolution level.
Returns
-------
graph : networkx.Graph
The graph with the added edges.
"""

graph = add_neighbour_edges(graph, resolutions, x_hops)
graph = add_edges_to_children(
graph,
resolutions,
depth_children,
)
return graph


def add_neighbour_edges(
graph: nx.Graph,
refinement_levels: tuple[int],
x_hops: int = 1,
) -> nx.Graph:
for resolution in refinement_levels:
nodes = select_nodes_from_graph_at_resolution(graph, resolution)

for idx in nodes:
# neighbours
for idx_neighbour in h3.k_ring(idx, k=x_hops) & set(nodes):
graph = add_edge(
graph,
h3.h3_to_center_child(idx, refinement_levels[-1]),
h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]),
)

return graph


def add_edges_to_children(
graph: nx.Graph,
refinement_levels: tuple[int],
depth_children: Optional[int] = None,
) -> nx.Graph:
"""Adds edges to the children of the nodes at the specified resolution levels.
Parameters
----------
graph : nx.Graph
graph to which the edges will be added
refinement_levels : tuple[int]
set of refinement levels
depth_children : Optional[int], optional
The number of resolution levels to consider for the connections of children. Defaults to 1, which includes
connections up to the next resolution level, by default None
"""
if depth_children is None:
depth_children = len(refinement_levels)

for i_level, resolution_parent in enumerate(refinement_levels[0:-1]):
parent_nodes = select_nodes_from_graph_at_resolution(graph, resolution_parent)

for parent_idx in parent_nodes:
# add own children
for resolution_child in refinement_levels[i_level + 1 : i_level + depth_children + 1]:
for child_idx in h3.h3_to_children(parent_idx, res=resolution_child):
graph = add_edge(
graph,
h3.h3_to_center_child(parent_idx, refinement_levels[-1]),
h3.h3_to_center_child(child_idx, refinement_levels[-1]),
)

return graph


def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int):
parent_nodes = [node for node in graph.nodes if h3.h3_get_resolution(node) == resolution]
return parent_nodes


def add_edge(
graph: nx.Graph,
source_node_h3_idx: str,
target_node_h3_idx: str,
) -> nx.Graph:
"""Add edge between two nodes to a graph.
The edge will only be added in case both target and source nodes are included in the graph.
Parameters
----------
graph : networkx.Graph
The graph to add the nodes.
source_node_h3_idx : str
The H3 index of the tail of the edge.
target_node_h3_idx : str
The H3 index of the head of the edge.
"""
if not graph.has_node(source_node_h3_idx) or not graph.has_node(target_node_h3_idx):
return graph

if source_node_h3_idx != target_node_h3_idx:
source_location = np.deg2rad(h3.h3_to_geo(source_node_h3_idx))
target_location = np.deg2rad(h3.h3_to_geo(target_node_h3_idx))
graph.add_edge(
source_node_h3_idx, target_node_h3_idx, weight=haversine_distances([source_location, target_location])[0][1]
)

return graph
Loading

0 comments on commit 746ea2b

Please sign in to comment.