Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 64 additions & 50 deletions training/src/anemoi/training/utils/variables_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,50 @@ def _crack_variable_name(variable_name: str) -> tuple[str, str | None]:
return variable_name, None


def filter_variables(variables: dict[str, Variable], criteria: dict) -> list[Variable]:
"""Filter variables based on criteria."""
result = []
for var in variables.values():
match = True
assert isinstance(criteria, dict), f"variable_groups criteria must be a dict. Got: {type(criteria)}"
for key, value in criteria.items():
attr = getattr(var, key, None)
if isinstance(value, list):
if attr not in value:
match = False
break
else:
if attr != value:
match = False
break

if match:
result.append(var)
return result


def build_variable_groups(
metadata_variables: dict[str, Variable],
variable_groups: dict,
default_group: str | None = None,
) -> dict[str, list[Variable]]:
"""Build variable groups from a dictionary"""
groups = {}
for group_name, group_spec in variable_groups.items():
# If the criteria is a list, it represent the variable names
group_spec = {"name": group_spec} if isinstance(group_spec, str | list) else group_spec
groups[group_name] = filter_variables(metadata_variables, group_spec)

if default_group is not None:
groups[default_group] = [
var
for var in metadata_variables.values()
if var.name not in {v.name for group in groups.values() for v in group}
]

return groups


class ExtractVariableGroupAndLevel:
"""Extract the group and level of a variable from dataset metadata and training-config file.

Expand Down Expand Up @@ -75,16 +119,20 @@ def __init__(
assert "default" in variable_groups, "Default group not defined in variable_groups"
self.default_group = variable_groups.pop("default")

self.variable_groups = variable_groups

self.metadata_variables: dict[str, Variable] = {
name: Variable.from_dict(name, val) if not isinstance(val, Variable) else val
for name, val in (metadata_variables or {}).items()
}
self.variable_group_spec = variable_groups
self.variable_groups = build_variable_groups(
self.metadata_variables,
variable_groups,
default_group=self.default_group,
)

def get_group_specification(self, group_name: str) -> GROUP_SPEC | dict[str, GROUP_SPEC]:
"""Get the specification of a group."""
return self.variable_groups[group_name]
return self.variable_group_spec.get(group_name, "default")

def get_group(self, variable_name: str) -> str:
"""Get the group of a variable.
Expand All @@ -99,53 +147,19 @@ def get_group(self, variable_name: str) -> str:
group : str
Group of the variable
"""
for group_name, group_spec in self.variable_groups.items():
if isinstance(group_spec, list | str):
# simple group
if self.get_param(variable_name) in (group_spec if isinstance(group_spec, list) else [group_spec]):
LOG.debug(
"Variable %r is in group %r",
variable_name,
group_name,
)
return group_name

elif isinstance(group_spec, dict):
# complex group
if variable_name not in self.metadata_variables:
if group_spec.keys() != {"param"}:
error_msg = (
f"Variable {variable_name} not found in metadata and `variable_groups` "
" must be a simple list or a dictionary with only the `param` key."
"\nPlease either provide metadata for the variable or simplify the `variable_groups`."
)
raise ValueError(error_msg)

if self.get_param(variable_name) in (
group_spec["param"] if isinstance(group_spec["param"], list) else [group_spec["param"]]
):
LOG.debug(
"Variable %r is in group %r through specification : %r.",
variable_name,
group_name,
group_spec,
)
return group_name
else:
var_metadata = self.metadata_variables.get(variable_name)
if all(
getattr(var_metadata, key) in (val if isinstance(val, list) else [val])
for key, val in group_spec.items()
):
LOG.debug(
"Variable %r is in group %r through specification : %r.",
variable_name,
group_name,
group_spec,
)
return group_name

return self.default_group
for group_name, variables in self.variable_groups.items():
if variable_name in [v.name for v in variables]:
LOG.debug(
"Variable %r is in group %r through specification : %r.",
variable_name,
group_name,
self.get_group_specification(group_name),
)
return group_name

raise ValueError(
f"{self.__class__.__name__} only supports variables found in the dataset. Variable {variable_name} not found.",
)

def _is_metadata_trusted(self, variable_name: str) -> bool:
"""Check if the metadata for a variable is trusted.
Expand Down
14 changes: 13 additions & 1 deletion training/tests/unit/utils/test_variable_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def mocked_variable_metadata() -> dict[str, Variable]:
"q_100": MockedVariable("q", "pl", "100"),
"q_200": MockedVariable("q", "pl", "200"),
"q_500": MockedVariable("q", "pl", "500"),
"q": MockedVariable("q", "sfc", None),
"z_500": MockedVariable("z", "pl", "500"),
"z_ml_500": MockedVariable("z", "ml", "500"),
"t_500": MockedVariable("t", "pl", "500"),
Expand All @@ -62,7 +63,7 @@ def mocked_variable_metadata() -> dict[str, Variable]:
}
COMPLEX_METADATA_LESS_GROUPS = {
"default": "default",
"pl": {"param": ["q", "z"]},
"pl": {"param": ["q", "z"], "is_surface_level": True},
}


Expand Down Expand Up @@ -96,6 +97,7 @@ def mocked_variable_metadata() -> dict[str, Variable]:
# Complex metadata-less groups
(COMPLEX_METADATA_LESS_GROUPS, "q_100", "pl"),
(COMPLEX_METADATA_LESS_GROUPS, "q_500", "pl"),
(COMPLEX_METADATA_LESS_GROUPS, "q", "sfc"),
(COMPLEX_METADATA_LESS_GROUPS, "z_500", "pl"),
(COMPLEX_METADATA_LESS_GROUPS, "z_123", "pl"),
(COMPLEX_METADATA_LESS_GROUPS, "2t", "default"),
Expand All @@ -113,6 +115,16 @@ def test_group_matching(
assert ExtractVariableGroupAndLevel(groups, variable_metadata).get_group(variable) == expected_group


@pytest.mark.parametrize("variable_name", ["non_existing", "asjndbf", "q_123", "z"])
def test_group_variable_not_found(variable_name: str) -> None:
"""Test that a ValueError is raised when the variable is not found in any group.

'z' is used in the group specification to group all z_{pl}, but it is not a variable itself.
"""
with pytest.raises(ValueError):
ExtractVariableGroupAndLevel(COMPLEX_METADATA_LESS_GROUPS, {}).get_group(variable_name)


@pytest.fixture
def mocked_variable_lacking_metadata() -> dict[str, Variable]:
return {}
Expand Down
Loading