Skip to content

Commit

Permalink
1 CAMS2_83 test failing. getter and setter for get_web_interface_name?
Browse files Browse the repository at this point in the history
  • Loading branch information
lewisblake committed Nov 17, 2024
1 parent 91ee520 commit f20ec94
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 100 deletions.
162 changes: 87 additions & 75 deletions pyaerocom/aeroval/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,75 @@ def web_interface_names(self) -> list:
pass


class ObsCollection(BaseCollection):
class Collection(abc.ABC):
def __init__(self):
self._entries = {}

def __iter__(self):
"""
Yield each entry in the collection.
"""
yield from self._entries.values()

@abc.abstractmethod
def add_entry(self, key, value) -> None:
pass

@abc.abstractmethod
def remove_entry(self, key) -> None:
pass

@abc.abstractmethod
def get_entry(self, key) -> object:
pass

def keylist(self, name_or_pattern: str = None) -> list[str]:
"""Find model / obs names that match input search pattern(s)
Parameters
----------
name_or_pattern : str, optional
Name or pattern specifying search string.
Returns
-------
list
list of keys in collection that match input requirements. If
`name_or_pattern` is None, all keys will be returned.
Raises
------
KeyError
if no matches can be found
"""
if name_or_pattern is None:
name_or_pattern = "*"

matches = []
for key in self._entries.keys():
if fnmatch(key, name_or_pattern) and key not in matches:
matches.append(key)
if len(matches) == 0:
raise KeyError(f"No matches could be found that match input {name_or_pattern}")
return matches

@property
def web_interface_names(self) -> list:
"""
List of web interface names for each obs entry
Returns
-------
list
"""
return self.keylist()

def to_json(self) -> str:
"""Serialize ModelCollection to a JSON string."""
return json.dumps({k: v.dict() for k, v in self._entries.items()}, default=str)


class ObsCollection(Collection):
"""
Dict-like object that represents a collection of obs entries
Expand All @@ -94,9 +162,16 @@ class ObsCollection(BaseCollection):
"""

SETTER_CONVERT = {dict: ObsEntry}
def add_entry(self, key: str, entry: dict | ObsEntry):
if isinstance(entry, dict):
entry = ObsEntry(**entry)
self._entries[key] = entry

def get_entry(self, key) -> object:
def remove_entry(self, key: str):
if key in self._entries:
del self._entries[key]

def get_entry(self, key) -> ObsEntry:
"""
Getter for obs entries
Expand All @@ -106,7 +181,7 @@ def get_entry(self, key) -> object:
if input name is not in this collection
"""
try:
entry = self[key]
entry = self._entries[key]
entry.obs_name = self.get_web_interface_name(key)
return entry
except (KeyError, AttributeError):
Expand All @@ -123,7 +198,7 @@ def get_all_vars(self) -> list[str]:
"""
vars = []
for ocfg in self.values():
for ocfg in self._entries.values():
vars.extend(ocfg.get_all_vars())
return sorted(list(set(vars)))

Expand All @@ -148,7 +223,12 @@ def get_web_interface_name(self, key):
corresponding name
"""
return self[key].web_interface_name if self[key].web_interface_name is not None else key
# LB: private method?
return (
self._entries[key].web_interface_name
if self._entries[key].web_interface_name is not None
else key
)

@property
def web_interface_names(self) -> list:
Expand All @@ -164,75 +244,7 @@ def web_interface_names(self) -> list:
@property
def all_vert_types(self):
"""List of unique vertical types specified in this collection"""
return list({x.obs_vert_type for x in self.values()})


class Collection(abc.ABC):
def __init__(self):
self._entries = {}

def __iter__(self):
"""
Yield each entry in the collection.
"""
yield from self._entries.values()

@abc.abstractmethod
def add_entry(self, key, value) -> None:
pass

@abc.abstractmethod
def remove_entry(self, key) -> None:
pass

@abc.abstractmethod
def get_entry(self, key) -> object:
pass

