|
1 | 1 | import functools
|
| 2 | +from typing import Any |
2 | 3 |
|
3 | 4 | import jax
|
4 | 5 | import jax.numpy as jnp
|
5 | 6 | import numpy as np
|
| 7 | +import optimagic as om |
6 | 8 | import pandas as pd
|
7 | 9 |
|
8 | 10 | import skillmodels.likelihood_function as lf
|
9 | 11 | import skillmodels.likelihood_function_debug as lfd
|
10 |
| -from skillmodels.constraints import add_bounds, get_constraints |
| 12 | +from skillmodels.constraints import add_bounds, get_constraint_tuples |
11 | 13 | from skillmodels.kalman_filters import calculate_sigma_scaling_factor_and_weights
|
12 | 14 | from skillmodels.params_index import get_params_index
|
13 | 15 | from skillmodels.parse_params import create_parsing_info
|
@@ -41,7 +43,7 @@ def get_maximization_inputs(model_dict, data):
|
41 | 43 | constraints (list): List of optimagic constraints that are implied by the
|
42 | 44 | model specification.
|
43 | 45 | params_template (pd.DataFrame): Parameter DataFrame with correct index and
|
44 |
| - bounds but with empty value column. |
| 46 | + bounds. The value column is empty except for the fixed constraints. |
45 | 47 | data_aug (pd.DataFrame): DataFrame with augmented data. If model contains
|
46 | 48 | investment factors, we double up the number of periods in order to add
|
47 | 49 |
|
@@ -123,26 +125,32 @@ def debug_loglike(params):
|
123 | 125 | tmp["value"] = float(tmp["value"])
|
124 | 126 | return process_debug_data(debug_data=tmp, model=model)
|
125 | 127 |
|
126 |
| - constr = get_constraints( |
| 128 | + _constraints_tuples = get_constraint_tuples( |
127 | 129 | dimensions=model["dimensions"],
|
128 | 130 | labels=model["labels"],
|
129 | 131 | anchoring_info=model["anchoring"],
|
130 | 132 | update_info=model["update_info"],
|
131 | 133 | normalizations=model["normalizations"],
|
132 | 134 | )
|
133 | 135 |
|
| 136 | + constraints = convert_old_style_constraints(_constraints_tuples) |
| 137 | + |
134 | 138 | params_template = pd.DataFrame(columns=["value"], index=p_index)
|
135 | 139 | params_template = add_bounds(
|
136 | 140 | params_template,
|
137 | 141 | model["estimation_options"]["bounds_distance"],
|
138 | 142 | )
|
| 143 | + params_template = _fill_fixed_constraints( |
| 144 | + params_template=params_template, |
| 145 | + constraints_tuples=_constraints_tuples, |
| 146 | + ) |
139 | 147 |
|
140 | 148 | out = {
|
141 | 149 | "loglike": loglike,
|
142 | 150 | "loglikeobs": loglikeobs,
|
143 | 151 | "debug_loglike": debug_loglike,
|
144 | 152 | "loglike_and_gradient": loglike_and_gradient,
|
145 |
| - "constraints": constr, |
| 153 | + "constraints": constraints, |
146 | 154 | "params_template": params_template,
|
147 | 155 | }
|
148 | 156 |
|
@@ -218,3 +226,49 @@ def _get_jnp_params_vec(params, target_index):
|
218 | 226 |
|
219 | 227 | vec = jnp.array(params.reindex(target_index)["value"].to_numpy())
|
220 | 228 | return vec
|
| 229 | + |
| 230 | + |
| 231 | +def _sel(params, loc): |
| 232 | + return params.loc[loc] |
| 233 | + |
| 234 | + |
| 235 | +def convert_old_style_constraints(old_style): |
| 236 | + # Need this in many cases, anyhow -- so just impose to simplify code below! |
| 237 | + new_style = [] |
| 238 | + for oc in old_style: |
| 239 | + if oc["type"] == "pairwise_equality": |
| 240 | + new_style.append( |
| 241 | + om.PairwiseEqualityConstraint( |
| 242 | + selectors=[functools.partial(_sel, loc=loc) for loc in oc["locs"]] |
| 243 | + ) |
| 244 | + ) |
| 245 | + else: |
| 246 | + sel = functools.partial(_sel, loc=oc["loc"]) |
| 247 | + if oc["type"] == "fixed": |
| 248 | + new_style.append(om.FixedConstraint(selector=sel)) |
| 249 | + elif oc["type"] == "equality": |
| 250 | + new_style.append(om.EqualityConstraint(selector=sel)) |
| 251 | + elif oc["type"] == "probability": |
| 252 | + new_style.append(om.ProbabilityConstraint(selector=sel)) |
| 253 | + elif oc["type"] == "increasing": |
| 254 | + new_style.append(om.IncreasingConstraint(selector=sel)) |
| 255 | + else: |
| 256 | + raise TypeError(oc["type"]) |
| 257 | + return new_style |
| 258 | + |
| 259 | + |
| 260 | +def _fill_fixed_constraints( |
| 261 | + params_template: pd.DataFrame, |
| 262 | + constraints_tuples: list[dict[str, Any]], |
| 263 | +) -> pd.DataFrame: |
| 264 | + params = params_template.copy() |
| 265 | + for constr in constraints_tuples: |
| 266 | + if constr["type"] == "fixed": |
| 267 | + params.loc[constr["loc"], "value"] = constr["value"] |
| 268 | + |
| 269 | + # Check that fixed constraints are valid |
| 270 | + fixed = params[params["value"].notna()] |
| 271 | + invalid = fixed.query("value < lower_bound or value > upper_bound") |
| 272 | + if len(invalid) > 0: |
| 273 | + raise ValueError(f"Invalid fixed constraints:\n\n{invalid}") |
| 274 | + return params |
0 commit comments