Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/inference/configs/inputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,39 @@ You can specify the input as ``grib`` to read the data from a GRIB file.

For more options, see :ref:`grib-input`.

********
netcdf
********

You can specify the input as ``netcdf`` to read the data from a NetCDF
file.

.. literalinclude:: yaml/inputs_5.yaml
:language: yaml

.. note::

The netcdf input expects the data to be stored identically to the
GRIB input, i.e. with a values coordinate, and variable names that
match what the checkpoint was trained on.

For context, using ``earthkit-data`` a grib file can be exported as
netcdf with the following and then read back in using the netcdf
input, note how the names are created:

.. code:: python

import earthkit.data as ekd

ds = ekd.from_source("file", "input.grib")

ds_xr = ds.to_xarray(
variable_key="p_l",
remapping={"p_l": "{param}_{levelist}"},
add_earthkit_attrs=False,
)
ds_xr.to_netcdf("input.nc")

******
mars
******
Expand Down
2 changes: 2 additions & 0 deletions docs/inference/configs/yaml/inputs_5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
input:
netcdf: /path/to/netcdf/file.nc
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ dependencies = [

optional-dependencies.all = [
"anemoi-datasets",
"anemoi-inference[huggingface,plot,tests,zarr]",
"anemoi-inference[huggingface,netcdf,plot,tests,zarr]",
"anemoi-utils[all]>=0.4.32",
]
optional-dependencies.cosmo = [
Expand All @@ -82,6 +82,7 @@ optional-dependencies.docs = [
]

optional-dependencies.huggingface = [ "huggingface-hub" ]
optional-dependencies.netcdf = [ "earthkit-data>=0.18.5" ]
optional-dependencies.plot = [ "earthkit-plots" ]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
Expand Down
116 changes: 115 additions & 1 deletion src/anemoi/inference/inputs/ekd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import glob
import logging
import os
import re
from collections import defaultdict
from collections.abc import Callable
from functools import cached_property
from typing import Any

import earthkit.data as ekd
Expand All @@ -21,6 +23,7 @@
from numpy.typing import DTypeLike

from anemoi.inference.context import Context
from anemoi.inference.decorators import main_argument
from anemoi.inference.types import Date
from anemoi.inference.types import FloatArray
from anemoi.inference.types import State
Expand Down Expand Up @@ -450,3 +453,114 @@ def get_geography_info(key: str) -> str | None:

if geography_information:
state["_geography"] = geography_information


@main_argument("path")
class FieldlistInput(EkdInput):
"""Handles earthkit-data FieldList as input."""

patterns: tuple[str, ...]

def __init__(
self,
context: Context,
*,
path: str,
**kwargs: Any,
) -> None:
"""Initialise the FieldlistInput.

Parameters
----------
context : Any
The context in which the input is used.
path : str
Path, directory or glob pattern to file(s). Examples:
- "/path/to/file.grib"
- "/path/to/*.grib"
- "/path/to/**/*.grib2"
- "/path/to/directory/"
namer : Optional[Any]
Optional namer for the input.
**kwargs : Any
Additional keyword arguments.
"""
super().__init__(context, **kwargs)
self.path = path

def create_input_state(self, *, date: Date | None, ref_date_index: int = -1, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
ref_date_index : int, default -1
The reference date index to use.
**kwargs : Any
Additional keyword arguments.

Returns
-------
State
The created input state.
"""
return self._create_input_state(self._fieldlist, date=date, ref_date_index=ref_date_index)

def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State:
"""Load the forcings state for the given variables and dates.

Parameters
----------
dates : List[Date]
List of dates for which to load the forcings.
current_state : State
The current state of the input.

Returns
-------
State
The loaded forcings state.
"""

return self._load_forcings_state(
self._fieldlist,
dates=dates,
current_state=current_state,
)

@cached_property
def _fieldlist(self) -> ekd.FieldList:
"""Get the input fieldlist from the file or collection."""
path = self.path

# Case 1: explicit glob pattern
if glob.has_magic(path):
matches = glob.glob(path, recursive=True)
files = [p for p in matches if os.path.isfile(p)]
if not files:
LOG.warning("No files matched pattern %r", path)
return ekd.from_source("empty") # type: ignore[reportReturnType]
return ekd.from_source("file", sorted(files)) # type: ignore[reportReturnType]

# Case 2: directory path -> search for files recursively
if os.path.isdir(path):
files = []
for pat in self.patterns:
files.extend(glob.glob(os.path.join(path, "**", pat), recursive=True))
files = [f for f in sorted(set(files)) if os.path.isfile(f)]
if not files:
LOG.warning("Directory %r contains no files which match patterns %r", path, self.patterns)
return ekd.from_source("empty") # type: ignore[reportReturnType]
return ekd.from_source("file", files) # type: ignore[reportReturnType]

# Case 3: single file path
try:
if os.path.getsize(path) == 0:
LOG.warning("File %r is empty", path)
return ekd.from_source("empty") # type: ignore[reportReturnType]
except FileNotFoundError:
LOG.warning("Path %r not found", path)
return ekd.from_source("empty") # type: ignore[reportReturnType]

return ekd.from_source("file", path) # type: ignore[reportReturnType]
125 changes: 3 additions & 122 deletions src/anemoi/inference/inputs/gribfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,133 +8,14 @@
# nor does it submit to any jurisdiction.


import glob
import logging
import os
from functools import cached_property
from typing import Any

import earthkit.data as ekd

from anemoi.inference.context import Context
from anemoi.inference.types import Date
from anemoi.inference.types import State

from ..decorators import main_argument
from . import input_registry
from .ekd import FieldlistInput
from .grib import GribInput

LOG = logging.getLogger(__name__)


@input_registry.register("grib")
@main_argument("path")
class GribFileInput(GribInput):
class GribFileInput(FieldlistInput, GribInput):
"""Handles grib files."""

trace_name = "grib file"

def __init__(
self,
context: Context,
*,
path: str,
**kwargs: Any,
) -> None:
"""Initialize the GribFileInput.

Parameters
----------
context : Any
The context in which the input is used.
path : str
Path, directory or glob pattern to GRIB file(s). Examples:
- "/path/to/file.grib"
- "/path/to/*.grib"
- "/path/to/**/*.grib2"
- "/path/to/directory/"
namer : Optional[Any]
Optional namer for the input.
**kwargs : Any
Additional keyword arguments.
"""
super().__init__(context, **kwargs)
self.path = path

def create_input_state(self, *, date: Date | None, ref_date_index: int = -1, **kwargs) -> State:
"""Create the input state for the given date.

Parameters
----------
date : Optional[Date]
The date for which to create the input state.
ref_date_index : int, default -1
The reference date index to use.
**kwargs : Any
Additional keyword arguments.

Returns
-------
State
The created input state.
"""
return self._create_input_state(self._fieldlist, date=date, ref_date_index=ref_date_index)

def load_forcings_state(self, *, dates: list[Date], current_state: State) -> State:
"""Load the forcings state for the given variables and dates.

Parameters
----------
dates : List[Date]
List of dates for which to load the forcings.
current_state : State
The current state of the input.

Returns
-------
State
The loaded forcings state.
"""

return self._load_forcings_state(
self._fieldlist,
dates=dates,
current_state=current_state,
)

@cached_property
def _fieldlist(self) -> ekd.FieldList:
"""Get the input fieldlist from the GRIB file or collection."""
path = self.path

# Case 1: explicit glob pattern
if glob.has_magic(path):
matches = glob.glob(path, recursive=True)
files = [p for p in matches if os.path.isfile(p)]
if not files:
LOG.warning("No GRIB files matched pattern %r", path)
return ekd.from_source("empty")
return ekd.from_source("file", sorted(files))

# Case 2: directory path -> search for GRIB files recursively
if os.path.isdir(path):
patterns = ("*.grib", "*.grib1", "*.grib2", "*.grb", "*.grb2")
files = []
for pat in patterns:
files.extend(glob.glob(os.path.join(path, "**", pat), recursive=True))
files = [f for f in sorted(set(files)) if os.path.isfile(f)]
if not files:
LOG.warning("GRIB directory %r contains no GRIB files", path)
return ekd.from_source("empty")
return ekd.from_source("file", files)

# Case 3: single file path
try:
if os.path.getsize(path) == 0:
LOG.warning("GRIB file %r is empty", path)
return ekd.from_source("empty")
except FileNotFoundError:
LOG.warning("GRIB path %r not found", path)
return ekd.from_source("empty")

return ekd.from_source("file", path)
patterns = ("*.grib", "*.grb", "*.grb2", "*.grib2")
20 changes: 20 additions & 0 deletions src/anemoi/inference/inputs/netcdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from . import input_registry
from .ekd import FieldlistInput


@input_registry.register("netcdf")
class NetcdfFileInput(FieldlistInput):
"""Handles netcdf files."""

trace_name = "netcdf file"
patterns = ("*.nc", "*.netcdf")
4 changes: 2 additions & 2 deletions src/anemoi/inference/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LOG = logging.getLogger(__name__)


class FieldListInput(GribInput):
class PluginInput(GribInput):
"""Handles earthkit-data fieldlists input fields."""

def __init__(self, context: Any, *, input_fields: Any, **kwargs) -> None:
Expand Down Expand Up @@ -195,7 +195,7 @@ def run(self) -> None:
input_kwargs = self.input.anemoi_plugin_input_kwargs()
output_kwargs = self.input.anemoi_plugin_input_kwargs()

input = FieldListInput(self.runner, input_fields=self.all_fields, **input_kwargs)
input = PluginInput(self.runner, input_fields=self.all_fields, **input_kwargs)
output = CallbackOutput(self.runner, write=self.write, **output_kwargs)

input_state = input.create_input_state(date=self.start_datetime)
Expand Down
Loading
Loading