Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit 7fd9c74

Browse files
JPXKQXtheissenhelenJesperDramsch
authored
Clean nodes after building the graph (#23)
* feat: clean graph of unneeded attributes after creation Co-authored-by: Mario Santa Cruz <[email protected]> Co-authored-by: Helen Theissen <[email protected]> Co-authored-by: Jesper Dramsch <[email protected]>
1 parent d510eb2 commit 7fd9c74

File tree

9 files changed

+86
-124
lines changed

9 files changed

+86
-124
lines changed

src/anemoi/graphs/create.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ def generate_graph(self) -> HeteroData:
5959

6060
return graph
6161

62+
def clean(self, graph: HeteroData) -> HeteroData:
63+
"""Clean the hidden attributes of the nodes and edges."""
64+
for nodes_name in graph.node_types:
65+
node_attrs = list(graph[nodes_name].keys())
66+
for node_attr_name in node_attrs:
67+
if node_attr_name.startswith("_"):
68+
del graph[nodes_name][node_attr_name]
69+
70+
for edge_key in graph.edge_types:
71+
edge_attrs = graph[edge_key].keys()
72+
for edge_attr_name in edge_attrs:
73+
if edge_attr_name.startswith("_"):
74+
del graph[edge_key][edge_attr_name]
75+
76+
return graph
77+
6278
def save(self, graph: HeteroData) -> None:
6379
"""Save the graph to the output path."""
6480
if not os.path.exists(self.path) or self.overwrite:
@@ -69,6 +85,7 @@ def create(self) -> HeteroData:
6985
"""Create the graph and save it to the output path."""
7086
self.init()
7187
graph = self.generate_graph()
88+
graph = self.clean(graph)
7289
self.save(graph)
7390
return graph
7491

src/anemoi/graphs/edges/builder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -276,26 +276,26 @@ def __init__(self, source_name: str, target_name: str, x_hops: int):
276276
self.x_hops = x_hops
277277

278278
def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
279-
source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
280-
source_nodes["nx_graph"],
281-
resolutions=source_nodes["resolutions"],
279+
source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph(
280+
source_nodes["_nx_graph"],
281+
resolutions=source_nodes["_resolutions"],
282282
x_hops=self.x_hops,
283283
) # HeteroData refuses to accept None
284284

285285
adjmat = nx.to_scipy_sparse_array(
286-
source_nodes["nx_graph"], nodelist=list(range(len(source_nodes["nx_graph"]))), format="coo"
286+
source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo"
287287
)
288288
return adjmat
289289

290290
def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):
291291

292-
source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
293-
source_nodes["nx_graph"],
294-
resolutions=source_nodes["resolutions"],
292+
source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph(
293+
source_nodes["_nx_graph"],
294+
resolutions=source_nodes["_resolutions"],
295295
x_hops=self.x_hops,
296296
)
297297

298-
adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo")
298+
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
299299
return adjmat
300300

301301
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
@@ -311,7 +311,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
311311
return adjmat
312312

313313
def post_process_adjmat(self, nodes: NodeStorage, adjmat):
314-
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["node_ordering"])}
314+
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])}
315315
sort_func = np.vectorize(graph_sorted.get)
316316
adjmat.row = sort_func(adjmat.row)
317317
adjmat.col = sort_func(adjmat.col)

src/anemoi/graphs/nodes/builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,9 @@ def get_coordinates(self) -> torch.Tensor:
226226
def create_nodes(self) -> np.ndarray: ...
227227

228228
def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
229-
graph[self.name]["resolutions"] = self.resolutions
230-
graph[self.name]["nx_graph"] = self.nx_graph
231-
graph[self.name]["node_ordering"] = self.node_ordering
229+
graph[self.name]["_resolutions"] = self.resolutions
230+
graph[self.name]["_nx_graph"] = self.nx_graph
231+
graph[self.name]["_node_ordering"] = self.node_ordering
232232
return super().register_attributes(graph, config)
233233

234234

tests/edges/test_icosahedral_edges.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

tests/edges/test_multiscale_edges.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from anemoi.graphs.nodes import TriNodes
77

