Skip to content

Commit b058ff0

Browse files
authored
Add task to create analytical solution. (#1)
1 parent 6d20f3e commit b058ff0

File tree

4 files changed

+381
-1
lines changed

4 files changed

+381
-1
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ wheels/
2727
.installed.cfg
2828
*.egg
2929
MANIFEST
30-
*build/
30+
*bld/
3131

3232
# PyInstaller
3333
# Usually these files are written by a python script from a template

src/lcm_dev/analytical_solution.py

+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
"""Implementation of analytical solution by Iskhakov et al (2017)."""
2+
from functools import partial
3+
4+
import numpy as np
5+
from scipy.optimize import root_scalar
6+
7+
8+
def _u(c, work_dec, delta):
9+
"""Utility function.
10+
11+
Args:
12+
c (float): consumption
13+
work_dec (float): work indicator (True or False)
14+
delta (float): disutility of work
15+
Returns:
16+
float: utility
17+
18+
"""
19+
u = np.log(c) - work_dec * delta if c > 0 else -np.inf
20+
21+
return u
22+
23+
24+
def _generate_policy_function_vector(wage, r, beta, tau):
25+
"""Gererate consumption policy function vector given tau.
26+
27+
This function returns the functions that are used in the
28+
piecewise consumption function.
29+
30+
Args:
31+
wage (float): income
32+
r (float): interest rate
33+
beta (float): discount factor
34+
tau (int): periods left until end of life
35+
36+
Returns:
37+
dict: consumption policy dict
38+
39+
"""
40+
policy_vec_worker = [lambda m: m]
41+
42+
# Generate liquidity constraint kink functions
43+
for i in range(1, tau + 1):
44+
policy_vec_worker.append(
45+
lambda m, i=i: (
46+
m + wage * (np.sum([(1 + r) ** (-j) for j in range(1, i + 1)]))
47+
)
48+
/ (np.sum([beta**j for j in range(0, i + 1)])),
49+
)
50+
51+
# Generate retirement discontinuity functions
52+
for i in reversed(range(1, tau)):
53+
policy_vec_worker.append(
54+
lambda m, i=i, tau=tau: (
55+
m + wage * (np.sum([(1 + r) ** (-j) for j in range(1, i + 1)]))
56+
)
57+
/ (np.sum([beta**j for j in range(0, tau + 1)])),
58+
)
59+
policy_vec_worker.append(
60+
lambda m, tau=tau: m / (np.sum([beta**j for j in range(0, tau + 1)])),
61+
)
62+
63+
# Generate function for retirees
64+
policy_retiree = lambda m, tau=tau: m / ( # noqa: E731
65+
np.sum([beta**j for j in range(0, tau + 1)])
66+
)
67+
68+
return {"worker": policy_vec_worker, "retired": policy_retiree}
69+
70+
71+
def _compute_wealth_tresholds(v_prime, wage, r, beta, delta, tau, consumption_policy):
72+
"""Compute wealth treshold for piecewise consumption function.
73+
74+
Args:
75+
v_prime (function): continuation value of value function
76+
wage (float): labor income
77+
r (float): interest rate
78+
beta (float): discount factor
79+
delta (float): disutility of work
80+
tau (int): periods left until end of life
81+
consumption_policy (list): consumption policy vector
82+
83+
Returns:
84+
list: list of wealth thresholds
85+
86+
"""
87+
# Liquidity constraint threshold
88+
wealth_thresholds = [-np.inf, wage / ((1 + r) * beta)]
89+
90+
# Retirement threshold
91+
k = delta * np.sum([beta**j for j in range(0, tau + 1)]) ** (-1)
92+
ret_threshold = ((wage / (1 + r)) * np.exp(-k)) / (1 - np.exp(-k))
93+
94+
# Other kinks and discontinuities: Root finding
95+
for i in range(0, (tau - 1) * 2):
96+
c_l = consumption_policy[i + 1]
97+
c_u = consumption_policy[i + 2]
98+
99+
def root_fct(m, c_l=c_l, c_u=c_u):
100+
return (
101+
_u(c=c_l(m), work_dec=True, delta=delta)
102+
- _u(c=c_u(m), work_dec=True, delta=delta)
103+
+ beta * v_prime((1 + r) * (m - c_l(m)) + wage, work_status=True)
104+
- beta * v_prime((1 + r) * (m - c_u(m)) + wage, work_status=True)
105+
)
106+
107+
sol = root_scalar(
108+
root_fct,
109+
method="brentq",
110+
bracket=[wealth_thresholds[i + 1], ret_threshold],
111+
xtol=1e-10,
112+
rtol=1e-10,
113+
maxiter=1000,
114+
)
115+
assert sol.converged
116+
wealth_thresholds.append(sol.root)
117+
118+
# Add retirement threshold
119+
wealth_thresholds.append(ret_threshold)
120+
121+
# Add upper bound
122+
wealth_thresholds.append(np.inf)
123+
124+
return wealth_thresholds
125+
126+
127+
def _evaluate_piecewise_conditions(m, wealth_thresholds):
128+
"""Determine correct sub-function of policy function given wealth m.
129+
130+
Args:
131+
m (float): current wealth level
132+
wealth_thresholds (list): list of wealth thresholds
133+
Returns:
134+
list: list of booleans
135+
136+
"""
137+
cond_list = [
138+
m >= lb and m < ub
139+
for lb, ub in zip(wealth_thresholds[:-1], wealth_thresholds[1:])
140+
]
141+
return cond_list
142+
143+
144+
def _work_decision(m, work_status, wealth_thresholds):
145+
"""Determine work decision given current wealth level.
146+
147+
Args:
148+
m (float): current wealth level
149+
work_status (bool): work status from last period
150+
wealth_thresholds (list): list of wealth thresholds
151+
Returns:
152+
bool: work decision
153+
154+
"""
155+
return m < wealth_thresholds[-2] if work_status is not False else False
156+
157+
158+
def _consumption(m, work_status, policy_dict, wt):
159+
"""Determine consumption given current wealth level.
160+
161+
Args:
162+
m (float): current wealth level
163+
work_status (bool): work status from last period
164+
policy_dict (dict): dictionary of consumption policy functions
165+
wt (list): list of wealth thresholds
166+
Returns:
167+
float: consumption
168+
169+
"""
170+
if work_status is False:
171+
cons = policy_dict["retired"](m)
172+
173+
else:
174+
condlist = _evaluate_piecewise_conditions(m, wealth_thresholds=wt)
175+
cons = np.piecewise(x=m, condlist=condlist, funclist=policy_dict["worker"])
176+
return cons
177+
178+
179+
def _value_function(
180+
m,
181+
work_status,
182+
work_dec_func,
183+
c_pol,
184+
v_prime,
185+
beta,
186+
delta,
187+
tau,
188+
r,
189+
wage,
190+
):
191+
"""Determine value function given current wealth level and retirement status.
192+
193+
Args:
194+
m (float): current wealth level
195+
work_status (bool): work decision from last period
196+
work_dec_func (function): work decision function
197+
c_pol (function): consumption policy function
198+
v_prime (function): continuation value of value function
199+
beta (float): discount factor
200+
delta (float): disutility of work
201+
tau (int): periods left until end of life
202+
r (float): interest rate
203+
wage (float): labor income
204+
Returns:
205+
float: value function
206+
207+
"""
208+
if m == 0:
209+
v = -np.inf
210+
elif work_status is False:
211+
a = np.log(m) * np.sum([beta**j for j in range(0, tau + 1)])
212+
b = -np.log(np.sum([beta**j for j in range(0, tau + 1)]))
213+
c = np.sum([beta**j for j in range(0, tau + 1)])
214+
d = beta * (np.log(beta) + np.log(1 + r))
215+
e = np.sum(
216+
[
217+
beta**j * np.sum([beta**i for i in range(0, tau - j)])
218+
for j in range(0, tau)
219+
],
220+
)
221+
v = a + b * c + d * e
222+
else:
223+
work_dec = work_dec_func(m=m, work_status=work_status)
224+
cons = c_pol(m=m, work_status=work_status)
225+
226+
inst_util = _u(c=cons, work_dec=work_dec, delta=delta)
227+
cont_val = v_prime((1 + r) * (m - cons) + wage * work_dec, work_status=work_dec)
228+
229+
v = inst_util + beta * cont_val
230+
231+
return v
232+
233+
234+
def _construct_model(delta, num_periods, param_dict):
235+
"""Construct model given parameters via backward inducton.
236+
237+
Args:
238+
delta (float): disutility of work
239+
num_periods (int): length of life
240+
param_dict (dict): dictionary of parameters
241+
Returns:
242+
list: list of value functions
243+
244+
"""
245+
c_pol = [None] * num_periods
246+
v = [None] * num_periods
247+
work_dec_func = [None] * num_periods
248+
249+
for t in reversed(range(0, num_periods)):
250+
if t == num_periods - 1:
251+
v[t] = (
252+
lambda m, work_status: np.log(m) if m > 0 else -np.inf # noqa: ARG005
253+
)
254+
c_pol[t] = lambda m, work_status: m # noqa: ARG005
255+
work_dec_func[t] = lambda m, work_status: False # noqa: ARG005
256+
else:
257+
# Time left until retirement
258+
param_dict["tau"] = num_periods - t - 1
259+
260+
# Generate consumption function
261+
policy_dict = _generate_policy_function_vector(**param_dict)
262+
263+
wt = _compute_wealth_tresholds(
264+
v_prime=v[t + 1],
265+
consumption_policy=policy_dict["worker"],
266+
delta=delta,
267+
**param_dict,
268+
)
269+
270+
c_pol[t] = partial(_consumption, policy_dict=policy_dict, wt=wt)
271+
272+
# Determine retirement status
273+
work_dec_func[t] = partial(
274+
_work_decision,
275+
wealth_thresholds=wt,
276+
)
277+
278+
# Calculate V
279+
v[t] = partial(
280+
_value_function,
281+
work_dec_func=work_dec_func[t],
282+
c_pol=c_pol[t],
283+
v_prime=v[t + 1],
284+
delta=delta,
285+
**param_dict,
286+
)
287+
return v
288+
289+
290+
def analytical_solution(grid, beta, wage, r, delta, num_periods):
291+
"""Compute value function analytically on a grid.
292+
293+
Args:
294+
grid (list): grid of wealth levels
295+
beta (float): discount factor
296+
wage (float): labor income
297+
r (float): interest rate
298+
delta (float): disutility of work
299+
num_periods (int): length of life
300+
Returns:
301+
list: values of value function
302+
303+
"""
304+
# Unpack parameters
305+
306+
param_dict = {
307+
"beta": beta,
308+
"wage": wage,
309+
"r": r,
310+
"tau": None,
311+
}
312+
313+
v_fct = _construct_model(
314+
delta=delta,
315+
num_periods=num_periods,
316+
param_dict=param_dict,
317+
)
318+
319+
v = {
320+
k: [list(map(v_fct[t], grid, [v] * len(grid))) for t in range(0, num_periods)]
321+
for (k, v) in [["worker", True], ["retired", False]]
322+
}
323+
324+
return v

src/lcm_dev/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pathlib import Path
2+
3+
SRC = Path(__file__).parent.resolve()
4+
ROOT = SRC.joinpath("..", "..").resolve()
5+
BLD = ROOT.joinpath("bld").resolve()
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Task creating the analytical solution."""
2+
3+
import pickle
4+
5+
import numpy as np
6+
import pytask
7+
8+
from lcm_dev.analytical_solution import analytical_solution
9+
from lcm_dev.config import BLD
10+
11+
models = {
12+
"iskhakov_2017": {
13+
"beta": 0.98,
14+
"delta": 1.0,
15+
"wage": float(20),
16+
"r": 0.0,
17+
"num_periods": 5,
18+
},
19+
"low_delta": {
20+
"beta": 0.98,
21+
"delta": 0.1,
22+
"wage": float(20),
23+
"r": 0.0,
24+
"num_periods": 3,
25+
},
26+
"high_wage": {
27+
"beta": 0.98,
28+
"delta": 1.0,
29+
"wage": float(100),
30+
"r": 0.0,
31+
"num_periods": 5,
32+
},
33+
}
34+
35+
wealth_grid = np.linspace(1, 100, 10_000)
36+
37+
for model, params in models.items():
38+
39+
@pytask.mark.task(
40+
id=model,
41+
kwargs={
42+
"produces": BLD / "analytical_solution" / f"{model}.p",
43+
"params": params,
44+
},
45+
)
46+
def task_create_analytical_solution(produces, params):
47+
"""Store analytical solution in a pickle file."""
48+
pickle.dump(
49+
analytical_solution(grid=wealth_grid, **params),
50+
produces.open("wb"),
51+
)

0 commit comments

Comments
 (0)