Skip to content

Commit e3e148b

Browse files
committed
[WIP] implement fixed constraints right away.
1 parent deb769d commit e3e148b

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

src/skillmodels/constraints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import skillmodels.transition_functions as tf
88

99

10-
def get_constraints(dimensions, labels, anchoring_info, update_info, normalizations):
10+
def get_constraint_tuples(
11+
dimensions, labels, anchoring_info, update_info, normalizations
12+
):
1113
"""Generate constraints implied by the model specification.
1214
1315
The result can easily be converted to optimagic-style constraints.
@@ -71,6 +73,9 @@ def add_bounds(params_df, bounds_distance=0.0):
7173
)
7274
if "lower_bound" not in df.columns:
7375
df["lower_bound"] = -np.inf
76+
if "upper_bound" not in df.columns:
77+
df["upper_bound"] = np.inf
78+
7479
df.loc["meas_sds", "lower_bound"] = bounds_distance
7580
df.loc["shock_sds", "lower_bound"] = bounds_distance
7681

src/skillmodels/maximization_inputs.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import functools
2+
from typing import Any
23

34
import jax
45
import jax.numpy as jnp
56
import numpy as np
7+
import optimagic as om
68
import pandas as pd
79

810
import skillmodels.likelihood_function as lf
911
import skillmodels.likelihood_function_debug as lfd
10-
from skillmodels.constraints import add_bounds, get_constraints
12+
from skillmodels.constraints import add_bounds, get_constraint_tuples
1113
from skillmodels.kalman_filters import calculate_sigma_scaling_factor_and_weights
1214
from skillmodels.params_index import get_params_index
1315
from skillmodels.parse_params import create_parsing_info
@@ -41,7 +43,7 @@ def get_maximization_inputs(model_dict, data):
4143
constraints (list): List of optimagic constraints that are implied by the
4244
model specification.
4345
params_template (pd.DataFrame): Parameter DataFrame with correct index and
44-
bounds but with empty value column.
46+
bounds. The value column is empty except for the fixed constraints.
4547
data_aug (pd.DataFrame): DataFrame with augmented data. If model contains
4648
investment factors, we double up the number of periods in order to add
4749
@@ -123,26 +125,32 @@ def debug_loglike(params):
123125
tmp["value"] = float(tmp["value"])
124126
return process_debug_data(debug_data=tmp, model=model)
125127

126-
constr = get_constraints(
128+
_constraints_tuples = get_constraint_tuples(
127129
dimensions=model["dimensions"],
128130
labels=model["labels"],
129131
anchoring_info=model["anchoring"],
130132
update_info=model["update_info"],
131133
normalizations=model["normalizations"],
132134
)
133135

136+
constraints = convert_old_style_constraints(_constraints_tuples)
137+
134138
params_template = pd.DataFrame(columns=["value"], index=p_index)
135139
params_template = add_bounds(
136140
params_template,
137141
model["estimation_options"]["bounds_distance"],
138142
)
143+
params_template = _fill_fixed_constraints(
144+
params_template=params_template,
145+
constraints_tuples=_constraints_tuples,
146+
)
139147

140148
out = {
141149
"loglike": loglike,
142150
"loglikeobs": loglikeobs,
143151
"debug_loglike": debug_loglike,
144152
"loglike_and_gradient": loglike_and_gradient,
145-
"constraints": constr,
153+
"constraints": constraints,
146154
"params_template": params_template,
147155
}
148156

@@ -218,3 +226,49 @@ def _get_jnp_params_vec(params, target_index):
218226

219227
vec = jnp.array(params.reindex(target_index)["value"].to_numpy())
220228
return vec
229+
230+
231+
def _sel(params, loc):
232+
return params.loc[loc]
233+
234+
235+
def convert_old_style_constraints(old_style):
236+
# Need this in many cases, anyhow -- so just impose to simplify code below!
237+
new_style = []
238+
for oc in old_style:
239+
if oc["type"] == "pairwise_equality":
240+
new_style.append(
241+
om.PairwiseEqualityConstraint(
242+
selectors=[functools.partial(_sel, loc=loc) for loc in oc["locs"]]
243+
)
244+
)
245+
else:
246+
sel = functools.partial(_sel, loc=oc["loc"])
247+
if oc["type"] == "fixed":
248+
new_style.append(om.FixedConstraint(selector=sel))
249+
elif oc["type"] == "equality":
250+
new_style.append(om.EqualityConstraint(selector=sel))
251+
elif oc["type"] == "probability":
252+
new_style.append(om.ProbabilityConstraint(selector=sel))
253+
elif oc["type"] == "increasing":
254+
new_style.append(om.IncreasingConstraint(selector=sel))
255+
else:
256+
raise TypeError(oc["type"])
257+
return new_style
258+
259+
260+
def _fill_fixed_constraints(
261+
params_template: pd.DataFrame,
262+
constraints_tuples: list[dict[str, Any]],
263+
) -> pd.DataFrame:
264+
params = params_template.copy()
265+
for constr in constraints_tuples:
266+
if constr["type"] == "fixed":
267+
params.loc[constr["loc"], "value"] = constr["value"]
268+
269+
# Check that fixed constraints are valid
270+
fixed = params[params["value"].notna()]
271+
invalid = fixed.query("value < lower_bound or value > upper_bound")
272+
if len(invalid) > 0:
273+
raise ValueError(f"Invalid fixed constraints:\n\n{invalid}")
274+
return params

tests/test_constraints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_add_bounds():
2828
)
2929
expected = df.copy(deep=True)
3030
expected["lower_bound"] = [0.1] * 5 + [0.1, -np.inf, 0.1, -np.inf, 0.1]
31+
expected["upper_bound"] = np.inf
3132

3233
calculated = add_bounds(df, 0.1)
3334
assert_frame_equal(calculated, expected)

0 commit comments

Comments
 (0)