Skip to content

ENH: Hierarchical parameters #155

@MImmesberger

Description

@MImmesberger

Is your feature request related to a problem?

In many applications, parameters will play a role in multiple model functions. Currently, there is no way to specify global parameters.

Below is a modified version of ISKHAKOV_ET_AL_2017_STRIPPED_DOWN that adds the parameter model_start_age to both age and wage. The resulting parameter template looks like this:

{'beta': nan,
 'utility': {'disutility_of_work': nan},
 'next_wealth': {'interest_rate': nan},
 'borrowing_constraint': {},
 'labor_income': {},
 'working': {},
 'wage': {'model_start_age': nan},
 'age': {'model_start_age': nan}}

Describe the solution you'd like

Use dags.tree to allow for both global parameters and parameters that are tied to a specific function.

Additional context

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import jax.numpy as jnp

from lcm import DiscreteGrid, LinspaceGrid, Model

if TYPE_CHECKING:
    from lcm.typing import (
        BoolND,
        ContinuousAction,
        ContinuousState,
        DiscreteAction,
        DiscreteState,
        FloatND,
        Int1D,
        IntND,
    )

# ======================================================================================
# Model functions
# ======================================================================================


# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class RetirementStatus:
    working: int = 0
    retired: int = 1


# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
def utility(
    consumption: ContinuousAction, working: IntND, disutility_of_work: float
) -> FloatND:
    return jnp.log(consumption) - disutility_of_work * working


# --------------------------------------------------------------------------------------
# Auxiliary variables
# --------------------------------------------------------------------------------------
def labor_income(working: IntND, wage: float | FloatND) -> FloatND:
    return working * wage


def working(retirement: DiscreteAction) -> IntND:
    return 1 - retirement


def wage(age: int | IntND, model_start_age: int) -> float | FloatND:
    return 1 + 0.1 * (age - model_start_age)


def age(_period: int | Int1D, model_start_age: int) -> int | IntND:
    return _period + model_start_age


# --------------------------------------------------------------------------------------
# State transitions
# --------------------------------------------------------------------------------------
def next_wealth(
    wealth: ContinuousState,
    consumption: ContinuousAction,
    labor_income: FloatND,
    interest_rate: float,
) -> ContinuousState:
    return (1 + interest_rate) * (wealth - consumption) + labor_income


# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
def borrowing_constraint(
    consumption: ContinuousAction | DiscreteAction, wealth: ContinuousState
) -> BoolND:
    return consumption <= wealth


ISKHAKOV_ET_AL_2017_STRIPPED_DOWN = Model(
    description=(
        "Starts from Iskhakov et al. (2017), removes absorbing retirement constraint "
        "and the lagged_retirement state, and adds wage function that depends on age."
    ),
    n_periods=3,
    functions={
        "utility": utility,
        "next_wealth": next_wealth,
        "borrowing_constraint": borrowing_constraint,
        "labor_income": labor_income,
        "working": working,
        "wage": wage,
        "age": age,
    },
    actions={
        "retirement": DiscreteGrid(RetirementStatus),
        "consumption": LinspaceGrid(
            start=1,
            stop=400,
            n_points=500,
        ),
    },
    states={
        "wealth": LinspaceGrid(
            start=1,
            stop=400,
            n_points=100,
        ),
    },
)


ISKHAKOV_ET_AL_2017_STRIPPED_DOWN.params_template

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions