-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Is your feature request related to a problem?
In many applications, parameters will play a role in multiple model functions. Currently, there is no way to specify global parameters.
Below is a modified version of ISKHAKOV_ET_AL_2017_STRIPPED_DOWN that adds the parameter model_start_age to both age and wage. The resulting parameter template looks like this:
{'beta': nan,
'utility': {'disutility_of_work': nan},
'next_wealth': {'interest_rate': nan},
'borrowing_constraint': {},
'labor_income': {},
'working': {},
'wage': {'model_start_age': nan},
'age': {'model_start_age': nan}}Describe the solution you'd like
Use dags.tree to allow for both global parameters and parameters that are tied to a specific function.
Additional context
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import jax.numpy as jnp
from lcm import DiscreteGrid, LinspaceGrid, Model
if TYPE_CHECKING:
from lcm.typing import (
BoolND,
ContinuousAction,
ContinuousState,
DiscreteAction,
DiscreteState,
FloatND,
Int1D,
IntND,
)
# ======================================================================================
# Model functions
# ======================================================================================
# --------------------------------------------------------------------------------------
# Categorical variables
# --------------------------------------------------------------------------------------
@dataclass
class RetirementStatus:
working: int = 0
retired: int = 1
# --------------------------------------------------------------------------------------
# Utility functions
# --------------------------------------------------------------------------------------
def utility(
consumption: ContinuousAction, working: IntND, disutility_of_work: float
) -> FloatND:
return jnp.log(consumption) - disutility_of_work * working
# --------------------------------------------------------------------------------------
# Auxiliary variables
# --------------------------------------------------------------------------------------
def labor_income(working: IntND, wage: float | FloatND) -> FloatND:
return working * wage
def working(retirement: DiscreteAction) -> IntND:
return 1 - retirement
def wage(age: int | IntND, model_start_age: int) -> float | FloatND:
return 1 + 0.1 * (age - model_start_age)
def age(_period: int | Int1D, model_start_age: int) -> int | IntND:
return _period + model_start_age
# --------------------------------------------------------------------------------------
# State transitions
# --------------------------------------------------------------------------------------
def next_wealth(
wealth: ContinuousState,
consumption: ContinuousAction,
labor_income: FloatND,
interest_rate: float,
) -> ContinuousState:
return (1 + interest_rate) * (wealth - consumption) + labor_income
# --------------------------------------------------------------------------------------
# Constraints
# --------------------------------------------------------------------------------------
def borrowing_constraint(
consumption: ContinuousAction | DiscreteAction, wealth: ContinuousState
) -> BoolND:
return consumption <= wealth
ISKHAKOV_ET_AL_2017_STRIPPED_DOWN = Model(
description=(
"Starts from Iskhakov et al. (2017), removes absorbing retirement constraint "
"and the lagged_retirement state, and adds wage function that depends on age."
),
n_periods=3,
functions={
"utility": utility,
"next_wealth": next_wealth,
"borrowing_constraint": borrowing_constraint,
"labor_income": labor_income,
"working": working,
"wage": wage,
"age": age,
},
actions={
"retirement": DiscreteGrid(RetirementStatus),
"consumption": LinspaceGrid(
start=1,
stop=400,
n_points=500,
),
},
states={
"wealth": LinspaceGrid(
start=1,
stop=400,
n_points=100,
),
},
)
ISKHAKOV_ET_AL_2017_STRIPPED_DOWN.params_templatetimmens
Metadata
Metadata
Assignees
Labels
No labels