Skip to content

Commit

Permalink
feat: hexagonal builders
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Jun 28, 2024
1 parent ea88000 commit d0c20f6
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.utils import get_grid_reference_distance
from anemoi.graphs.nodes.builder import TriRefinedIcosahedralNodeBuilder
from anemoi.graphs.nodes.builder import HexRefinedIcosahedralNodeBuilder
from anemoi.graphs.generate import icosahedral

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,3 +165,39 @@ def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):
adjmat.row = sort_func2(sort_func1(adjmat.row))
adjmat.col = sort_func2(sort_func1(adjmat.col))
return adjmat


class HexagonalEdgeBuilder(BaseEdgeBuilder):
"""Computes hexagonal edges and adds them to a HeteroData graph."""

def __init__(self, src_name: str, dst_name: str, add_neighbouring_children: bool = False, depth_children: Optional[int] = 1):
super().__init__(src_name, dst_name)
self.add_neighbouring_children = add_neighbouring_children
self.depth_children = depth_children

def transform(self, graph: HeteroData, edge_name: str, attrs_config: Optional[DotDict] = None) -> HeteroData:
assert (
graph[self.src_name].node_type == HexRefinedIcosahedralNodeBuilder.__name__
), "IcosahedralConnection requires MultiScaleIcosahedral nodes."
assert graph[self.src_name] == graph[self.dst_name], "InheritConnection requires the same nodes for source and destination."

# TODO: Next assert doesn't exist anymore since filters were moved, make sure this is checked where appropriate
# assert filter_src is None and filter_dst is None, "InheritConnection does not support filtering with attributes."

return super().transform(graph, edge_name, attrs_config)

def get_adj_matrix(self, src_nodes: NodeStorage, dst_nodes: NodeStorage):

src_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
src_nodes["nx_graph"],
resolutions=src_nodes["resolutions"],
neighbour_children=self.add_neighbouring_children,
depth_children=self.depth_children,
)

adjmat = nx.to_scipy_sparse_array(src_nodes["nx_graph"], format="coo")
graph_2_sorted = dict(zip(src_nodes["node_ordering"], range(len(src_nodes.node_ordering))))
sort_func = np.vectorize(graph_2_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
return adjmat
7 changes: 7 additions & 0 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from hydra.utils import instantiate
from torch_geometric.data import HeteroData
from anemoi.graphs.generate.icosahedral import create_icosahedral_nodes
from anemoi.graphs.generate.hexagonal import create_hexagonal_nodes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,3 +115,9 @@ def create_nodes(self) -> np.ndarray:
# TODO: AOI mask builder is not used in the current implementation.
return create_icosahedral_nodes(resolutions=self.resolutions)


class HexRefinedIcosahedralNodeBuilder(RefinedIcosahedralNodeBuilder):
"""It depends on the h3 Python library."""

def create_nodes(self) -> np.ndarray:
return create_hexagonal_nodes(self.resolutions)

0 comments on commit d0c20f6

Please sign in to comment.