88

9-
class TestIcosahedralEdgesInit:
9+
class TestMultiScaleEdgesInit:
1010
def test_init(self):
1111
"""Test MultiScaleEdges initialization."""
1212
assert isinstance(MultiScaleEdges("test_nodes", "test_nodes", 1), MultiScaleEdges)
@@ -23,7 +23,7 @@ def test_fail_init_diff_nodes(self):
2323
MultiScaleEdges("test_nodes", "test_nodes2", 0)
2424

2525

26-
class TestIcosahedralEdgesTransform:
26+
class TestMultiScaleEdgesTransform:
2727

2828
@pytest.fixture()
2929
def tri_ico_graph(self) -> HeteroData:

tests/nodes/test_hex_nodes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_update_graph():
2828
node_builder = HexNodes(0, "test_nodes")
2929
graph = HeteroData()
3030
graph = node_builder.update_graph(graph, {})
31-
assert "resolutions" in graph["test_nodes"]
32-
assert "nx_graph" in graph["test_nodes"]
33-
assert "node_ordering" in graph["test_nodes"]
34-
assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes
31+
assert "_resolutions" in graph["test_nodes"]
32+
assert "_nx_graph" in graph["test_nodes"]
33+
assert "_node_ordering" in graph["test_nodes"]
34+
assert len(graph["test_nodes"]["_node_ordering"]) == graph["test_nodes"].num_nodes

tests/nodes/test_tri_nodes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_update_graph():
2828
node_builder = TriNodes(1, "test_nodes")
2929
graph = HeteroData()
3030
graph = node_builder.update_graph(graph, {})
31-
assert "resolutions" in graph["test_nodes"]
32-
assert "nx_graph" in graph["test_nodes"]
33-
assert "node_ordering" in graph["test_nodes"]
34-
assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes
31+
assert "_resolutions" in graph["test_nodes"]
32+
assert "_nx_graph" in graph["test_nodes"]
33+
assert "_node_ordering" in graph["test_nodes"]
34+
assert len(graph["test_nodes"]["_node_ordering"]) == graph["test_nodes"].num_nodes

tests/test_create.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
2+
# This software is licensed under the terms of the Apache Licence Version 2.0
3+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4+
# In applying this licence, ECMWF does not waive the privileges and immunities
5+
# granted to it by virtue of its status as an intergovernmental organisation
6+
# nor does it submit to any jurisdiction.
7+
8+
9+
from pathlib import Path
10+
11+
import torch
12+
from torch_geometric.data import HeteroData
13+
14+
from anemoi.graphs.create import GraphCreator
15+
16+
17+
class TestGraphCreator:
18+
19+
def test_generate_graph(self, config_file: tuple[Path, str], mock_grids_path: tuple[str, int]):
20+
"""Test GraphCreator workflow."""
21+
tmp_path, config_name = config_file
22+
graph_path = tmp_path / "graph.pt"
23+
config_path = tmp_path / config_name
24+
25+
GraphCreator(graph_path, config_path).create()
26+
27+
graph = torch.load(graph_path)
28+
assert isinstance(graph, HeteroData)
29+
assert "test_nodes" in graph.node_types
30+
assert ("test_nodes", "to", "test_nodes") in graph.edge_types
31+
32+
for nodes in graph.node_stores:
33+
for node_attr in nodes.node_attrs():
34+
assert isinstance(nodes[node_attr], torch.Tensor)
35+
assert nodes[node_attr].dtype in [torch.int32, torch.float32]
36+
37+
for edges in graph.edge_stores:
38+
for edge_attr in edges.edge_attrs():
39+
assert isinstance(edges[edge_attr], torch.Tensor)
40+
assert edges[edge_attr].dtype in [torch.int32, torch.float32]
41+
42+
for nodes in graph.node_stores:
43+
for node_attr in nodes.node_attrs():
44+
assert not node_attr.startswith("_")
45+
for edges in graph.edge_stores:
46+
for edge_attr in edges.edge_attrs():
47+
assert not edge_attr.startswith("_")

tests/test_graphs.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)