def keylist(self, name_or_pattern: str = None) -> list[str]:
"""Find model / obs names that match input search pattern(s)
Parameters
----------
name_or_pattern : str, optional
Name or pattern specifying search string.
Returns
-------
list
list of keys in collection that match input requirements. If
`name_or_pattern` is None, all keys will be returned.
Raises
------
KeyError
if no matches can be found
"""
if name_or_pattern is None:
name_or_pattern = "*"

matches = []
for key in self._entries.keys():
if fnmatch(key, name_or_pattern) and key not in matches:
matches.append(key)
if len(matches) == 0:
raise KeyError(f"No matches could be found that match input {name_or_pattern}")
return matches

@property
def web_interface_names(self) -> list:
"""
List of web interface names for each obs entry
Returns
-------
list
"""
return self.keylist()

def to_json(self) -> str:
"""Serialize ModelCollection to a JSON string."""
return json.dumps({k: v.dict() for k, v in self._entries.items()}, default=str)
return list({x.obs_vert_type for x in self._entries.values()})


class ModelCollection(Collection):
Expand Down
4 changes: 2 additions & 2 deletions pyaerocom/aeroval/experiment_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,8 @@ def _is_part_of_experiment(self, obs_name, obs_var, mod_name, mod_var) -> bool:
# occurence of web_interface_name).
allobs = self.cfg.obs_cfg
obs_matches = []
for key, ocfg in allobs.items():
if obs_name == allobs.get_web_interface_name(key):
for ocfg in allobs:
if obs_name == allobs.get_web_interface_name(ocfg.obs_name):
obs_matches.append(ocfg)
if len(obs_matches) == 0:
self._invalid["obs"].append(obs_name)
Expand Down
4 changes: 2 additions & 2 deletions pyaerocom/aeroval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_dummy_model(obs_list: list, cfg) -> str:
tmp_var_obj = Variable()
# Loops over variables in obs
for obs in obs_list:
for var in cfg.obs_cfg[obs].obs_vars:
for var in cfg.obs_cfg.get_entry(obs).obs_vars:
# Create dummy cube

dummy_cube = make_dummy_cube(var, start_yr=start, stop_yr=stop, freq=freq)
Expand All @@ -185,7 +185,7 @@ def make_dummy_model(obs_list: list, cfg) -> str:
for dummy_grid_yr in yr_gen:
# Add to netcdf
yr = dummy_grid_yr.years_avail()[0]
vert_code = cfg.obs_cfg[obs].obs_vert_type
vert_code = cfg.obs_cfg.get_entry(obs).obs_vert_type

save_name = dummy_grid_yr.aerocom_savename(model_id, var, vert_code, yr, freq)
dummy_grid_yr.to_netcdf(outdir, savename=save_name)
Expand Down
22 changes: 15 additions & 7 deletions pyaerocom/aeroval/setup_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,17 +501,25 @@ def colocation_opts(self) -> ColocationSetup:

# These attributes require special attention b/c they're not based on Pydantic's BaseModel class.

obs_cfg: ObsCollection | dict = ObsCollection()
# obs_cfg: ObsCollection | dict = ObsCollection()

@field_validator("obs_cfg")
def validate_obs_cfg(cls, v):
if isinstance(v, ObsCollection):
return v
return ObsCollection(v)
# @field_validator("obs_cfg")
# def validate_obs_cfg(cls, v):
# if isinstance(v, ObsCollection):
# return v
# return ObsCollection(v)

@computed_field
@cached_property
def obs_cfg(self) -> ObsCollection:
oc = ObsCollection()
for k, v in self.model_extra.get("obs_cfg", {}).items():
oc.add_entry(k, v)
return oc

@field_serializer("obs_cfg")
def serialize_obs_cfg(self, obs_cfg: ObsCollection):
return obs_cfg.json_repr()
return obs_cfg.to_json()

# model_cfg: ModelCollection | dict = ModelCollection()

