Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit 746ea2b

Browse files
6 generate graphs from icosahedral meshes (#11)
* Global Encoder-Processor-Decoder graph (#9) * feat: Initial implementation of global graphs Co-authored by: Mario Santa Cruz <[email protected]> Co-authored-by: Helen Theissen <[email protected]> Co-authored-by: Jesper Dramsch <[email protected]> * fix: attributes as torch.float32 * new test: attributes must be float32 * fix typo * Homogeneize base builders * improve test docstrings * homogeneize (name as class attribute) * new input config * new default * feat: Initial implementation of global graphs Co-authored by: Mario Santa Cruz <[email protected]> * add cli command * Ignore .pt files * run pre-commit * docstring + log erros * initial tests * feat: initial version of AttributeBuilder * refactor: separate into node edge attribute builders * feat: edge_length moved to edges/attributes.py * remove __init__ * bugfix (encoder edge lengths) + refector * feat: support path and dict for `config` argument * fix: error * refactor: naming * fix: pre-commit * feat: builders icosahedral * feat: Add icosahedral graph generation Co-authored-by: Mario Santa Cruz <[email protected]> * refactor: remove create_shere * feat: Icosahedral edge builder * feat: hexagonal graph generation Co-authored-by: Mario Santa Cruz <[email protected]> * feat: hexagonal builders * fix: AOI not implemented yet * fix: abstractmethod and renaming * chore: add dependencies * test: add tests for trimesh * test: add tests for hex (h3) * fix: imports * fix: output type * refactor: delete unused file * refactor: renaming and positioning * feat: ensure src and dst always the same * fix: imports * fix: edge_name not supported * test: add tests for TriIcosahedralEdges * fix: assert missing for Hexagonal edges * test: hexagonal edges * fix: avoid same name * fix: imports * fix: conflicts * update tests * Include xhops to hexagonal edges * docs: update docstrings * fix: update attribute name * refactor: rename multiscale nodes * refactor: rename icosahedral nodes * improve: clarity of function * improve: function syntax * refactor: simplify resolution assignment * refactor: improve variable naming in icosahedral graph generation * more comments * refactor: naming * refactor: separate into functions * refactor: remove unused code * refactor: remove unused options and rename * doc: clarify cells and nodes * fix: add return statements * homogeneize: tri & hex edges * naming: xhops to x_hops * naming: x_hops in docstring * docstring * blank lines * fix: add return statement * simplify icoshaedral edges * feat: select edge builder method based on ico node type * test: adjust tests to new MultiScaleEdges * docs: improve docstrings * fix: remove LAM filtering for icosahedral, leave for next PR * h3 (v4) not supported ---------
1 parent 5b3bbe1 commit 746ea2b

File tree

12 files changed

+770
-10
lines changed

12 files changed

+770
-10
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,12 @@ dynamic = [
5252
dependencies = [
5353
"anemoi-datasets[data]>=0.3.3",
5454
"anemoi-utils>=0.3.6",
55+
"h3>=3.7.6,<4",
5556
"hydra-core>=1.3",
57+
"networkx>=3.1",
5658
"torch>=2.2",
5759
"torch-geometric>=2.3.1,<2.5",
60+
"trimesh>=4.1",
5861
]
5962

6063
optional-dependencies.all = [

src/anemoi/graphs/edges/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .builder import CutOffEdges
22
from .builder import KNNEdges
3+
from .builder import MultiScaleEdges
34

4-
__all__ = ["KNNEdges", "CutOffEdges"]
5+
__all__ = ["KNNEdges", "CutOffEdges", "MultiScaleEdges"]

src/anemoi/graphs/edges/builder.py

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import abstractmethod
44
from typing import Optional
55

6+
import networkx as nx
67
import numpy as np
78
import torch
89
from anemoi.utils.config import DotDict
@@ -12,6 +13,10 @@
1213
from torch_geometric.data.storage import NodeStorage
1314

1415
from anemoi.graphs import EARTH_RADIUS
16+
from anemoi.graphs.generate import hexagonal
17+
from anemoi.graphs.generate import icosahedral
18+
from anemoi.graphs.nodes.builder import HexNodes
19+
from anemoi.graphs.nodes.builder import TriNodes
1520
from anemoi.graphs.utils import get_grid_reference_distance
1621

1722
LOGGER = logging.getLogger(__name__)
@@ -33,7 +38,7 @@ def name(self) -> tuple[str, str, str]:
3338
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ...
3439

3540
def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
36-
"""Prepare nodes information."""
41+
"""Prepare node information and get source and target nodes."""
3742
return graph[self.source_name], graph[self.target_name]
3843

3944
def get_edge_index(self, graph: HeteroData) -> torch.Tensor:
@@ -188,8 +193,6 @@ class CutOffEdges(BaseEdgeBuilder):
188193
The name of the target nodes.
189194
cutoff_factor : float
190195
Factor to multiply the grid reference distance to get the cut-off radius.
191-
radius : float
192-
Cut-off radius.
193196
194197
Methods
195198
-------
@@ -235,7 +238,7 @@ def get_cutoff_radius(self, graph: HeteroData, mask_attr: Optional[torch.Tensor]
235238
return radius
236239

237240
def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
238-
"""Prepare nodes information."""
241+
"""Prepare node information and get source and target nodes."""
239242
self.radius = self.get_cutoff_radius(graph)
240243
return super().prepare_node_data(graph)
241244

@@ -260,3 +263,66 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
260263
nearest_neighbour.fit(source_nodes.x)
261264
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo()
262265
return adj_matrix
266+
267+
268+
class MultiScaleEdges(BaseEdgeBuilder):
269+
"""Base class for multi-scale edges in the nodes of a graph."""
270+
271+
def __init__(self, source_name: str, target_name: str, x_hops: int):
272+
super().__init__(source_name, target_name)
273+
assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same."
274+
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
275+
assert x_hops > 0, "Number of x_hops must be positive"
276+
self.x_hops = x_hops
277+
278+
def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
279+
source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
280+
source_nodes["nx_graph"],
281+
resolutions=source_nodes["resolutions"],
282+
x_hops=self.x_hops,
283+
) # HeteroData refuses to accept None
284+
285+
adjmat = nx.to_scipy_sparse_array(
286+
source_nodes["nx_graph"], nodelist=list(range(len(source_nodes["nx_graph"]))), format="coo"
287+
)
288+
return adjmat
289+
290+
def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):
291+
292+
source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
293+
source_nodes["nx_graph"],
294+
resolutions=source_nodes["resolutions"],
295+
x_hops=self.x_hops,
296+
)
297+
298+
adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo")
299+
return adjmat
300+
301+
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
302+
if self.node_type == TriNodes.__name__:
303+
adjmat = self.adjacency_from_tri_nodes(source_nodes)
304+
elif self.node_type == HexNodes.__name__:
305+
adjmat = self.adjacency_from_hex_nodes(source_nodes)
306+
else:
307+
raise ValueError(f"Invalid node type {self.node_type}")
308+
309+
adjmat = self.post_process_adjmat(source_nodes, adjmat)
310+
311+
return adjmat
312+
313+
def post_process_adjmat(self, nodes: NodeStorage, adjmat):
314+
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["node_ordering"])}
315+
sort_func = np.vectorize(graph_sorted.get)
316+
adjmat.row = sort_func(adjmat.row)
317+
adjmat.col = sort_func(adjmat.col)
318+
return adjmat
319+
320+
def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
321+
assert (
322+
graph[self.source_name].node_type == TriNodes.__name__
323+
or graph[self.source_name].node_type == HexNodes.__name__
324+
), f"{self.__class__.__name__} requires {TriNodes.__name__} or {HexNodes.__name__}."
325+
326+
self.node_type = graph[self.source_name].node_type
327+
328+
return super().update_graph(graph, attrs_config)
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from typing import Optional
2+
3+
import h3
4+
import networkx as nx
5+
import numpy as np
6+
from sklearn.metrics.pairwise import haversine_distances
7+
8+
9+
def create_hexagonal_nodes(
10+
resolutions: list[int],
11+
area: Optional[dict] = None,
12+
) -> tuple[nx.Graph, np.ndarray, list[int]]:
13+
"""Creates a global mesh from a refined icosahedro.
14+
15+
This method relies on the H3 python library, which covers the earth with hexagons (and 5 pentagons). At each
16+
refinement level, a hexagon cell (nodes) has 7 child cells (aperture 7).
17+
18+
Parameters
19+
----------
20+
resolutions : list[int]
21+
Levels of mesh resolution to consider.
22+
area : dict
23+
A region, in GeoJSON data format, to be contained by all cells. Defaults to None, which computes the global
24+
mesh.
25+
26+
Returns
27+
-------
28+
graph : networkx.Graph
29+
The specified graph (nodes & edges).
30+
coords_rad : np.ndarray
31+
The node coordinates (not ordered) in radians.
32+
node_ordering : list[int]
33+
Order of the nodes in the graph to be sorted by latitude and longitude.
34+
"""
35+
graph = nx.Graph()
36+
37+
area_kwargs = {"area": area}
38+
39+
for resolution in resolutions:
40+
graph = add_nodes_for_resolution(graph, resolution, **area_kwargs)
41+
42+
coords = np.deg2rad(np.array([h3.h3_to_geo(node) for node in graph.nodes]))
43+
44+
# Sort nodes by latitude and longitude
45+
node_ordering = np.lexsort(coords.T[::-1], axis=0)
46+
47+
return graph, coords, list(node_ordering)
48+
49+
50+
def add_nodes_for_resolution(
51+
graph: nx.Graph,
52+
resolution: int,
53+
**area_kwargs: Optional[dict],
54+
) -> nx.Graph:
55+
"""Add all nodes at a specified refinement level to a graph.
56+
57+
Parameters
58+
----------
59+
graph : networkx.Graph
60+
The graph to add the nodes.
61+
resolution : int
62+
The H3 refinement level. It can be an integer from 0 to 15.
63+
area_kwargs: dict
64+
Additional arguments to pass to the get_nodes_at_resolution function.
65+
"""
66+
67+
nodes = get_nodes_at_resolution(resolution, **area_kwargs)
68+
69+
for idx in nodes:
70+
graph.add_node(idx, hcoords_rad=np.deg2rad(h3.h3_to_geo(idx)))
71+
72+
return graph
73+
74+
75+
def get_nodes_at_resolution(
76+
resolution: int,
77+
area: Optional[dict] = None,
78+
) -> set[str]:
79+
"""Get nodes at a specified refinement level over the entire globe.
80+
81+
If area is not None, it will return the nodes within the specified area
82+
83+
Parameters
84+
----------
85+
resolution : int
86+
The H3 refinement level. It can be an integer from 0 to 15.
87+
area : dict
88+
An area as GeoJSON dictionary specifying a polygon. Defaults to None.
89+
90+
Returns
91+
-------
92+
nodes : set[str]
93+
The set of H3 indexes at the specified resolution level.
94+
"""
95+
nodes = h3.uncompact(h3.get_res0_indexes(), resolution) if area is None else h3.polyfill(area, resolution)
96+
97+
# TODO: AOI not used in the current implementation.
98+
99+
return nodes
100+
101+
102+
def add_edges_to_nx_graph(
103+
graph: nx.Graph,
104+
resolutions: list[int],
105+
x_hops: int = 1,
106+
depth_children: int = 1,
107+
) -> nx.Graph:
108+
"""Adds the edges to the graph.
109+
110+
This method includes multi-scale connections to the existing graph. The different scales
111+
are defined by the resolutions (or refinement levels) specified.
112+
113+
Parameters
114+
----------
115+
graph : networkx.Graph
116+
The graph to add the edges.
117+
resolutions : list[int]
118+
Levels of mesh resolution to consider.
119+
x_hops: int
120+
The number of hops to consider for the neighbours.
121+
depth_children : int
122+
The number of resolution levels to consider for the connections of children. Defaults to 1, which includes
123+
connections up to the next resolution level.
124+
125+
Returns
126+
-------
127+
graph : networkx.Graph
128+
The graph with the added edges.
129+
"""
130+
131+
graph = add_neighbour_edges(graph, resolutions, x_hops)
132+
graph = add_edges_to_children(
133+
graph,
134+
resolutions,
135+
depth_children,
136+
)
137+
return graph
138+
139+
140+
def add_neighbour_edges(
141+
graph: nx.Graph,
142+
refinement_levels: tuple[int],
143+
x_hops: int = 1,
144+
) -> nx.Graph:
145+
for resolution in refinement_levels:
146+
nodes = select_nodes_from_graph_at_resolution(graph, resolution)
147+
148+
for idx in nodes:
149+
# neighbours
150+
for idx_neighbour in h3.k_ring(idx, k=x_hops) & set(nodes):
151+
graph = add_edge(
152+
graph,
153+
h3.h3_to_center_child(idx, refinement_levels[-1]),
154+
h3.h3_to_center_child(idx_neighbour, refinement_levels[-1]),
155+
)
156+
157+
return graph
158+
159+
160+
def add_edges_to_children(
161+
graph: nx.Graph,
162+
refinement_levels: tuple[int],
163+
depth_children: Optional[int] = None,
164+
) -> nx.Graph:
165+
"""Adds edges to the children of the nodes at the specified resolution levels.
166+
167+
Parameters
168+
----------
169+
graph : nx.Graph
170+
graph to which the edges will be added
171+
refinement_levels : tuple[int]
172+
set of refinement levels
173+
depth_children : Optional[int], optional
174+
The number of resolution levels to consider for the connections of children. Defaults to 1, which includes
175+
connections up to the next resolution level, by default None
176+
"""
177+
if depth_children is None:
178+
depth_children = len(refinement_levels)
179+
180+
for i_level, resolution_parent in enumerate(refinement_levels[0:-1]):
181+
parent_nodes = select_nodes_from_graph_at_resolution(graph, resolution_parent)
182+
183+
for parent_idx in parent_nodes:
184+
# add own children
185+
for resolution_child in refinement_levels[i_level + 1 : i_level + depth_children + 1]:
186+
for child_idx in h3.h3_to_children(parent_idx, res=resolution_child):
187+
graph = add_edge(
188+
graph,
189+
h3.h3_to_center_child(parent_idx, refinement_levels[-1]),
190+
h3.h3_to_center_child(child_idx, refinement_levels[-1]),
191+
)
192+
193+
return graph
194+
195+
196+
def select_nodes_from_graph_at_resolution(graph: nx.Graph, resolution: int):
197+
parent_nodes = [node for node in graph.nodes if h3.h3_get_resolution(node) == resolution]
198+
return parent_nodes
199+
200+
201+
def add_edge(
202+
graph: nx.Graph,
203+
source_node_h3_idx: str,
204+
target_node_h3_idx: str,
205+
) -> nx.Graph:
206+
"""Add edge between two nodes to a graph.
207+
208+
The edge will only be added in case both target and source nodes are included in the graph.
209+
210+
Parameters
211+
----------
212+
graph : networkx.Graph
213+
The graph to add the nodes.
214+
source_node_h3_idx : str
215+
The H3 index of the tail of the edge.
216+
target_node_h3_idx : str
217+
The H3 index of the head of the edge.
218+
"""
219+
if not graph.has_node(source_node_h3_idx) or not graph.has_node(target_node_h3_idx):
220+
return graph
221+
222+
if source_node_h3_idx != target_node_h3_idx:
223+
source_location = np.deg2rad(h3.h3_to_geo(source_node_h3_idx))
224+
target_location = np.deg2rad(h3.h3_to_geo(target_node_h3_idx))
225+
graph.add_edge(
226+
source_node_h3_idx, target_node_h3_idx, weight=haversine_distances([source_location, target_location])[0][1]
227+
)
228+
229+
return graph

0 commit comments

Comments
 (0)