Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jit the simulation part of lcm #99

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/lcm/entry_point.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
from collections.abc import Callable
from functools import partial
from typing import Literal, cast
Expand All @@ -9,13 +10,14 @@
from lcm.argmax import argmax
from lcm.discrete_problem import get_solve_discrete_problem
from lcm.dispatchers import productmap
from lcm.functools import all_as_kwargs
from lcm.input_processing import process_model
from lcm.logging import get_logger
from lcm.model_functions import (
get_utility_and_feasibility_function,
)
from lcm.next_state import get_next_state_function
from lcm.simulate import simulate
from lcm.simulate import _as_data_frame, _compute_targets, simulate
from lcm.solve_brute import solve
from lcm.state_space import create_state_choice_space
from lcm.typing import ParamsDict
Expand Down Expand Up @@ -170,7 +172,7 @@ def get_lcm_function(
solve_model = jax.jit(_solve_model) if jit else _solve_model

_next_state_simulate = get_next_state_function(model=_mod, target="simulate")
simulate_model = partial(
_simulate_model = partial(
simulate,
state_indexers=state_indexers,
continuous_choice_grids=continuous_choice_grids,
Expand All @@ -179,13 +181,43 @@ def get_lcm_function(
next_state=jax.jit(_next_state_simulate),
logger=logger,
)
simulate_model = jax.jit(_simulate_model) if jit else _simulate_model

if targets == "solve":
_target = solve_model

def _target(*args, **kwargs):
return solve_model(*args, **kwargs)
elif targets == "simulate":
_target = simulate_model

def _target(*args, **kwargs):
kwargs = all_as_kwargs(
args, kwargs, list(inspect.signature(simulate).parameters)
)
additional_targets = kwargs.get("additional_targets")
kwargs.pop("additional_targets", None)
_simulated = simulate_model(**kwargs)
return _as_data_frame(
_compute_targets(
_simulated, additional_targets, _mod.functions, kwargs["params"]
),
_mod.n_periods,
)
elif targets == "solve_and_simulate":
_target = partial(simulate_model, solve_model=solve_model)

def _target(*args, **kwargs):
kwargs = all_as_kwargs(
args, kwargs, list(inspect.signature(simulate).parameters)
)
additional_targets = kwargs.get("additional_targets")
kwargs.pop("additional_targets", None)
_solved = solve_model(kwargs["params"])
_simulated = simulate_model(**kwargs, vf_arr_list=_solved)
return _as_data_frame(
_compute_targets(
_simulated, additional_targets, _mod.functions, kwargs["params"]
),
_mod.n_periods,
)

return cast(Callable, _target), _mod.params

Expand Down
52 changes: 19 additions & 33 deletions src/lcm/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def simulate(
model: InternalModel,
next_state,
logger,
solve_model=None,
vf_arr_list=None,
additional_targets=None,
seed=12345,
):
"""Simulate the model forward in time.
Expand Down Expand Up @@ -59,13 +57,9 @@ def simulate(

"""
if vf_arr_list is None:
if solve_model is None:
raise ValueError(
"You need to provide either vf_arr_list or solve_model.",
)
# We do not need to convert the params here, because the solve_model function
# will do it.
vf_arr_list = solve_model(params)
raise ValueError(
"You need to provide either vf_arr_list or solve_model.",
)

logger.info("Starting simulation")

Expand Down Expand Up @@ -207,18 +201,7 @@ def simulate(

logger.info("Period: %s", period)

processed = _process_simulated_data(_simulation_results)

if additional_targets is not None:
calculated_targets = _compute_targets(
processed,
targets=additional_targets,
model_functions=model.functions,
params=params,
)
processed = {**processed, **calculated_targets}

return _as_data_frame(processed, n_periods=n_periods)
return _process_simulated_data(_simulation_results)


def solve_continuous_problem(
Expand Down Expand Up @@ -316,21 +299,24 @@ def _compute_targets(processed_results, targets, model_functions, params):
dict: Dict with computed targets.

"""
target_func = concatenate_functions(
functions=model_functions,
targets=targets,
return_type="dict",
)
if targets is not None:
target_func = concatenate_functions(
functions=model_functions,
targets=targets,
return_type="dict",
)

# get list of variables over which we want to vectorize the target function
variables = [
p for p in list(inspect.signature(target_func).parameters) if p != "params"
]
# get list of variables over which we want to vectorize the target function
variables = [
p for p in list(inspect.signature(target_func).parameters) if p != "params"
]

target_func = vmap_1d(target_func, variables=variables)

target_func = vmap_1d(target_func, variables=variables)
kwargs = {k: v for k, v in processed_results.items() if k in variables}

kwargs = {k: v for k, v in processed_results.items() if k in variables}
return target_func(params=params, **kwargs)
return {**processed_results, **target_func(params=params, **kwargs)}
return processed_results


def _process_simulated_data(results):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def test_regression_test():
# Compare
# ==================================================================================
aaae(expected_solve, got_solve, decimal=5)
assert_frame_equal(expected_simulate, got_simulate)
assert_frame_equal(expected_simulate, got_simulate, check_like=True)
5 changes: 3 additions & 2 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def test_simulate_using_raw_inputs(simulate_inputs):
**simulate_inputs,
)

assert_array_equal(got.loc[0, :]["retirement"], 1)
assert_array_almost_equal(got.loc[0, :]["consumption"], jnp.array([1.0, 50.400803]))
assert_array_equal(got["retirement"], 1)
assert_array_almost_equal(got["consumption"], jnp.array([1.0, 50.400803]))


# ======================================================================================
Expand Down Expand Up @@ -336,6 +336,7 @@ def f_b(b, params): # noqa: ARG001
params={"disutility_of_work": -1.0},
)
expected = {
**processed_results,
"fa": jnp.arange(3) - 1.0,
"fb": 1 + jnp.arange(3),
}
Expand Down
Loading