Expand Down
5 changes: 3 additions & 2 deletions pyaerocom/aeroval/superobs_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def _run_var(self, model_name, obs_name, var_name, try_colocate_if_missing):
coldata_files = []
coldata_resolutions = []
vert_codes = []
obs_needed = self.cfg.obs_cfg[obs_name].obs_id
vert_code = self.cfg.obs_cfg.get_entry(obs_name).obs_vert_type
obs_entry = self.cfg.obs_cfg.get_entry(obs_name)
obs_needed = obs_entry.obs_id
vert_code = obs_entry.obs_vert_type
for oname in obs_needed:
fp, ts_type, vert_code = self._get_coldata_fileinfo(
model_name, oname, var_name, try_colocate_if_missing
Expand Down
4 changes: 2 additions & 2 deletions tests/aeroval/test_aeroval_HIGHLEV.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_superobs_different_resolutions(eval_config: dict):
cfg.model_cfg.get_entry("TM5-AP3-CTRL").model_ts_type_read = None
cfg.model_cfg.get_entry("TM5-AP3-CTRL").flex_ts_type = True

cfg.obs_cfg["AERONET-Sun"].ts_type = "daily"
cfg.obs_cfg["AERONET-SDA"].ts_type = "monthly"
cfg.obs_cfg.get_entry("AERONET-Sun").ts_type = "daily"
cfg.obs_cfg.get_entry("AERONET-SDA").ts_type = "monthly"

proc = ExperimentProcessor(cfg)
proc.exp_output.delete_experiment_data(also_coldata=True)
Expand Down
18 changes: 11 additions & 7 deletions tests/aeroval/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@


def test_obscollection():
oc = ObsCollection(model1=dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column"))
oc = ObsCollection()
oc.add_entry("model1", dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column"))
assert oc

oc["AN-EEA-MP"] = dict(
is_superobs=True,
obs_id=("AirNow", "EEA-NRT-rural", "MarcoPolo"),
obs_vars=["concpm10", "concpm25", "vmro3", "vmrno2"],
obs_vert_type="Surface",
oc.add_entry(
"AN-EEA-MP",
dict(
is_superobs=True,
obs_id=("AirNow", "EEA-NRT-rural", "MarcoPolo"),
obs_vars=["concpm10", "concpm25", "vmro3", "vmrno2"],
obs_vert_type="Surface",
),
)

assert "AN-EEA-MP" in oc
assert "AN-EEA-MP" in oc.keylist()


def test_modelcollection():
Expand Down
3 changes: 1 addition & 2 deletions tests/aeroval/test_experiment_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,9 @@ def test_Experiment_Output_clean_json_files_CFG1_INVALIDMOD(eval_config: dict):
@pytest.mark.parametrize("cfg", ["cfgexp1"])
def test_Experiment_Output_clean_json_files_CFG1_INVALIDOBS(eval_config: dict):
cfg = EvalSetup(**eval_config)
cfg.obs_cfg["obs1"] = cfg.obs_cfg["AERONET-Sun"]
cfg.obs_cfg.add_entry("obs1", cfg.obs_cfg.get_entry("AERONET-Sun"))
proc = ExperimentProcessor(cfg)
proc.run()
cfg.obs_cfg.remove_entry("obs1")
modified = proc.exp_output.clean_json_files()
assert len(modified) == 13

Expand Down
2 changes: 1 addition & 1 deletion tests/aeroval/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ def test__get_min_max_year_periods_error():
@pytest.mark.parametrize("cfg", ["cfgexp1"])
def test_make_dummy_model(eval_config: dict):
cfg = EvalSetup(**eval_config)
assert cfg.obs_cfg["AERONET-Sun"]
assert cfg.obs_cfg.get_entry("AERONET-Sun")
model_id = make_dummy_model(["AERONET-Sun"], cfg)
assert model_id == "dummy_model"

0 comments on commit f20ec94

Please sign in to comment.