Skip to content

Commit

Permalink
feat: add displot of node adjacencies
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Aug 16, 2024
1 parent 7be6cea commit 71694d6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
6 changes: 1 addition & 5 deletions src/anemoi/graphs/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def describe(self, show_attribute_distributions: Optional[bool] = True) -> None:
print(f"📦 Path : {self.path}")
print(f"💽 Size : {bytes(self.total_size)} ({self.total_size})")
print()
print("🪩 Nodes summary")
print("🪩 Nodes summary")
print()
print(
table(
Expand Down Expand Up @@ -211,7 +211,3 @@ def describe(self, show_attribute_distributions: Optional[bool] = True) -> None:
print()
print("🔋 Graph ready.")
print()


if __name__ == "__main__":
GraphDescriptor("graph.pt").describe()
2 changes: 2 additions & 0 deletions src/anemoi/graphs/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from anemoi.graphs.plotting.displots import plot_distribution_edge_attributes
from anemoi.graphs.plotting.displots import plot_distribution_node_attributes
from anemoi.graphs.plotting.displots import plot_distribution_node_derived_attributes
from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes
from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph
from anemoi.graphs.plotting.interactive_html import plot_isolated_nodes
Expand Down Expand Up @@ -52,6 +53,7 @@ def inspect(self):

if self.show_attribute_distributions:
LOGGER.info("Saving distribution plots of node ande edge attributes ...")
plot_distribution_node_derived_attributes(self.graph, self.output_path / "distribution_node_adjancency.png")
plot_distribution_edge_attributes(self.graph, self.output_path / "distribution_edge_attributes.png")
plot_distribution_node_attributes(self.graph, self.output_path / "distribution_node_attributes.png")

Expand Down
38 changes: 33 additions & 5 deletions src/anemoi/graphs/plotting/displots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from typing import Union

import matplotlib.pyplot as plt
import torch
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs.plotting.prepare import compute_node_adjacencies
from anemoi.graphs.plotting.prepare import get_edge_attribute_dims
from anemoi.graphs.plotting.prepare import get_node_attribute_dims

Expand All @@ -35,6 +37,33 @@ def plot_distribution_edge_attributes(graph: HeteroData, out_file: Optional[Unio
plot_distribution_attributes(graph.edge_items(), num_edges, attr_dims, "Edge", out_file)


def plot_distribution_node_derived_attributes(graph, outfile: Optional[Union[str, Path]] = None):
"""Figure with the distribution of the node derived attributes.
Each row represents a node type and each column an attribute dimension.
"""
node_adjs = {}
node_attr_dims = {}
for source_name, _, target_name in graph.edge_types:
node_adj_tensor = compute_node_adjacencies(graph, source_name, target_name)
node_adj_tensor = torch.from_numpy(node_adj_tensor.reshape((node_adj_tensor.shape[0], -1)))
node_adj_key = f"# edges from {source_name}"

# Store node adjacencies
if target_name in node_adjs:
node_adjs[target_name] = node_adjs[target_name] | {node_adj_key: node_adj_tensor}
else:
node_adjs[target_name] = {node_adj_key: node_adj_tensor}

# Store attribute dimension
if node_adj_key not in node_attr_dims:
node_attr_dims[node_adj_key] = node_adj_tensor.shape[1]

node_adj_list = [(k, v) for k, v in node_adjs.items()]

plot_distribution_attributes(node_adj_list, len(node_adjs), node_attr_dims, "Node", outfile)


def plot_distribution_attributes(
graph_items: Union[NodeStorage, EdgeStorage],
num_items: int,
Expand Down Expand Up @@ -62,11 +91,10 @@ def plot_distribution_attributes(
for dim in range(attr_values):
if attr_name in item_store:
axs[i, j + dim].hist(item_store[attr_name][:, dim].float(), bins=50)
if j + dim == 0:
axs[i, j + dim].set_ylabel("".join(item_name).replace("to", " --> "))
if i == 0:
axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
elif i == num_items - 1:

axs[i, j + dim].set_ylabel("".join(item_name).replace("to", " --> "))
axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
if i == num_items - 1:
axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
else:
axs[i, j + dim].set_axis_off()
Expand Down

0 comments on commit 71694d6

Please sign in to comment.