Skip to content

Commit 79faf00

Browse files
mnjowetamuri
andauthored
Add labels to disease module parameters (#1610)
Co-authored-by: Asif Tamuri <[email protected]>
1 parent 5e6b48a commit 79faf00

File tree

4 files changed

+86
-8
lines changed

4 files changed

+86
-8
lines changed

src/tlo/core.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import json
1010
from enum import Enum, auto
11-
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional
11+
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List
1212

1313
import numpy as np
1414
import pandas as pd
@@ -120,6 +120,14 @@ def __repr__(self) -> str:
120120

121121
class Parameter(Specifiable):
122122
"""Used to specify parameters for disease modules etc."""
123+
def __init__(self,
124+
type_: Types,
125+
description: str,
126+
categories: List[str] = None,
127+
*,
128+
metadata: Optional[Dict[str, Any]] = None):
129+
super().__init__(type_, description, categories)
130+
self.metadata = metadata or {}
123131

124132

125133
class Property(Specifiable):
@@ -321,27 +329,41 @@ def load_parameters_from_dataframe(self, resource: pd.DataFrame) -> None:
321329
322330
:param DataFrame resource: DataFrame with a column of the parameter_name and a column of `value`
323331
"""
332+
324333
resource.set_index('parameter_name', inplace=True)
325334
skipped_data_types = ('DATA_FRAME', 'SERIES')
335+
acceptable_labels = ['unassigned', 'undetermined', 'universal', 'local', 'scenario']
336+
param_defaults = {'param_label': 'unassigned', 'prior_min': None, 'prior_max': None }
337+
338+
for _col in param_defaults.keys():
339+
if _col not in resource.columns:
340+
resource[_col] = param_defaults[_col]
326341
# for each supported parameter, convert to the correct type
327342
for parameter_name in resource.index[resource.index.notnull()]:
328343
parameter_definition = self.PARAMETERS[parameter_name]
329-
330344
if parameter_definition.type_.name in skipped_data_types:
331345
continue
332346

333347
# For each parameter, raise error if the value can't be coerced
334-
parameter_value = resource.at[parameter_name, 'value']
348+
parameter_value, prior_min, prior_max = resource.loc[parameter_name, ['value', 'prior_min', 'prior_max']]
349+
parameter_label = resource.at[parameter_name, 'param_label']
350+
assert parameter_label in acceptable_labels, f'unrecognised parameter label {parameter_label}'
351+
335352
error_message = (
336-
f"The value of '{parameter_value}' for parameter '{parameter_name}' "
337-
f"could not be parsed as a {parameter_definition.type_.name} data type"
353+
f"some values are not of type {parameter_definition.type_.name} and "
354+
f"could not be parsed as a {parameter_definition.type_.name} data type. "
355+
f"parameter name is {parameter_name}, values {[parameter_value, prior_min, prior_max]}"
338356
)
339357
if parameter_definition.python_type is list:
340358
try:
341359
# chose json.loads instead of save_eval
342360
# because it raises error instead of joining two strings without a comma
343361
parameter_value = json.loads(parameter_value)
344362
assert isinstance(parameter_value, list)
363+
if pd.notnull(prior_min):
364+
assert isinstance(json.loads(prior_min), list)
365+
if pd.notnull(prior_max):
366+
assert isinstance(json.loads(prior_max), list)
345367
except (json.decoder.JSONDecodeError, TypeError, AssertionError) as exception:
346368
raise ValueError(error_message) from exception
347369
elif parameter_definition.python_type == pd.Categorical:
@@ -358,11 +380,22 @@ def load_parameters_from_dataframe(self, resource: pd.DataFrame) -> None:
358380
# All other data types, assign to the python_type defined in Parameter class
359381
try:
360382
parameter_value = parameter_definition.python_type(parameter_value)
383+
if not isinstance(parameter_definition.python_type, pd.Timestamp):
384+
if pd.notnull(prior_min):
385+
parameter_definition.python_type(prior_min)
386+
if pd.notnull(prior_max):
387+
parameter_definition.python_type(prior_max)
361388
except Exception as exception:
362389
raise ValueError(error_message) from exception
363390

364391
# Save the values to the parameters
365392
self.parameters[parameter_name] = parameter_value
393+
# Assign metadata to the Parameter object
394+
parameter_definition.metadata.update(
395+
param_label=parameter_label,
396+
prior_min=prior_min,
397+
prior_max=prior_max
398+
)
366399

367400
def read_parameters(self, data_folder: str | Path) -> None:
368401
"""Read parameter values from file, if required.

src/tlo/simulation.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import heapq
77
import itertools
88
import time
9-
from collections import OrderedDict
9+
from collections import Counter, OrderedDict
1010
from pathlib import Path
1111
from typing import TYPE_CHECKING, Optional
1212

@@ -116,6 +116,7 @@ def __init__(
116116
self._custom_log_levels = None
117117
self._log_filepath = self._configure_logging(**log_config)
118118

119+
119120
# random number generator
120121
seed_from = "auto" if seed is None else "user"
121122
self._seed = seed
@@ -307,8 +308,15 @@ def finalise(self, wall_clock_time: Optional[float] = None) -> None:
307308
:param wall_clock_time: Optional argument specifying total time taken to
308309
simulate, to be written out to log before closing.
309310
"""
310-
for module in self.modules.values():
311+
for module_name, module in self.modules.items():
311312
module.on_simulation_end()
313+
if hasattr(module, "PARAMETERS"):
314+
# collect the module's parameter labels
315+
labels = [p.metadata.get("param_label", "not_init_via_load_param") for p in module.PARAMETERS.values()]
316+
labels = Counter(labels)
317+
for label, count in labels.items():
318+
logger.info(key="parameter_stats", data={"module": module_name, "label": label, "count": count})
319+
312320
if wall_clock_time is not None:
313321
logger.info(key="info", data=f"simulate() {wall_clock_time} s")
314322
self.close_output_file()

tests/test_core.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,43 @@ def test_bools(self):
166166
assert self.module.parameters['bool_true'] is True
167167
assert self.module.parameters['bool_false'] is False
168168

169+
def test_unacceptable_labels(self):
170+
""" label not acceptable for parameter label
171+
172+
should raise an assertion error """
173+
resource = self.resource.copy()
174+
resource['param_label'] = 'free'
175+
with pytest.raises(AssertionError, match="unrecognised parameter label"):
176+
self.module.load_parameters_from_dataframe(resource)
177+
178+
def test_unacceptable_lower_value(self):
179+
""" check unacceptable for lower value
180+
181+
should raise a value error """
182+
resource = self.resource.copy()
183+
resource['prior_min'] = 'a'
184+
with pytest.raises(ValueError):
185+
self.module.load_parameters_from_dataframe(resource)
186+
187+
def test_unacceptable_upper_value(self):
188+
""" check unacceptable for upper value
189+
190+
should raise a value error """
191+
resource = self.resource.copy()
192+
resource['prior_max'] = 'b'
193+
with pytest.raises(ValueError):
194+
self.module.load_parameters_from_dataframe(resource)
195+
196+
def test_list_type_parameter_value_has_list_type_lower_upper_value(self):
197+
""" assign integer and float values to lower and upper values respectively.
198+
199+
should raise a value error for parameter values of type list """
200+
resource = self.resource.copy()
201+
resource['prior_min'] = 1
202+
resource['prior_max'] = 2.0
203+
with pytest.raises(ValueError, match='some values are not of type LIST'):
204+
self.module.load_parameters_from_dataframe(resource)
205+
169206

170207
class TestLoadParametersFromDataframe_Bools_From_Csv:
171208
"""Tests for the load_parameters_from_dataframe method, including handling of bools when loading from csv"""

tests/test_simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def _check_parsed_logs_are_equal(
259259
if key == "_metadata":
260260
assert module_logs_1[key] == module_logs_2[key]
261261
elif (module_name, key) not in module_name_key_pairs_to_skip:
262-
assert module_logs_1[key].equals(module_logs_2[key])
262+
assert module_logs_1[key].equals(module_logs_2[key]), f"{module_name} log {key} not equal"
263263

264264

265265
@pytest.mark.slow

0 commit comments

Comments
 (0)