Skip to content

Commit

Permalink
Merge branch 'develop' into feature/inspection_tool
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Aug 5, 2024
2 parents 797e139 + bfdac7d commit 0263f9c
Show file tree
Hide file tree
Showing 16 changed files with 347 additions and 173 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/changelog-pr-update.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Check Changelog Update on PR
on:
pull_request:
types: [assigned, opened, synchronize, reopened, labeled, unlabeled]
branches:
- main
- develop
jobs:
Check-Changelog:
name: Check Changelog Action
runs-on: ubuntu-20.04
steps:
- uses: tarides/changelog-check-action@v2
with:
changelog: CHANGELOG.md
2 changes: 2 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
with:
python-version: 3.x
- uses: pre-commit/[email protected]
env:
SKIP: no-commit-to-branch

checks:
strategy:
Expand Down
22 changes: 22 additions & 0 deletions .github/workflows/readthedocs-pr-update.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Read the Docs PR Preview
on:
pull_request_target:
types:
- opened
- synchronize
- reopened
# Execute this action only on PRs that touch
# documentation files.
paths:
- "docs/**"

permissions:
pull-requests: write

jobs:
documentation-links:
runs-on: ubuntu-latest
steps:
- uses: readthedocs/actions/preview@v1
with:
project-slug: "anemoi-graphs"
53 changes: 53 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

Please add your functional changes to the appropriate section in the PR.
Keep it human-readable, your future self will thank you!

## [Unreleased]

### Added
- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere
### Changed

### Removed

## [0.2.1] - Anemoi-graph Release, bug fix release

### Added

### Changed
- Fix The 'save_path' argument of the GraphCreator class is optional, allowing users to create graphs without saving them.

### Removed

## [0.2.0] - Anemoi-graph Release, Icosahedral graph building

### Added
- New node builders by iteratively refining an icosahedron: TriNodes, HexNodes.
- New edge builders for building multi-scale connections.
- Added Changelog

### Changed

### Removed

## [0.1.0] - Initial Release, Global graph building

### Added
- Documentation
- Initial implementation for global graph building on the fly from Zarr and NPZ datasets

### Changed

### Removed

<!-- Add Git Diffs for Links above -->
[unreleased]: https://github.com/ecmwf/anemoi-graphs/compare/0.2.1...HEAD
[0.2.1]: https://github.com/ecmwf/anemoi-graphs/compare/0.2.0...0.2.1
[0.2.0]: https://github.com/ecmwf/anemoi-graphs/compare/0.1.0...0.2.0
[0.1.0]: https://github.com/ecmwf/anemoi-graphs/releases/tag/0.1.0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies = [
"anemoi-datasets[data]>=0.3.3",
"anemoi-utils>=0.3.6",
"h3>=3.7.6,<4",
"healpy>=1.17",
"hydra-core>=1.3",
"networkx>=3.1",
"torch>=2.2",
Expand Down
18 changes: 10 additions & 8 deletions src/anemoi/graphs/commands/create.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

from anemoi.graphs.create import GraphCreator
from anemoi.graphs.inspector import GraphDescription

Expand All @@ -17,17 +19,17 @@ def add_arguments(self, command_parser):
help="Overwrite existing files. This will delete the target graph if it already exists.",
)
command_parser.add_argument("--description", action="store_false", help="Show the description of the graph.")
command_parser.add_argument("config", help="Configuration yaml file defining the recipe to create the graph.")
command_parser.add_argument("path", help="Path to store the created graph.")
command_parser.add_argument(
"config", help="Configuration yaml file path defining the recipe to create the graph."
)
command_parser.add_argument("save_path", type=Path, help="Path to store the created graph.")

def run(self, args):
kwargs = vars(args)

c = GraphCreator(**kwargs)
c.create()
graph_creator = GraphCreator(config=args.config)
graph_creator.create(save_path=args.save_path, overwrite=args.overwrite)

if kwargs.get("description", False):
GraphDescription(kwargs["path"]).describe()
if args.description:
GraphDescription(args.path).describe()


command = Create
107 changes: 66 additions & 41 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
import os
from itertools import chain
from pathlib import Path
from typing import Optional
from typing import Union

import torch
from anemoi.utils.config import DotDict
Expand All @@ -14,26 +17,9 @@ class GraphCreator:

def __init__(
self,
config=None,
path=None,
cache=None,
print=print,
overwrite=False,
**kwargs,
config: Union[Path, DotDict],
):
if isinstance(config, str) or isinstance(config, os.PathLike):
self.config = DotDict.from_file(config)
else:
self.config = config

self.path = path # Output path
self.cache = cache
self.print = print
self.overwrite = overwrite

def init(self):
if self._path_readable() and not self.overwrite:
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")
self.config = DotDict.from_file(config) if isinstance(config, Path) else config

