diff --git a/CHANGELOG.md b/CHANGELOG.md index 7743f30..52e9e94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,10 @@ Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD) + +### Added - feat: Define node sets and edges based on an ICON icosahedral mesh (#53) +- feat: Support for multiple edge builders between two sets of nodes (#70) ## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08 @@ -44,6 +47,8 @@ Keep it human-readable, your future self will thank you! - ci: extened python versions to include 3.11 and 3.12 - Update copyright notice - Fix `__version__` import in init +- The `edge_builder` field in the recipe is renamed to `edge_builders`. It now receives a list of edge builders. (#70) +- The `{source|target}_mask_attr_name` field is moved to inside the edge builder definition. (#70) ## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03 diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index c111cc3..4726bbd 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -55,13 +55,13 @@ def update_graph(self, graph: HeteroData) -> HeteroData: ) for edges_cfg in self.config.get("edges", {}): - graph = instantiate( - edges_cfg.edge_builder, - edges_cfg.source_name, - edges_cfg.target_name, - source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None), - target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None), - ).update_graph(graph, edges_cfg.get("attributes", {})) + for edge_builder_cfg in edges_cfg.edge_builders: + edge_builder = instantiate( + edge_builder_cfg, source_name=edges_cfg.source_name, target_name=edges_cfg.target_name + ) + graph = edge_builder.register_edges(graph) + + graph = edge_builder.register_attributes(graph, edges_cfg.get("attributes", {})) return graph diff --git a/src/anemoi/graphs/edges/builder.py b/src/anemoi/graphs/edges/builder.py index 92c3003..9a9d6f9 100644 --- a/src/anemoi/graphs/edges/builder.py +++ b/src/anemoi/graphs/edges/builder.py @@ -33,6 +33,7 @@ from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes from anemoi.graphs.nodes.builders.from_refined_icosahedron import StretchedTriNodes from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes +from anemoi.graphs.utils import concat_edges from anemoi.graphs.utils import get_grid_reference_distance LOGGER = logging.getLogger(__name__) @@ -98,8 +99,19 @@ def register_edges(self, graph: HeteroData) -> HeteroData: HeteroData The graph with the registered edges. """ - graph[self.name].edge_index = self.get_edge_index(graph) - graph[self.name].edge_type = type(self).__name__ + edge_index = self.get_edge_index(graph) + edge_type = type(self).__name__ + + if "edge_index" in graph[self.name]: + # Expand current edge indices + graph[self.name].edge_index = concat_edges(graph[self.name].edge_index, edge_index) + if edge_type not in graph[self.name].edge_type: + graph[self.name].edge_type = graph[self.name].edge_type + "," + edge_type + return graph + + # Register new edge indices + graph[self.name].edge_index = edge_index + graph[self.name].edge_type = edge_type return graph def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData: @@ -394,7 +406,6 @@ def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs): assert isinstance(x_hops, int), "Number of x_hops must be an integer" assert x_hops > 0, "Number of x_hops must be positive" self.x_hops = x_hops - self.node_type = None def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage: nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph( @@ -428,25 +439,23 @@ def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage: return nodes def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): - if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: + if source_nodes.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]: source_nodes = self.add_edges_from_tri_nodes(source_nodes) - elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: + elif source_nodes.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]: source_nodes = self.add_edges_from_hex_nodes(source_nodes) - elif self.node_type == StretchedTriNodes.__name__: + elif source_nodes.node_type == StretchedTriNodes.__name__: source_nodes = self.add_edges_from_stretched_tri_nodes(source_nodes) else: - raise ValueError(f"Invalid node type {self.node_type}") + raise ValueError(f"Invalid node type {source_nodes.node_type}") adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo") return adjmat def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData: - self.node_type = graph[self.source_name].node_type + node_type = graph[self.source_name].node_type valid_node_names = [n.__name__ for n in self.VALID_NODES] - assert ( - self.node_type in valid_node_names - ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." + assert node_type in valid_node_names, f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes." return super().update_graph(graph, attrs_config) diff --git a/src/anemoi/graphs/utils.py b/src/anemoi/graphs/utils.py index a68d6e7..5aa4b79 100644 --- a/src/anemoi/graphs/utils.py +++ b/src/anemoi/graphs/utils.py @@ -63,61 +63,22 @@ def get_grid_reference_distance(coords_rad: torch.Tensor, mask: torch.Tensor | N return dists[dists > 0].max() -def add_margin(lats: np.ndarray, lons: np.ndarray, margin: float) -> tuple[np.ndarray, np.ndarray]: - """Add a margin to the convex hull of the points considered. - - For each point (lat, lon) add 8 points around it, each at a distance of `margin` from the original point. - - Arguments - --------- - lats : np.ndarray - Latitudes of the points considered. - lons : np.ndarray - Longitudes of the points considered. - margin : float - The margin to add to the convex hull. - - Returns - ------- - latitudes : np.ndarray - Latitudes of the points considered, including the margin. - longitudes : np.ndarray - Longitudes of the points considered, including the margin. - """ - assert margin >= 0, "Margin must be non-negative" - if margin == 0: - return lats, lons - - latitudes, longitudes = [], [] - for lat_sign in [-1, 0, 1]: - for lon_sign in [-1, 0, 1]: - latitudes.append(lats + lat_sign * margin) - longitudes.append(lons + lon_sign * margin) - - return np.concatenate(latitudes), np.concatenate(longitudes) - - -def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int: - """Index position of vector. - - Get the index position of a vector in a matrix. +def concat_edges(edge_indices1: torch.Tensor, edge_indices2: torch.Tensor) -> torch.Tensor: + """Concat edges Parameters ---------- - vector : torch.Tensor of shape (N, ) - Vector to get its position in the matrix. - tensor : torch.Tensor of shape (M, N,) - Tensor in which the position is searched. + edge_indices1: torch.Tensor + Edge indices of the first set of edges. Shape: (2, num_edges1) + edge_indices2: torch.Tensor + Edge indices of the second set of edges. Shape: (2, num_edges2) Returns ------- - int - Index position of `vector` in `tensor`. -1 if `vector` is not in `tensor`. + torch.Tensor + Concatenated edge indices. """ - mask = torch.all(tensor == vector, axis=1) - if mask.any(): - return int(torch.where(mask)[0]) - return -1 + return torch.unique(torch.cat([edge_indices1, edge_indices2], axis=1), dim=1) def haversine_distance(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray: diff --git a/tests/conftest.py b/tests/conftest.py index 59cf21d..51c90dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -91,10 +91,9 @@ def config_file(tmp_path) -> tuple[str, str]: { "source_name": "test_nodes", "target_name": "test_nodes", - "edge_builder": { - "_target_": "anemoi.graphs.edges.KNNEdges", - "num_nearest_neighbours": 3, - }, + "edge_builders": [ + {"_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3}, + ], "attributes": { "dist_norm": {"_target_": "anemoi.graphs.edges.attributes.EdgeLength"}, "edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"}, diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..8e7a8c5 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,18 @@ +import numpy as np +import torch + +from anemoi.graphs.utils import concat_edges + + +def test_concat_edges(): + edge_indices1 = torch.tensor([[0, 1, 2, 3], [-1, -2, -3, -4]], dtype=torch.int64) + edge_indices2 = torch.tensor(np.array([[0, 4], [-1, -5]]), dtype=torch.int64) + no_edges = torch.tensor([[], []], dtype=torch.int64) + + result1 = concat_edges(edge_indices1, edge_indices2) + result2 = concat_edges(no_edges, edge_indices2) + + expected1 = torch.tensor([[0, 1, 2, 3, 4], [-1, -2, -3, -4, -5]], dtype=torch.int64) + + assert torch.allclose(result1, expected1) + assert torch.allclose(result2, edge_indices2)