Skip to content

ENH: Irrelevant states raise unreadable error #150

@MImmesberger

Description

@MImmesberger

Is your feature request related to a problem?

Probably very low priority: In a test model, I had a state that was completely irrelevant (some_additional_state in the example below). This throws a hard to interpret error when solving the model. Below is a reproducable example.

Error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[26], line 14
      3 params = {
      4     "beta": 0.98,
      5     "utility": {"disutility_of_work": 1.0},
      6     "next_wealth": {"interest_rate": 0.05},
      7     "labor_income": {"wage": 1.0},
      8 }
     10 initial_states = {
     11     "wealth": jnp.array([10.0]),
     12 }
---> 14 f = get_lcm_function(
     15     model=ISKHAKOV_ET_AL_2017_STRIPPED_DOWN,
     16     targets="solve",
     17 )[0]
     19 result = f(
     20     params=params,
     21 )
     23 result

File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/entry_point.py:127, in get_lcm_function(model, targets, debug_mode, jit)
    119     next_state_space_info = state_space_infos[period + 1]
    121 Q_and_F = get_Q_and_F(
    122     model=internal_model,
    123     next_state_space_info=next_state_space_info,
    124     period=period,
    125 )
--> 127 max_Q_over_a = get_max_Q_over_a(
    128     Q_and_F=Q_and_F,
    129     actions_names=tuple(state_action_space.continuous_actions)
    130     + tuple(state_action_space.discrete_actions),
    131     states_names=tuple(state_action_space.states),
    132 )
    134 argmax_and_max_Q_over_a = get_argmax_and_max_Q_over_a(
    135     Q_and_F=Q_and_F,
    136     actions_names=tuple(state_action_space.discrete_actions)
    137     + tuple(state_action_space.continuous_actions),
    138 )
    140 state_action_spaces[period] = state_action_space

File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/max_Q_over_a.py:75, in get_max_Q_over_a(Q_and_F, actions_names, states_names)
     70     Q_arr, F_arr = Q_and_F(
     71         params=params, next_V_arr=next_V_arr, **states_and_actions
     72     )
     73     return Q_arr.max(where=F_arr, initial=-jnp.inf)
---> 75 return productmap(max_Q_over_a, variables=states_names)

File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/dispatchers.py:176, in productmap(func, variables)
    170     raise ValueError(
    171         f"Same argument provided more than once in variables: {duplicates}",
    172     )
    174 func_callable_with_args = allow_args(func)
--> 176 vmapped = _base_productmap(func_callable_with_args, variables)
    178 # This raises a mypy error but is perfectly fine to do. See
    179 # https://github.com/python/mypy/issues/12472
    180 vmapped.__signature__ = inspect.signature(func_callable_with_args)  # type: ignore[attr-defined]

File ~/.pixi/envs/py313-jax/lib/python3.13/site-packages/lcm/dispatchers.py:204, in _base_productmap(func, product_axes)
    201 signature = inspect.signature(func)
    202 parameters = list(signature.parameters)
--> 204 positions = [parameters.index(ax) for ax in product_axes]
    206 vmap_specs = []
    207 # We iterate in reverse order such that the output dimensions are in the same order
    208 # as the input dimensions.

ValueError: 'some_additional_state' is not in list

Describe the solution you'd like

Would be cool to raise a more readable error.

Reproducable example
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import jax.numpy as jnp

from lcm import DiscreteGrid, LinspaceGrid, Model
from lcm.entry_point import get_lcm_function

if TYPE_CHECKING:
    from lcm.typing import (
        BoolND,
        ContinuousAction,
        ContinuousState,
        DiscreteAction,
        DiscreteState,
        FloatND,
        Int1D,
        IntND,
    )

# ======================================================================================
# Model functions
# ======================================================================================


# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class RetirementStatus:
    working: int = 0
    retired: int = 1


@dataclass
class SomeAdditionalState:
    low: int = 0
    medium: int = 1
    high: int = 2


# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
def utility(
    consumption: ContinuousAction, working: IntND, disutility_of_work: float
) -> FloatND:
    return jnp.log(consumption) - disutility_of_work * working


# --------------------------------------------------------------------------------------
# Auxiliary variables
# --------------------------------------------------------------------------------------
def labor_income(working: IntND, wage: float | FloatND) -> FloatND:
    return working * wage


def working(retirement: DiscreteAction) -> IntND:
    return 1 - retirement


def wage(age: int | IntND) -> float | FloatND:
    return 1 + 0.1 * age


def age(_period: int | Int1D) -> int | IntND:
    return _period + 18


# --------------------------------------------------------------------------------------
# State transitions
# --------------------------------------------------------------------------------------
def next_wealth(
    wealth: ContinuousState,
    consumption: ContinuousAction,
    labor_income: FloatND,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth - consumption) + labor_income


# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
def borrowing_constraint(
    consumption: ContinuousAction | DiscreteAction, wealth: ContinuousState
) -> BoolND:
    return consumption <= wealth


# ======================================================================================
# Model specifications
# ======================================================================================

ISKHAKOV_ET_AL_2017_STRIPPED_DOWN = Model(
    description=(
        "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint "
        "and the lagged_retirement state, and adds wage function that depends on age."
    ),
    n_periods=3,
    functions={
        "utility": utility,
        "next_wealth": next_wealth,
        "next_some_additional_state": lambda x: x,
        "borrowing_constraint": borrowing_constraint,
        "labor_income": labor_income,
        "working": working,
        "wage": wage,
        "age": age,
    },
    actions={
        "retirement": DiscreteGrid(RetirementStatus),
        "consumption": LinspaceGrid(
            start=1,
            stop=400,
            n_points=500,
        ),
    },
    states={
        "wealth": LinspaceGrid(
            start=1,
            stop=400,
            n_points=100,
        ),
        "some_additional_state": DiscreteGrid(SomeAdditionalState),
    },
)


params = {
    "beta": 0.98,
    "utility": {"disutility_of_work": 1.0},
    "next_wealth": {"interest_rate": 0.05},
    "labor_income": {"wage": 1.0},
}

initial_states = {
    "wealth": jnp.array([10.0]),
}

f = get_lcm_function(
    model=ISKHAKOV_ET_AL_2017_STRIPPED_DOWN,
    targets="solve",
)[0]

result = f(
    params=params,
)

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions