Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] allow multiple edge builders #70

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.0...HEAD)

### Added
- feat: Support for multiple edge builders between two sets of nodes (#70)

## [0.4.0 - LAM and stretched graphs](https://github.com/ecmwf/anemoi-graphs/compare/0.3.0...0.4.0) - 2024-11-08

### Added
Expand Down Expand Up @@ -43,6 +46,8 @@ Keep it human-readable, your future self will thank you!
- ci: extened python versions to include 3.11 and 3.12
- Update copyright notice
- Fix `__version__` import in init
- The `edge_builder` field in the recipe is renamed to `edge_builders`. It now receives a list of edge builders. (#70)
- The `{source|target}_mask_attr_name` field is moved to inside the edge builder definition. (#70)

## [0.3.0 Anemoi-graphs, minor release](https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...0.3.0) - 2024-09-03

Expand Down
14 changes: 7 additions & 7 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def update_graph(self, graph: HeteroData) -> HeteroData:
)

for edges_cfg in self.config.get("edges", {}):
graph = instantiate(
edges_cfg.edge_builder,
edges_cfg.source_name,
edges_cfg.target_name,
source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None),
target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None),
).update_graph(graph, edges_cfg.get("attributes", {}))
for edge_builder_cfg in edges_cfg.edge_builders:
edge_builder = instantiate(
edge_builder_cfg, source_name=edges_cfg.source_name, target_name=edges_cfg.target_name
)
graph = edge_builder.register_edges(graph)

graph = edge_builder.register_attributes(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down
31 changes: 20 additions & 11 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import StretchedTriNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes
from anemoi.graphs.utils import concat_edges
from anemoi.graphs.utils import get_grid_reference_distance

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,8 +100,19 @@ def register_edges(self, graph: HeteroData) -> HeteroData:
HeteroData
The graph with the registered edges.
"""
graph[self.name].edge_index = self.get_edge_index(graph)
graph[self.name].edge_type = type(self).__name__
edge_index = self.get_edge_index(graph)
edge_type = type(self).__name__

if "edge_index" in graph[self.name]:
# Expand current edge indices
graph[self.name].edge_index = concat_edges(graph[self.name].edge_index, edge_index)
if edge_type not in graph[self.name].edge_type:
graph[self.name].edge_type = graph[self.name].edge_type + "," + edge_type
return graph

# Register new edge indices
graph[self.name].edge_index = edge_index
graph[self.name].edge_type = edge_type
return graph

def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
Expand Down Expand Up @@ -389,7 +401,6 @@ def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
assert x_hops > 0, "Number of x_hops must be positive"
self.x_hops = x_hops
self.node_type = None

def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph(
Expand Down Expand Up @@ -423,24 +434,22 @@ def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage:
return nodes

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]:
if source_nodes.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]:
source_nodes = self.add_edges_from_tri_nodes(source_nodes)
elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]:
elif source_nodes.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]:
source_nodes = self.add_edges_from_hex_nodes(source_nodes)
elif self.node_type == StretchedTriNodes.__name__:
elif source_nodes.node_type == StretchedTriNodes.__name__:
source_nodes = self.add_edges_from_stretched_tri_nodes(source_nodes)
else:
raise ValueError(f"Invalid node type {self.node_type}")
raise ValueError(f"Invalid node type {source_nodes.node_type}")

adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")

return adjmat

def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
self.node_type = graph[self.source_name].node_type
node_type = graph[self.source_name].node_type
valid_node_names = [n.__name__ for n in self.VALID_NODES]
assert (
self.node_type in valid_node_names
), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."
assert node_type in valid_node_names, f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."

return super().update_graph(graph, attrs_config)
57 changes: 9 additions & 48 deletions src/anemoi/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,61 +63,22 @@ def get_grid_reference_distance(coords_rad: torch.Tensor, mask: torch.Tensor | N
return dists[dists > 0].max()


def add_margin(lats: np.ndarray, lons: np.ndarray, margin: float) -> tuple[np.ndarray, np.ndarray]:
"""Add a margin to the convex hull of the points considered.

For each point (lat, lon) add 8 points around it, each at a distance of `margin` from the original point.

Arguments
---------
lats : np.ndarray
Latitudes of the points considered.
lons : np.ndarray
Longitudes of the points considered.
margin : float
The margin to add to the convex hull.

Returns
-------
latitudes : np.ndarray
Latitudes of the points considered, including the margin.
longitudes : np.ndarray
Longitudes of the points considered, including the margin.
"""
assert margin >= 0, "Margin must be non-negative"
if margin == 0:
return lats, lons

latitudes, longitudes = [], []
for lat_sign in [-1, 0, 1]:
for lon_sign in [-1, 0, 1]:
latitudes.append(lats + lat_sign * margin)
longitudes.append(lons + lon_sign * margin)

return np.concatenate(latitudes), np.concatenate(longitudes)


def get_index_in_outer_join(vector: torch.Tensor, tensor: torch.Tensor) -> int:
"""Index position of vector.

Get the index position of a vector in a matrix.
def concat_edges(edge_indices1: torch.Tensor, edge_indices2: torch.Tensor) -> torch.Tensor:
"""Concat edges

Parameters
----------
vector : torch.Tensor of shape (N, )
Vector to get its position in the matrix.
tensor : torch.Tensor of shape (M, N,)
Tensor in which the position is searched.
edge_indices1: torch.Tensor
Edge indices of the first set of edges. Shape: (2, num_edges1)
edge_indices2: torch.Tensor
Edge indices of the second set of edges. Shape: (2, num_edges2)

Returns
-------
int
Index position of `vector` in `tensor`. -1 if `vector` is not in `tensor`.
torch.Tensor
Concatenated edge indices.
"""
mask = torch.all(tensor == vector, axis=1)
if mask.any():
return int(torch.where(mask)[0])
return -1
return torch.unique(torch.cat([edge_indices1, edge_indices2], axis=1), dim=1)


def haversine_distance(source_coords: np.ndarray, target_coords: np.ndarray) -> np.ndarray:
Expand Down
7 changes: 3 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ def config_file(tmp_path) -> tuple[str, str]:
{
"source_name": "test_nodes",
"target_name": "test_nodes",
"edge_builder": {
"_target_": "anemoi.graphs.edges.KNNEdges",
"num_nearest_neighbours": 3,
},
"edge_builders": [
{"_target_": "anemoi.graphs.edges.KNNEdges", "num_nearest_neighbours": 3},
],
"attributes": {
"dist_norm": {"_target_": "anemoi.graphs.edges.attributes.EdgeLength"},
"edge_dirs": {"_target_": "anemoi.graphs.edges.attributes.EdgeDirection"},
Expand Down
18 changes: 18 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import torch

from anemoi.graphs.utils import concat_edges


def test_concat_edges():
edge_indices1 = torch.tensor([[0, 1, 2, 3], [-1, -2, -3, -4]], dtype=torch.int64)
edge_indices2 = torch.tensor(np.array([[0, 4], [-1, -5]]), dtype=torch.int64)
no_edges = torch.tensor([[], []], dtype=torch.int64)

result1 = concat_edges(edge_indices1, edge_indices2)
result2 = concat_edges(no_edges, edge_indices2)

expected1 = torch.tensor([[0, 1, 2, 3, 4], [-1, -2, -3, -4, -5]], dtype=torch.int64)

assert torch.allclose(result1, expected1)
assert torch.allclose(result2, edge_indices2)
Loading