-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Labels
Description
Is your feature request related to a problem?
Probably very low priority: In a test model, I had a state that was completely irrelevant (some_additional_state in the example below). This throws a hard to interpret error when solving the model. Below is a reproducable example.
Error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[26], line 14
3 params = {
4 "beta": 0.98,
5 "utility": {"disutility_of_work": 1.0},
6 "next_wealth": {"interest_rate": 0.05},
7 "labor_income": {"wage": 1.0},
8 }
10 initial_states = {
11 "wealth": jnp.array([10.0]),
12 }
---> 14 f = get_lcm_function(
15 model=ISKHAKOV_ET_AL_2017_STRIPPED_DOWN,
16 targets="solve",
17 )[0]
19 result = f(
20 params=params,
21 )
23 result
File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/entry_point.py:127, in get_lcm_function(model, targets, debug_mode, jit)
119 next_state_space_info = state_space_infos[period + 1]
121 Q_and_F = get_Q_and_F(
122 model=internal_model,
123 next_state_space_info=next_state_space_info,
124 period=period,
125 )
--> 127 max_Q_over_a = get_max_Q_over_a(
128 Q_and_F=Q_and_F,
129 actions_names=tuple(state_action_space.continuous_actions)
130 + tuple(state_action_space.discrete_actions),
131 states_names=tuple(state_action_space.states),
132 )
134 argmax_and_max_Q_over_a = get_argmax_and_max_Q_over_a(
135 Q_and_F=Q_and_F,
136 actions_names=tuple(state_action_space.discrete_actions)
137 + tuple(state_action_space.continuous_actions),
138 )
140 state_action_spaces[period] = state_action_space
File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/max_Q_over_a.py:75, in get_max_Q_over_a(Q_and_F, actions_names, states_names)
70 Q_arr, F_arr = Q_and_F(
71 params=params, next_V_arr=next_V_arr, **states_and_actions
72 )
73 return Q_arr.max(where=F_arr, initial=-jnp.inf)
---> 75 return productmap(max_Q_over_a, variables=states_names)
File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/dispatchers.py:176, in productmap(func, variables)
170 raise ValueError(
171 f"Same argument provided more than once in variables: {duplicates}",
172 )
174 func_callable_with_args = allow_args(func)
--> 176 vmapped = _base_productmap(func_callable_with_args, variables)
178 # This raises a mypy error but is perfectly fine to do. See
179 # https://github.com/python/mypy/issues/12472
180 vmapped.__signature__ = inspect.signature(func_callable_with_args) # type: ignore[attr-defined]
File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/dispatchers.py:204, in _base_productmap(func, product_axes)
201 signature = inspect.signature(func)
202 parameters = list(signature.parameters)
--> 204 positions = [parameters.index(ax) for ax in product_axes]
206 vmap_specs = []
207 # We iterate in reverse order such that the output dimensions are in the same order
208 # as the input dimensions.
ValueError: 'some_additional_state' is not in listDescribe the solution you'd like
Would be cool to raise a more readable error.
Reproducable example
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import jax.numpy as jnp
from lcm import DiscreteGrid, LinspaceGrid, Model
from lcm.entry_point import get_lcm_function
if TYPE_CHECKING:
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
DiscreteState,
FloatND,
Int1D,
IntND,
)
# ======================================================================================
# Model functions
# ======================================================================================
# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class RetirementStatus:
working: int = 0
retired: int = 1
@dataclass
class SomeAdditionalState:
low: int = 0
medium: int = 1
high: int = 2
# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
def utility(
consumption: ContinuousAction, working: IntND, disutility_of_work: float
) -> FloatND:
return jnp.log(consumption) - disutility_of_work * working
# --------------------------------------------------------------------------------------
# Auxiliary variables
# --------------------------------------------------------------------------------------
def labor_income(working: IntND, wage: float | FloatND) -> FloatND:
return working * wage
def working(retirement: DiscreteAction) -> IntND:
return 1 - retirement
def wage(age: int | IntND) -> float | FloatND:
return 1 + 0.1 * age
def age(_period: int | Int1D) -> int | IntND:
return _period + 18
# --------------------------------------------------------------------------------------
# State transitions
# --------------------------------------------------------------------------------------
def next_wealth(
wealth: ContinuousState,
consumption: ContinuousAction,
labor_income: FloatND,
interest_rate: float,
) -> ContinuousState:
return (1 + interest_rate) * (wealth - consumption) + labor_income
# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
def borrowing_constraint(
consumption: ContinuousAction | DiscreteAction, wealth: ContinuousState
) -> BoolND:
return consumption <= wealth
# ======================================================================================
# Model specifications
# ======================================================================================
ISKHAKOV_ET_AL_2017_STRIPPED_DOWN = Model(
description=(
"Starts from Iskhakov et al. (2017), removes absorbing retirement constraint "
"and the lagged_retirement state, and adds wage function that depends on age."
),
n_periods=3,
functions={
"utility": utility,
"next_wealth": next_wealth,
"next_some_additional_state": lambda x: x,
"borrowing_constraint": borrowing_constraint,
"labor_income": labor_income,
"working": working,
"wage": wage,
"age": age,
},
actions={
"retirement": DiscreteGrid(RetirementStatus),
"consumption": LinspaceGrid(
start=1,
stop=400,
n_points=500,
),
},
states={
"wealth": LinspaceGrid(
start=1,
stop=400,
n_points=100,
),
"some_additional_state": DiscreteGrid(SomeAdditionalState),
},
)
params = {
"beta": 0.98,
"utility": {"disutility_of_work": 1.0},
"next_wealth": {"interest_rate": 0.05},
"labor_income": {"wage": 1.0},
}
initial_states = {
"wealth": jnp.array([10.0]),
}
f = get_lcm_function(
model=ISKHAKOV_ET_AL_2017_STRIPPED_DOWN,
targets="solve",
)[0]
result = f(
params=params,
)