def generate_graph(self) -> HeteroData:
"""Generate the graph.
Expand All @@ -59,30 +45,69 @@ def generate_graph(self) -> HeteroData:

return graph

def save(self, graph: HeteroData) -> None:
"""Save the graph to the output path."""
if self.path is None:
LOGGER.info("No output path specified. The graph will not be saved.")
elif not self.path.exists() or self.overwrite:
self.path.parent.mkdir(parents=True, exist_ok=True)
torch.save(graph, self.path)
LOGGER.info(f"Graph saved at {self.path}.")
def clean(self, graph: HeteroData) -> HeteroData:
"""Remove private attributes used during creation from the graph.
Parameters
----------
graph : HeteroData
generated graph
Returns
-------
HeteroData
cleaned graph
"""
for type_name in chain(graph.node_types, graph.edge_types):
for attr_name in graph[type_name].keys():
if attr_name.startswith("_"):
del graph[type_name][attr_name]

return graph

def save(self, graph: HeteroData, save_path: Path, overwrite: bool = False) -> None:
"""Save the generated graph to the output path.
Parameters
----------
graph : HeteroData
generated graph
save_path : Path
location to save the graph
overwrite : bool, optional
whether to overwrite existing graph file, by default False
"""
save_path = Path(save_path)

if not save_path.exists() or overwrite:
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(graph, save_path)
LOGGER.info(f"Graph saved at {save_path}.")
else:
LOGGER.info("Graph already exists. Use overwrite=True to overwrite.")

def create(self) -> HeteroData:
"""Create the graph and save it to the output path."""
self.init()
def create(self, save_path: Optional[Path] = None, overwrite: bool = False) -> HeteroData:
"""Create the graph and save it to the output path.
Parameters
----------
save_path : Path, optional
location to save the graph, by default None
overwrite : bool, optional
whether to overwrite existing graph file, by default False
Returns
-------
HeteroData
created graph object
"""

graph = self.generate_graph()
self.save(graph)
return graph
graph = self.clean(graph)

def _path_readable(self) -> bool:
"""Check if the output path is readable."""
import torch
if save_path is None:
LOGGER.warning("No output path specified. The graph will not be saved.")
else:
self.save(graph, save_path, overwrite)

try:
torch.load(self.path)
return True
except FileNotFoundError:
return False
return graph
18 changes: 9 additions & 9 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,26 +276,26 @@ def __init__(self, source_name: str, target_name: str, x_hops: int):
self.x_hops = x_hops

def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
x_hops=self.x_hops,
) # HeteroData refuses to accept None

adjmat = nx.to_scipy_sparse_array(
source_nodes["nx_graph"], nodelist=list(range(len(source_nodes["nx_graph"]))), format="coo"
source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo"
)
return adjmat

def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):

source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
x_hops=self.x_hops,
)

adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo")
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
return adjmat

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
Expand All @@ -311,7 +311,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
return adjmat

def post_process_adjmat(self, nodes: NodeStorage, adjmat):
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["node_ordering"])}
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])}
sort_func = np.vectorize(graph_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
Expand Down
57 changes: 54 additions & 3 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def get_coordinates(self) -> torch.Tensor:
def create_nodes(self) -> np.ndarray: ...

def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
graph[self.name]["resolutions"] = self.resolutions
graph[self.name]["nx_graph"] = self.nx_graph
graph[self.name]["node_ordering"] = self.node_ordering
graph[self.name]["_resolutions"] = self.resolutions
graph[self.name]["_nx_graph"] = self.nx_graph
graph[self.name]["_node_ordering"] = self.node_ordering
return super().register_attributes(graph, config)


Expand All @@ -244,3 +244,54 @@ class HexNodes(IcosahedralNodes):

def create_nodes(self) -> np.ndarray:
return create_hexagonal_nodes(self.resolutions)


class HEALPixNodes(BaseNodeBuilder):
"""Nodes from HEALPix grid.
HEALPix is an acronym for Hierarchical Equal Area isoLatitude Pixelization of a sphere.
Attributes
----------
resolution : int
The resolution of the grid.
name : str
The name of the nodes.
Methods
-------
get_coordinates()
Get the lat-lon coordinates of the nodes.
register_nodes(graph, name)
Register the nodes in the graph.
register_attributes(graph, name, config)
Register the attributes in the nodes of the graph specified.
update_graph(graph, name, attr_config)
Update the graph with new nodes and attributes.
"""

def __init__(self, resolution: int, name: str) -> None:
"""Initialize the HEALPixNodes builder."""
self.resolution = resolution
super().__init__(name)

assert isinstance(resolution, int), "Resolution must be an integer."
assert resolution > 0, "Resolution must be positive."

def get_coordinates(self) -> torch.Tensor:
"""Get the coordinates of the nodes.
Returns
-------
torch.Tensor of shape (N, 2)
Coordinates of the nodes.
"""
import healpy as hp

spatial_res_degrees = hp.nside2resol(2**self.resolution, arcmin=True) / 60
LOGGER.info(f"Creating HEALPix nodes with resolution {spatial_res_degrees:.2} deg.")

npix = hp.nside2npix(2**self.resolution)
hpxlon, hpxlat = hp.pix2ang(2**self.resolution, range(npix), nest=True, lonlat=True)

return self.reshape_coords(hpxlat, hpxlon)
Loading

0 comments on commit 0263f9c

Please sign in to comment.