|
15 | 15 |
|
16 | 16 | import networkx as nx
|
17 | 17 | import numpy as np
|
| 18 | +import scipy |
18 | 19 | import torch
|
19 | 20 | from anemoi.utils.config import DotDict
|
20 | 21 | from hydra.utils import instantiate
|
@@ -80,10 +81,8 @@ def get_edge_index(self, graph: HeteroData) -> torch.Tensor:
|
80 | 81 | source_nodes, target_nodes = self.prepare_node_data(graph)
|
81 | 82 |
|
82 | 83 | adjmat = self.get_adjacency_matrix(source_nodes, target_nodes)
|
83 |
| - |
84 | 84 | # Get source & target indices of the edges
|
85 | 85 | edge_index = np.stack([adjmat.col, adjmat.row], axis=0)
|
86 |
| - |
87 | 86 | return torch.from_numpy(edge_index).to(torch.int32)
|
88 | 87 |
|
89 | 88 | def register_edges(self, graph: HeteroData) -> HeteroData:
|
@@ -381,7 +380,13 @@ class MultiScaleEdges(BaseEdgeBuilder):
|
381 | 380 | Update the graph with the edges.
|
382 | 381 | """
|
383 | 382 |
|
384 |
| - VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes, StretchedTriNodes] |
| 383 | + VALID_NODES = [ |
| 384 | + TriNodes, |
| 385 | + HexNodes, |
| 386 | + LimitedAreaTriNodes, |
| 387 | + LimitedAreaHexNodes, |
| 388 | + StretchedTriNodes, |
| 389 | + ] |
385 | 390 |
|
386 | 391 | def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
|
387 | 392 | super().__init__(source_name, target_name)
|
@@ -444,3 +449,129 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
|
444 | 449 | ), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."
|
445 | 450 |
|
446 | 451 | return super().update_graph(graph, attrs_config)
|
| 452 | + |
| 453 | + |
| 454 | +class ICONTopologicalBaseEdgeBuilder(BaseEdgeBuilder): |
| 455 | + """Base class for computing edges based on ICON grid topology. |
| 456 | +
|
| 457 | + Attributes |
| 458 | + ---------- |
| 459 | + source_name : str |
| 460 | + The name of the source nodes. |
| 461 | + target_name : str |
| 462 | + The name of the target nodes. |
| 463 | + icon_mesh : str |
| 464 | + The name of the ICON mesh (defines both the processor mesh and the data) |
| 465 | + """ |
| 466 | + |
| 467 | + def __init__( |
| 468 | + self, |
| 469 | + source_name: str, |
| 470 | + target_name: str, |
| 471 | + icon_mesh: str, |
| 472 | + source_mask_attr_name: str | None = None, |
| 473 | + target_mask_attr_name: str | None = None, |
| 474 | + ): |
| 475 | + self.icon_mesh = icon_mesh |
| 476 | + super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name) |
| 477 | + |
| 478 | + def update_graph(self, graph: HeteroData, attrs_config: DotDict = None) -> HeteroData: |
| 479 | + """Update the graph with the edges.""" |
| 480 | + assert self.icon_mesh is not None, f"{self.__class__.__name__} requires initialized icon_mesh." |
| 481 | + self.icon_sub_graph = graph[self.icon_mesh][self.sub_graph_address] |
| 482 | + return super().update_graph(graph, attrs_config) |
| 483 | + |
| 484 | + def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): |
| 485 | + """Parameters |
| 486 | + ---------- |
| 487 | + source_nodes : NodeStorage |
| 488 | + The source nodes. |
| 489 | + target_nodes : NodeStorage |
| 490 | + The target nodes. |
| 491 | + """ |
| 492 | + LOGGER.info(f"Using ICON topology {self.source_name}>{self.target_name}") |
| 493 | + nrows = self.icon_sub_graph.num_edges |
| 494 | + adj_matrix = scipy.sparse.coo_matrix( |
| 495 | + ( |
| 496 | + np.ones(nrows), |
| 497 | + ( |
| 498 | + self.icon_sub_graph.edge_vertices[:, self.vertex_index[0]], |
| 499 | + self.icon_sub_graph.edge_vertices[:, self.vertex_index[1]], |
| 500 | + ), |
| 501 | + ) |
| 502 | + ) |
| 503 | + return adj_matrix |
| 504 | + |
| 505 | + |
| 506 | +class ICONTopologicalProcessorEdges(ICONTopologicalBaseEdgeBuilder): |
| 507 | + """Computes edges based on ICON grid topology: processor grid built |
| 508 | + from ICON grid vertices. |
| 509 | + """ |
| 510 | + |
| 511 | + def __init__( |
| 512 | + self, |
| 513 | + source_name: str, |
| 514 | + target_name: str, |
| 515 | + icon_mesh: str, |
| 516 | + source_mask_attr_name: str | None = None, |
| 517 | + target_mask_attr_name: str | None = None, |
| 518 | + ): |
| 519 | + self.sub_graph_address = "_multi_mesh" |
| 520 | + self.vertex_index = (1, 0) |
| 521 | + super().__init__( |
| 522 | + source_name, |
| 523 | + target_name, |
| 524 | + icon_mesh, |
| 525 | + source_mask_attr_name, |
| 526 | + target_mask_attr_name, |
| 527 | + ) |
| 528 | + |
| 529 | + |
| 530 | +class ICONTopologicalEncoderEdges(ICONTopologicalBaseEdgeBuilder): |
| 531 | + """Computes encoder edges based on ICON grid topology: ICON cell |
| 532 | + circumcenters for mapped onto processor grid built from ICON grid |
| 533 | + vertices. |
| 534 | + """ |
| 535 | + |
| 536 | + def __init__( |
| 537 | + self, |
| 538 | + source_name: str, |
| 539 | + target_name: str, |
| 540 | + icon_mesh: str, |
| 541 | + source_mask_attr_name: str | None = None, |
| 542 | + target_mask_attr_name: str | None = None, |
| 543 | + ): |
| 544 | + self.sub_graph_address = "_cell_grid" |
| 545 | + self.vertex_index = (1, 0) |
| 546 | + super().__init__( |
| 547 | + source_name, |
| 548 | + target_name, |
| 549 | + icon_mesh, |
| 550 | + source_mask_attr_name, |
| 551 | + target_mask_attr_name, |
| 552 | + ) |
| 553 | + |
| 554 | + |
| 555 | +class ICONTopologicalDecoderEdges(ICONTopologicalBaseEdgeBuilder): |
| 556 | + """Computes encoder edges based on ICON grid topology: mapping from |
| 557 | + processor grid built from ICON grid vertices onto ICON cell |
| 558 | + circumcenters. |
| 559 | + """ |
| 560 | + |
| 561 | + def __init__( |
| 562 | + self, |
| 563 | + source_name: str, |
| 564 | + target_name: str, |
| 565 | + icon_mesh: str, |
| 566 | + source_mask_attr_name: str | None = None, |
| 567 | + target_mask_attr_name: str | None = None, |
| 568 | + ): |
| 569 | + self.sub_graph_address = "_cell_grid" |
| 570 | + self.vertex_index = (0, 1) |
| 571 | + super().__init__( |
| 572 | + source_name, |
| 573 | + target_name, |
| 574 | + icon_mesh, |
| 575 | + source_mask_attr_name, |
| 576 | + target_mask_attr_name, |
| 577 | + ) |
0 commit comments