Skip to content

Commit e4038a0

Browse files
hmgaudeckerclaude
andcommitted
Refactor piecewise polynomial to use explicit interval notation (GEP-08)
Replace numbered dict format with ordered list of interval dicts: - Add interval_utils.py: validate_intervals(), extend_intervals_to_real_line() - Support partial domains (NaN for out-of-domain values) - Transpose rates shape from (n_coefficients, n_intervals) to (n_intervals, n_coefficients) - Rename keys: rate_linear->slope, rate_quadratic->quadratic, rate_cubic->cubic, intercept_at_lower_threshold->intercept - Handle list-based YAML specs and updates_previous merging - Add portion dependency for interval parsing/validation - Update mettsim YAML files and tests Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent d36b549 commit e4038a0

File tree

11 files changed

+481
-280
lines changed

11 files changed

+481
-280
lines changed

pixi.lock

Lines changed: 76 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies = [
2828
"optree>=0.16.0",
2929
"pandas",
3030
"plotly>=6.2.0",
31+
"portion",
3132
"pygments",
3233
"pygraphviz",
3334
"pytest",
@@ -101,6 +102,7 @@ numpy_groupies = "*"
101102
numpydoc = "*"
102103
openpyxl = "*"
103104
pandas = ">=2.3"
105+
portion = "*"
104106
prek = "*"
105107
pygments = "*"
106108
pygraphviz = "*"

src/ttsim/interface_dag_elements/fail_if.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,10 @@ def _param_with_active_periods(
829829
) -> list[_ParamWithActivePeriod]:
830830
"""Return parameter with active periods."""
831831

832-
def _remove_note_and_reference(entry: dict[str | int, Any]) -> dict[str | int, Any]:
832+
def _remove_note_and_reference(entry: dict[str | int, Any] | list) -> dict[str | int, Any] | list:
833833
"""Remove note and reference from a parameter specification."""
834+
if isinstance(entry, list):
835+
return entry
834836
entry.pop("note", None)
835837
entry.pop("reference", None)
836838
return entry

src/ttsim/interface_dag_elements/policy_environment.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ttsim.tt.column_objects_param_function import (
2929
DEFAULT_END_DATE,
3030
)
31+
from ttsim.tt.interval_utils import extend_intervals_to_real_line
3132
from ttsim.tt.piecewise_polynomial import get_piecewise_parameters
3233

3334
if TYPE_CHECKING:
@@ -188,7 +189,7 @@ def _get_one_param(
188189
cleaned_spec["value"] = get_piecewise_parameters(
189190
leaf_name=leaf_name,
190191
func_type=param_type, # ty: ignore[invalid-argument-type]
191-
parameter_dict=cleaned_spec["value"],
192+
parameter_list=cleaned_spec["value"],
192193
xnp=xnp,
193194
)
194195
return PiecewisePolynomialParam(**cleaned_spec)
@@ -242,17 +243,24 @@ def _clean_one_param_spec(
242243
out["reference_period"] = spec.get("reference_period", None)
243244
out["name"] = spec["name"]
244245
out["description"] = spec["description"]
245-
current_spec: dict[str | int, Any] = copy.deepcopy(spec[policy_dates[idx - 1]])
246+
raw_current = copy.deepcopy(spec[policy_dates[idx - 1]])
247+
if isinstance(raw_current, list):
248+
# List-based spec (no note/reference at this level)
249+
out["note"] = None
250+
out["reference"] = None
251+
out["value"] = _get_param_value([spec[d] for d in policy_dates[:idx]])
252+
return out
253+
254+
current_spec: dict[str | int, Any] = raw_current
246255
out["note"] = current_spec.pop("note", None)
247256
out["reference"] = current_spec.pop("reference", None)
248-
if len(current_spec) == 0:
257+
if not current_spec:
249258
return None
250259
if len(current_spec) == 1 and "updates_previous" in current_spec:
251260
raise ValueError(
252261
"'updates_previous' cannot be specified as the only element, found:\n\n"
253262
f"{spec}\n\n",
254263
)
255-
# Parameter ceased to exist
256264
if spec["type"] == "scalar":
257265
if "updates_previous" in current_spec:
258266
raise ValueError(
@@ -265,14 +273,24 @@ def _clean_one_param_spec(
265273

266274

267275
def _get_param_value(
268-
relevant_specs: list[dict[str | int, Any]],
269-
) -> dict[str | int, Any]:
276+
relevant_specs: list[dict[str | int, Any] | list[dict[str, Any]]],
277+
) -> dict[str | int, Any] | list[dict[str, Any]]:
270278
"""Get the value of a parameter.
271279
272280
Implementation is a recursion in order to handle the 'updates_previous' machinery.
273281
282+
Supports both dict-based and list-based (piecewise) specs. When the raw spec
283+
is a list (no reference/note fields), it's used directly. When it's a dict
284+
with integer keys (has reference/note alongside), the integer-keyed entries
285+
are converted to a list.
286+
274287
"""
275-
current_spec = relevant_specs[-1].copy()
288+
raw_spec = relevant_specs[-1]
289+
if isinstance(raw_spec, list):
290+
# Already a list (YAML date entry was a plain list)
291+
return raw_spec
292+
293+
current_spec = raw_spec.copy()
276294
updates_previous = current_spec.pop("updates_previous", False)
277295
current_spec.pop("note", None)
278296
current_spec.pop("reference", None)
@@ -282,8 +300,33 @@ def _get_param_value(
282300
"'updates_previous' cannot be missing in the initial spec, found "
283301
f"{relevant_specs}"
284302
)
303+
base = _get_param_value(relevant_specs=relevant_specs[:-1])
304+
if isinstance(base, list):
305+
# List-based spec: convert list to dict with integer keys for merging
306+
base_dict = dict(enumerate(base))
307+
merged = upsert_tree(
308+
base=base_dict,
309+
to_upsert=current_spec, # ty: ignore[invalid-argument-type]
310+
)
311+
result = [
312+
merged[i] for i in sorted(k for k in merged if isinstance(k, int))
313+
]
314+
return extend_intervals_to_real_line(result)
285315
return upsert_tree(
286-
base=_get_param_value(relevant_specs=relevant_specs[:-1]), # ty: ignore[invalid-argument-type]
316+
base=base, # ty: ignore[invalid-argument-type]
287317
to_upsert=current_spec, # ty: ignore[invalid-argument-type]
288318
)
319+
320+
# Convert integer-keyed dict to list when keys are consecutive ints 0..n-1
321+
# and values are dicts (piecewise polynomial interval specs).
322+
# Do not convert when values are scalars (consecutive_int_lookup_table)
323+
# or keys are non-consecutive (partial overlay specs like {3: {...}}).
324+
if (
325+
current_spec
326+
and all(isinstance(k, int) for k in current_spec)
327+
and all(isinstance(v, dict) for v in current_spec.values())
328+
and sorted(current_spec) == list(range(len(current_spec)))
329+
):
330+
return [current_spec[i] for i in range(len(current_spec))]
331+
289332
return current_spec

src/ttsim/tt/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
policy_function,
1818
policy_input,
1919
)
20+
from ttsim.tt.interval_utils import extend_intervals_to_real_line
2021
from ttsim.tt.param_objects import (
2122
ConsecutiveIntLookupTableParam,
2223
ConsecutiveIntLookupTableParamValue,
@@ -63,6 +64,7 @@
6364
"agg_by_group_function",
6465
"agg_by_p_id_function",
6566
"convert_sparse_to_consecutive_int_lookup_table",
67+
"extend_intervals_to_real_line",
6668
"get_consecutive_int_lookup_table_param_value",
6769
"get_month_based_phase_inout_of_age_thresholds_param_value",
6870
"get_piecewise_parameters",

src/ttsim/tt/interval_utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""Utilities for validating interval notation in piecewise specs."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
import numpy
8+
import portion
9+
10+
if TYPE_CHECKING:
11+
from types import ModuleType
12+
13+
from jaxtyping import Array, Float
14+
15+
16+
def validate_intervals(intervals: list[portion.Interval], leaf_name: str) -> None:
17+
"""Validate that intervals are ascending, non-overlapping, and contiguous."""
18+
if not intervals:
19+
raise ValueError(f"No intervals provided for {leaf_name}.")
20+
21+
for i in range(1, len(intervals)):
22+
prev, curr = intervals[i - 1], intervals[i]
23+
if curr.lower <= prev.lower:
24+
raise ValueError(
25+
f"Intervals for {leaf_name} are not in ascending order: "
26+
f"interval {i - 1} has lower bound {prev.lower}, "
27+
f"interval {i} has lower bound {curr.lower}."
28+
)
29+
if not (prev & curr).empty:
30+
raise ValueError(
31+
f"Overlapping intervals for {leaf_name}: "
32+
f"interval {i - 1} = {prev} and interval {i} = {curr}."
33+
)
34+
if prev.upper != curr.lower:
35+
raise ValueError(
36+
f"Gap between intervals for {leaf_name}: "
37+
f"interval {i - 1} upper = {prev.upper}, "
38+
f"interval {i} lower = {curr.lower}."
39+
)
40+
41+
42+
def _bound_to_float(v: object) -> float:
43+
"""Convert a portion bound (including portion.inf) to a Python float."""
44+
if v == -portion.inf:
45+
return float("-inf")
46+
if v == portion.inf:
47+
return float("inf")
48+
return float(v)
49+
50+
51+
def intervals_to_thresholds(
52+
intervals: list[portion.Interval], xnp: ModuleType
53+
) -> tuple[
54+
Float[Array, " n"],
55+
Float[Array, " n"],
56+
Float[Array, " n_plus_1"],
57+
]:
58+
"""Extract threshold arrays from parsed intervals."""
59+
lower = numpy.array([_bound_to_float(iv.lower) for iv in intervals])
60+
upper = numpy.array([_bound_to_float(iv.upper) for iv in intervals])
61+
all_bounds = sorted(set(lower) | set(upper))
62+
return xnp.array(lower), xnp.array(upper), xnp.array(all_bounds)
63+
64+
65+
def extend_intervals_to_real_line(
66+
items: list[dict[str, Any]],
67+
) -> list[dict[str, Any]]:
68+
"""Extend intervals so adjacent intervals are contiguous.
69+
70+
After merging specs (e.g., via updates_previous), changing one interval's
71+
bounds can leave adjacent intervals with stale boundaries. This propagates
72+
each interval's upper bound as the next interval's lower bound.
73+
"""
74+
if not items or not any("interval" in item for item in items):
75+
return items
76+
77+
result = [item.copy() for item in items]
78+
for i in range(len(result) - 1):
79+
if "interval" not in result[i] or "interval" not in result[i + 1]:
80+
continue
81+
curr = portion.from_string(result[i]["interval"], conv=float)
82+
next_ = portion.from_string(result[i + 1]["interval"], conv=float)
83+
if curr.upper != next_.lower:
84+
complement = portion.CLOSED if curr.right == portion.OPEN else portion.OPEN
85+
fixed = portion.Interval.from_atomic(
86+
complement, curr.upper, next_.upper, next_.right
87+
)
88+
result[i + 1] = {
89+
**result[i + 1],
90+
"interval": portion.to_string(fixed),
91+
}
92+
93+
return result

src/ttsim/tt/param_objects.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,18 @@ class PiecewisePolynomialParamValue:
181181
"""The parameters expected by `piecewise_polynomial`.
182182
183183
thresholds:
184-
Thresholds defining the pieces / different segments on the real line.
184+
Boundary points defining the pieces / different segments.
185185
intercepts:
186-
Intercepts of the polynomial on each segment.
186+
Intercepts of the polynomial on each segment (one per interval).
187187
rates:
188-
Slope and higher-order coefficients of the polynomial on each segment.
188+
Coefficients of the polynomial on each segment, shape
189+
(n_intervals, n_coefficients). For piecewise_constant, this is
190+
(n_intervals, 1) with all zeros.
189191
"""
190192

191-
thresholds: Float[Array, " n_segments"]
192-
intercepts: Float[Array, " n_segments"]
193-
rates: Float[Array, " n_segments"]
193+
thresholds: Float[Array, " n_thresholds"]
194+
intercepts: Float[Array, " n_intervals"]
195+
rates: Float[Array, "n_intervals n_coefficients"]
194196

195197

196198
def get_consecutive_int_lookup_table_param_value(

0 commit comments

Comments
 (0)