From 3f0caa63eb759ec9a791fa7806f3c1a280e174b1 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 23 Oct 2024 06:46:02 +0000 Subject: [PATCH] Add more attributes to typed variables --- CHANGELOG.md | 6 ++++++ dev/dev.py | 14 +++++++++++--- src/anemoi/transform/grids/unstructured.py | 3 +-- src/anemoi/transform/sources/mars.py | 12 +++++++++--- src/anemoi/transform/variables/variables.py | 17 ++++++++++------- tests/test_grids.py | 2 +- 6 files changed, 38 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 063481c..fc0dbc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,3 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! + +## [Unreleased] + +### Changed + +- Add more attributes to typed variables diff --git a/dev/dev.py b/dev/dev.py index 50210c5..5d51de1 100644 --- a/dev/dev.py +++ b/dev/dev.py @@ -6,13 +6,16 @@ mars = source_factory( "mars", +) + +r = dict( param=["u", "v", "t", "q"], grid=[20, 20], date="20200101/to/20200105", levelist=[1000, 850, 500], ) -data = mars.forward(None) +data = mars.forward(r) for f in data: print(f) @@ -35,11 +38,16 @@ ################ pipeline = workflow_factory("pipeline", filters=[mars, uv_2_ddff, ddff_2_uv]) -for f in pipeline(None): +for f in pipeline(r): print(f) ################ -pipeline = mars | uv_2_ddff | ddff_2_uv + + +pipeline = r | mars | uv_2_ddff | ddff_2_uv for f in pipeline: print(f) + + +ipipe = pipeline.to_infernece() diff --git a/src/anemoi/transform/grids/unstructured.py b/src/anemoi/transform/grids/unstructured.py index cde6008..7b6f6be 100644 --- a/src/anemoi/transform/grids/unstructured.py +++ b/src/anemoi/transform/grids/unstructured.py @@ -28,7 +28,6 @@ def __init__(self, latitudes, longitudes, uuidOfHGrid=None): assert isinstance(latitudes, np.ndarray), type(latitudes) assert isinstance(longitudes, np.ndarray), type(longitudes) - LOG.info(f"Latitudes: {len(latitudes)}, Longitudes: {len(longitudes)}") assert len(latitudes) == len(longitudes) self.uuidOfHGrid = uuidOfHGrid @@ -95,7 +94,7 @@ def from_grib(cls, latitudes_url_or_path, longitudes_url_or_path, latitudes_para return cls([UnstructuredGridField(Geography(latitudes, longitudes))]) @classmethod - def from_values(cls, latitudes, longitudes): + def from_values(cls, *, latitudes, longitudes): if isinstance(latitudes, (list, tuple)): latitudes = np.array(latitudes) diff --git a/src/anemoi/transform/sources/mars.py b/src/anemoi/transform/sources/mars.py index 1b93af3..5f627fb 100644 --- a/src/anemoi/transform/sources/mars.py +++ b/src/anemoi/transform/sources/mars.py @@ -18,12 +18,18 @@ class Mars(Source): """A demo source""" def __init__(self, **request): - self.request = request + pass def forward(self, data): - assert data is None + return ekd.from_source("mars", **data) - return ekd.from_source("mars", **self.request) + def __ror__(self, data): + + class Input: + def __init__(self, data): + self.data = data + + return Input(data) register_source("mars", Mars) diff --git a/src/anemoi/transform/variables/variables.py b/src/anemoi/transform/variables/variables.py index 2adfc7e..d53380e 100644 --- a/src/anemoi/transform/variables/variables.py +++ b/src/anemoi/transform/variables/variables.py @@ -5,7 +5,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import json from . import Variable @@ -16,11 +15,7 @@ class VariableFromMarsVocabulary(Variable): def __init__(self, name, data: dict) -> None: super().__init__(name) self.data = data - print(json.dumps(data, indent=4)) - if "mars" in self.data: - self.mars = self.data["mars"] - else: - self.mars = self.data + self.mars = self.data.get("mars", {}) @property def is_pressure_level(self): @@ -32,7 +27,15 @@ def level(self): @property def is_constant_in_time(self): - return self.data.get("is_constant_in_time", False) + return self.data.get("constant_in_time", False) + + @property + def is_from_input(self): + return "mars" in self.data + + @property + def is_computed_forcing(self): + return self.data.get("computed_forcing", False) class VariableFromDict(VariableFromMarsVocabulary): diff --git a/tests/test_grids.py b/tests/test_grids.py index 36558d9..3ee1698 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -16,7 +16,7 @@ tlon = "tlon" -def test_unstructured_from_url(): +def do_not_test_unstructured_from_url(): ds = UnstructuredGridFieldList.from_grib(latitude_url, longitudes_url, tlat, tlon) assert len(ds) == 1