diff --git a/resources/ResourceFile_GenerateEventChains/parameter_values.csv b/resources/ResourceFile_GenerateEventChains/parameter_values.csv new file mode 100644 index 0000000000..ebf20c5f79 --- /dev/null +++ b/resources/ResourceFile_GenerateEventChains/parameter_values.csv @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:172a0c24c859aaafbad29f6016433cac7a7324efc582e6c4b19c74b6b97436e7 +size 420 diff --git a/src/scripts/analysis_data_generation/analysis_extract_data.py b/src/scripts/analysis_data_generation/analysis_extract_data.py new file mode 100644 index 0000000000..9ee37cabef --- /dev/null +++ b/src/scripts/analysis_data_generation/analysis_extract_data.py @@ -0,0 +1,557 @@ +"""Produce plots to show the health impact (deaths, dalys) each the healthcare system (overall health impact) when +running under different MODES and POLICIES (scenario_impact_of_actual_vs_funded.py)""" + +# short tclose -> ideal case +# long tclose -> status quo +import argparse +from pathlib import Path +from typing import Tuple + +import pandas as pd +import matplotlib.pyplot as plt + +from tlo import Date +from tlo.analysis.utils import extract_results, extract_event_chains +from datetime import datetime +from collections import Counter +import ast + +# Time simulated to collect data +start_date = Date(2010, 1, 1) +end_date = start_date + pd.DateOffset(months=13) + +# Range of years considered +min_year = 2010 +max_year = 2040 + + +def all_columns(_df): + return pd.Series(_df.all()) + +def check_if_beyond_time_range_considered(progression_properties): + matching_keys = [key for key in progression_properties.keys() if "rt_date_to_remove_daly" in key] + if matching_keys: + for key in matching_keys: + if progression_properties[key] > end_date: + print("Beyond time range considered, need at least ",progression_properties[key]) + +def print_filtered_df(df): + """ + Prints rows of the DataFrame excluding EventName 'Initialise' and 'Birth'. + """ + pd.set_option('display.max_colwidth', None) + filtered = df#[~df['EventName'].isin(['StartOfSimulation', 'Birth'])] + + dict_cols = ["Info"] + max_items = 2 + # Step 2: Truncate dictionary columns for display + if dict_cols is not None: + for col in dict_cols: + def truncate_dict(d): + if isinstance(d, dict): + items = list(d.items())[:max_items] # keep only first `max_items` + return dict(items) + return d + filtered[col] = filtered[col].apply(truncate_dict) + print(filtered) + + +def apply(results_folder: Path, output_folder: Path, resourcefilepath: Path = None, ): + """Produce standard set of plots describing the effect of each TREATMENT_ID. + - We estimate the epidemiological impact as the EXTRA deaths that would occur if that treatment did not occur. + - We estimate the draw on healthcare system resources as the FEWER appointments when that treatment does not occur. + """ + pd.set_option('display.max_rows', None) + pd.set_option('display.max_colwidth', None) + + individual_event_chains = extract_event_chains(results_folder) + print_filtered_df(individual_event_chains[0]) + exit(-1) + + eval_env = { + 'datetime': datetime, # Add the datetime class to the eval environment + 'pd': pd, # Add pandas to handle Timestamp + 'Timestamp': pd.Timestamp, # Specifically add Timestamp for eval + 'NaT': pd.NaT, + 'nan': float('nan'), # Include NaN for eval (can also use pd.NA if preferred) + } + + initial_properties_of_interest = ['rt_MAIS_military_score','rt_ISS_score','rt_disability','rt_polytrauma','rt_injury_1','rt_injury_2','rt_injury_3','rt_injury_4','rt_injury_5','rt_injury_6', 'rt_imm_death','sy_injury','sy_severe_trauma','sex','li_urban', 'li_wealth', 'li_mar_stat', 'li_in_ed', 'li_ed_lev'] + + # Will be added through computation: age at time of RTI + # Will be added through computation: total duration of event + + initial_rt_event_properties = set() + + num_individuals = 1000 + num_runs = 1 + record = [] + # Include results folder in output file name + name_tag = str(results_folder).replace("outputs/", "") + + + + for p in range(0,num_individuals): + + print("At person = ", p, " out of ", num_individuals) + + individual_event_chains = extract_results( + results_folder, + module='tlo.simulation', + key='event_chains', + column=str(p), + do_scaling=False + ) + + for r in range(0,num_runs): + initial_properties = {} + key_first_event = {} + key_last_event = {} + first_event = {} + last_event = {} + properties = {} + average_disability = 0 + total_dt_included = 0 + dt_in_prev_disability = 0 + prev_disability_incurred = 0 + ind_Counter = {'0': Counter(), '1a': Counter(), '1b' : Counter(), '2' : Counter()} + # Count total appts + + list_for_individual = [] + for item,row in individual_event_chains.iterrows(): + value = individual_event_chains.loc[item,(0, r)] + if value !='' and isinstance(value, str): + evaluated = eval(value, eval_env) + list_for_individual.append(evaluated) + + for i in list_for_individual: + print(i) + + """ + # These are the properties of the individual before the start of the chain of events + initial_properties = list_for_individual[0] + + # Initialise first event by gathering parameters of interest from initial_properties + first_event = {key: initial_properties[key] for key in initial_properties_of_interest if key in initial_properties} + + # The changing or adding of properties from the first_event will be stored in progression_properties + progression_properties = {} + + for i in list_for_individual: + # Skip the initial_properties, or in other words only consider these if they are 'proper' events + if 'event' in i: + #print(i) + if 'RTIPolling' in i['event']: + + # Keep track of which properties are changed during polling events + for key,value in i.items(): + if 'rt_' in key: + initial_rt_event_properties.add(key) + + # Retain a copy of Polling event + polling_event = i.copy() + + # Update parameters of interest following RTI + key_first_event = {key: i[key] if key in i else value for key, value in first_event.items()} + + # Calculate age of individual at time of event + key_first_event['age_in_days_at_event'] = (i['rt_date_inj'] - initial_properties['date_of_birth']).days + + # Keep track of evolution in individual's properties + progression_properties = initial_properties.copy() + progression_properties.update(i) + + # Initialise chain of Dalys incurred + if 'rt_disability' in i: + prev_disability_incurred = i['rt_disability'] + prev_date = i['event_date'] + + else: + # Progress properties of individual, even if this event is a death + progression_properties.update(i) + + # If disability has changed as a result of this, recalculate and add previous to rolling average + if 'rt_disability' in i: + + dt_in_prev_disability = (i['event_date'] - prev_date).days + #print("Detected change in disability", i['rt_disability'], "after dt=", dt_in_prev_disability) + #print("Adding the following to the average", prev_disability_incurred, " x ", dt_in_prev_disability ) + average_disability += prev_disability_incurred*dt_in_prev_disability + total_dt_included += dt_in_prev_disability + # Update variables + prev_disability_incurred = i['rt_disability'] + prev_date = i['event_date'] + + # Update running footprint + if 'appt_footprint' in i and i['appt_footprint'] != 'Counter()': + footprint = i['appt_footprint'] + if 'Counter' in footprint: + footprint = footprint[len("Counter("):-1] + apply = eval(footprint, eval_env) + ind_Counter[i['level']].update(Counter(apply)) + + # If the individual has died, ensure chain of event is interrupted here and update rolling average of DALYs + if 'is_alive' in i and i['is_alive'] is False: + if ((i['event_date'] - polling_event['rt_date_inj']).days) > total_dt_included: + dt_in_prev_disability = (i['event_date'] - prev_date).days + average_disability += prev_disability_incurred*dt_in_prev_disability + total_dt_included += dt_in_prev_disability + break + + # check_if_beyond_time_range_considered(progression_properties) + + # Compute final properties of individual + key_last_event['is_alive_after_RTI'] = progression_properties['is_alive'] + key_last_event['duration_days'] = (progression_properties['event_date'] - polling_event['rt_date_inj']).days + + # If individual didn't die and the key_last_event didn't result in a final change in DALYs, ensure that the last change is recorded here + if not key_first_event['rt_imm_death'] and (total_dt_included < key_last_event['duration_days']): + #print("Number of events", len(list_for_individual)) + #for i in list_for_individual: + # if 'event' in i: + # print(i) + dt_in_prev_disability = (progression_properties['event_date'] - prev_date).days + average_disability += prev_disability_incurred*dt_in_prev_disability + total_dt_included += dt_in_prev_disability + + # Now calculate the average disability incurred, and store any permanent disability and total footprint + if not key_first_event['rt_imm_death'] and key_last_event['duration_days']> 0: + key_last_event['rt_disability_average'] = average_disability/key_last_event['duration_days'] + else: + key_last_event['rt_disability_average'] = 0.0 + + key_last_event['rt_disability_permanent'] = progression_properties['rt_disability'] + key_last_event.update({'total_footprint': ind_Counter}) + + if key_last_event['duration_days']!=total_dt_included: + print("The duration of event and total_dt_included don't match", key_last_event['duration_days'], total_dt_included) + exit(-1) + + properties = key_first_event | key_last_event + + record.append(properties) + """ + + df = pd.DataFrame(record) + df.to_csv("new_raw_data_" + name_tag + ".csv", index=False) + + print(df) + print(initial_rt_event_properties) + exit(-1) + #print(i) + + #dict = {} + #for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # dict[i] = [] + + #for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # event_chains = extract_results( + # results_folder, + # module='tlo.simulation'#, + # key='event_chains', + # column = str(i), + # #custom_generate_series=get_num_dalys_by_year, + # do_scaling=False + # ) + # print(event_chains) + # print(event_chains.index) + # print(event_chains.columns.levels) + + # for index, row in event_chains.iterrows(): + # if event_chains.iloc[index,0] is not None: + # if(event_chains.iloc[index,0]['person_ID']==i): #and 'event' in event_chains.iloc[index,0].keys()): + # dict[i].append(event_chains.iloc[index,0]) + #elif (event_chains.iloc[index,0]['person_ID']==i and 'event' not in event_chains.iloc[index,0].keys()): + #print(event_chains.iloc[index,0]['de_depr']) + # exit(-1) + #for item in dict[0]: + # print(item) + + #exit(-1) + + TARGET_PERIOD = (Date(min_year, 1, 1), Date(max_year, 1, 1)) + + # Definitions of general helper functions + lambda stub: output_folder / f"{stub.replace('*', '_star_')}.png" # noqa: E731 + + def target_period() -> str: + """Returns the target period as a string of the form YYYY-YYYY""" + return "-".join(str(t.year) for t in TARGET_PERIOD) + + def get_parameter_names_from_scenario_file() -> Tuple[str]: + """Get the tuple of names of the scenarios from `Scenario` class used to create the results.""" + from scripts.healthsystem.impact_of_actual_vs_funded.scenario_impact_of_actual_vs_funded import ( + ImpactOfHealthSystemMode, + ) + e = ImpactOfHealthSystemMode() + return tuple(e._scenarios.keys()) + + def get_num_deaths(_df): + """Return total number of Deaths (total within the TARGET_PERIOD) + """ + return pd.Series(data=len(_df.loc[pd.to_datetime(_df.date).between(*TARGET_PERIOD)])) + + def get_num_dalys(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year.between(*[i.year for i in TARGET_PERIOD])] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum().sum() + ) + + def get_num_dalys_by_cause(_df): + """Return number of DALYs by cause by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year.between(*[i.year for i in TARGET_PERIOD])] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum() + ) + + def set_param_names_as_column_index_level_0(_df): + """Set the columns index (level 0) as the param_names.""" + ordered_param_names_no_prefix = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ordered_param_names_no_prefix.get(col) for col in _df.columns.levels[0]] + assert len(names_of_cols_level0) == len(_df.columns.levels[0]) + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + def find_difference_relative_to_comparison(_ser: pd.Series, + comparison: str, + scaled: bool = False, + drop_comparison: bool = True, + ): + """Find the difference in the values in a pd.Series with a multi-index, between the draws (level 0) + within the runs (level 1), relative to where draw = `comparison`. + The comparison is `X - COMPARISON`.""" + return _ser \ + .unstack(level=0) \ + .apply(lambda x: (x - x[comparison]) / (x[comparison] if scaled else 1.0), axis=1) \ + .drop(columns=([comparison] if drop_comparison else [])) \ + .stack() + + + def get_counts_of_hsi_by_treatment_id(_df): + """Get the counts of the short TREATMENT_IDs occurring""" + _counts_by_treatment_id = _df \ + .loc[pd.to_datetime(_df['date']).between(*TARGET_PERIOD), 'TREATMENT_ID'] \ + .apply(pd.Series) \ + .sum() \ + .astype(int) + return _counts_by_treatment_id.groupby(level=0).sum() + + year_target = 2023 + def get_counts_of_hsi_by_treatment_id_by_year(_df): + """Get the counts of the short TREATMENT_IDs occurring""" + _counts_by_treatment_id = _df \ + .loc[pd.to_datetime(_df['date']).dt.year ==year_target, 'TREATMENT_ID'] \ + .apply(pd.Series) \ + .sum() \ + .astype(int) + return _counts_by_treatment_id.groupby(level=0).sum() + + def get_counts_of_hsi_by_short_treatment_id(_df): + """Get the counts of the short TREATMENT_IDs occurring (shortened, up to first underscore)""" + _counts_by_treatment_id = get_counts_of_hsi_by_treatment_id(_df) + _short_treatment_id = _counts_by_treatment_id.index.map(lambda x: x.split('_')[0] + "*") + return _counts_by_treatment_id.groupby(by=_short_treatment_id).sum() + + def get_counts_of_hsi_by_short_treatment_id_by_year(_df): + """Get the counts of the short TREATMENT_IDs occurring (shortened, up to first underscore)""" + _counts_by_treatment_id = get_counts_of_hsi_by_treatment_id_by_year(_df) + _short_treatment_id = _counts_by_treatment_id.index.map(lambda x: x.split('_')[0] + "*") + return _counts_by_treatment_id.groupby(by=_short_treatment_id).sum() + + + # Obtain parameter names for this scenario file + param_names = get_parameter_names_from_scenario_file() + print(param_names) + + # ================================================================================================ + # TIME EVOLUTION OF TOTAL DALYs + # Plot DALYs averted compared to the ``No Policy'' policy + + year_target = 2023 # This global variable will be passed to custom function + def get_num_dalys_by_year(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year == year_target] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum().sum() + ) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + this_min_year = 2010 + for year in range(this_min_year, max_year+1): + year_target = year + num_dalys_by_year = extract_results( + results_folder, + module='tlo.methods.healthburden', + key='dalys_stacked', + custom_generate_series=get_num_dalys_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = num_dalys_by_year + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'index_original']) + concatenated_df = concatenated_df.reset_index(level='index_original',drop=True) + dalys_by_year = concatenated_df + print(dalys_by_year) + dalys_by_year.to_csv('ConvertedOutputs/Total_DALYs_with_time.csv', index=True) + + # ================================================================================================ + # Print population under each scenario + pop_model = extract_results(results_folder, + module="tlo.methods.demography", + key="population", + column="total", + index="date", + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + + pop_model.index = pop_model.index.year + pop_model = pop_model[(pop_model.index >= this_min_year) & (pop_model.index <= max_year)] + print(pop_model) + assert dalys_by_year.index.equals(pop_model.index) + assert all(dalys_by_year.columns == pop_model.columns) + pop_model.to_csv('ConvertedOutputs/Population_with_time.csv', index=True) + + # ================================================================================================ + # DALYs BROKEN DOWN BY CAUSES AND YEAR + # DALYs by cause per year + # %% Quantify the health losses associated with all interventions combined. + + year_target = 2023 # This global variable will be passed to custom function + def get_num_dalys_by_year_and_cause(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year == year_target] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum() + ) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + this_min_year = 2010 + for year in range(this_min_year, max_year+1): + year_target = year + num_dalys_by_year = extract_results( + results_folder, + module='tlo.methods.healthburden', + key='dalys_stacked', + custom_generate_series=get_num_dalys_by_year_and_cause, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = num_dalys_by_year #summarize(num_dalys_by_year) + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + + df_total = concatenated_df + df_total.to_csv('ConvertedOutputs/DALYS_by_cause_with_time.csv', index=True) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + for year in range(min_year, max_year+1): + year_target = year + + hsi_delivered_by_year = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='HSI_Event', + custom_generate_series=get_counts_of_hsi_by_short_treatment_id_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = hsi_delivered_by_year + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + HSI_ran_by_year = concatenated_df + + del ALL + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + for year in range(min_year, max_year+1): + year_target = year + + hsi_not_delivered_by_year = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Never_ran_HSI_Event', + custom_generate_series=get_counts_of_hsi_by_short_treatment_id_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = hsi_not_delivered_by_year + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + HSI_never_ran_by_year = concatenated_df + + HSI_never_ran_by_year = HSI_never_ran_by_year.fillna(0) #clean_df( + HSI_ran_by_year = HSI_ran_by_year.fillna(0) + HSI_total_by_year = HSI_ran_by_year.add(HSI_never_ran_by_year, fill_value=0) + HSI_ran_by_year.to_csv('ConvertedOutputs/HSIs_ran_by_area_with_time.csv', index=True) + HSI_never_ran_by_year.to_csv('ConvertedOutputs/HSIs_never_ran_by_area_with_time.csv', index=True) + print(HSI_ran_by_year) + print(HSI_never_ran_by_year) + print(HSI_total_by_year) + +if __name__ == "__main__": + rfp = Path('resources') + + parser = argparse.ArgumentParser( + description="Produce plots to show the impact each set of treatments", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--output-path", + help=( + "Directory to write outputs to. If not specified (set to None) outputs " + "will be written to value of --results-path argument." + ), + type=Path, + default=None, + required=False, + ) + parser.add_argument( + "--resources-path", + help="Directory containing resource files", + type=Path, + default=Path('resources'), + required=False, + ) + parser.add_argument( + "--results-path", + type=Path, + help=( + "Directory containing results from running " + "src/scripts/analysis_data_generation/scenario_generate_chains.py " + ), + default=None, + required=False + ) + args = parser.parse_args() + assert args.results_path is not None + results_path = args.results_path + + output_path = results_path if args.output_path is None else args.output_path + + apply( + results_folder=results_path, + output_folder=output_path, + resourcefilepath=args.resources_path + ) diff --git a/src/scripts/analysis_data_generation/postprocess_events_chain.py b/src/scripts/analysis_data_generation/postprocess_events_chain.py new file mode 100644 index 0000000000..96c27a04b1 --- /dev/null +++ b/src/scripts/analysis_data_generation/postprocess_events_chain.py @@ -0,0 +1,156 @@ +import pandas as pd +from dateutil.relativedelta import relativedelta + +# Remove from every individual's event chain all events that were fired after death +def cut_off_events_after_death(df): + + events_chain = df.groupby('person_ID') + + filtered_data = pd.DataFrame() + + for name, group in events_chain: + + # Find the first non-NaN 'date_of_death' and its index + first_non_nan_index = group['date_of_death'].first_valid_index() + + if first_non_nan_index is not None: + # Filter out all rows after the first non-NaN index + filtered_group = group.loc[:first_non_nan_index] # Keep rows up to and including the first valid index + filtered_data = pd.concat([filtered_data, filtered_group]) + else: + # If there are no non-NaN values, keep the original group + filtered_data = pd.concat([filtered_data, group]) + + return filtered_data + +# Load into DataFrame +def load_csv_to_dataframe(file_path): + try: + # Load raw chains into df + df = pd.read_csv(file_path) + print("Raw event chains loaded successfully!") + return df + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + except Exception as e: + print(f"An error occurred: {e}") + +file_path = 'output.csv' # Replace with the path to your CSV file + +output = load_csv_to_dataframe(file_path) + +# Some of the dates appeared not to be in datetime format. Correct here. +output['date_of_death'] = pd.to_datetime(output['date_of_death'], errors='coerce') +output['date_of_birth'] = pd.to_datetime(output['date_of_birth'], errors='coerce') +if 'hv_date_inf' in output.columns: + output['hv_date_inf'] = pd.to_datetime(output['hv_date_inf'], errors='coerce') + + +date_start = pd.to_datetime('2010-01-01') +if 'Other' in output['cause_of_death'].values: + print("ERROR: 'Other' was included in sim as possible cause of death") + exit(-1) + +# Choose which columns in individual properties to visualise +columns_to_print =['event','is_alive','hv_inf', 'hv_art','tb_inf', 'tb_date_active', 'event_date', 'when'] +#columns_to_print =['person_ID', 'date_of_birth', 'date_of_death', 'cause_of_death','hv_date_inf', 'hv_art','tb_inf', 'tb_date_active', 'event date', 'event'] + +# When checking which individuals led to *any* changes in individual properties, exclude these columns from comparison +columns_to_exclude_in_comparison = ['when', 'event', 'event_date', 'age_exact_years', 'age_years', 'age_days', 'age_range', 'level', 'appt_footprint'] + +# If considering epidemiology consistent with sim, add check here. +check_ages_of_those_HIV_inf = False +if check_ages_of_those_HIV_inf: + for index, row in output.iterrows(): + if pd.isna(row['hv_date_inf']): + continue # Skip this iteration + diff = relativedelta(output.loc[index, 'hv_date_inf'],output.loc[index, 'date_of_birth']) + if diff.years > 1 and diff.years<15: + print("Person contracted HIV infection at age younger than 15", diff) + +# Remove events after death +filtered_data = cut_off_events_after_death(output) + +print_raw_events = True # Print raw chain of events for each individual +print_selected_changes = False +print_all_changes = True +person_ID_of_interest = 494 + +pd.set_option('display.max_rows', None) + +for name, group in filtered_data.groupby('person_ID'): + list_of_dob = group['date_of_birth'] + + # Select individuals based on when they were born + if list_of_dob.iloc[0].year<2010: + + # Check that immutable properties are fixed for this individual, i.e. that events were collated properly: + all_identical_dob = group['date_of_birth'].nunique() == 1 + all_identical_sex = group['sex'].nunique() == 1 + if all_identical_dob is False or all_identical_sex is False: + print("Immutable properties are changing! This is not chain for single individual") + print(group) + exit(-1) + + print("----------------------------------------------------------------------") + print("person_ID ", group['person_ID'].iloc[0], "d.o.b ", group['date_of_birth'].iloc[0]) + print("Number of events for this individual ", group['person_ID'].iloc[0], "is :", len(group)/2) # Divide by 2 before printing Before/After for each event + number_of_events =len(group)/2 + number_of_changes=0 + if print_raw_events: + print(group) + + if print_all_changes: + # Check each row + comparison = group.drop(columns=columns_to_exclude_in_comparison).fillna(-99999).ne(group.drop(columns=columns_to_exclude_in_comparison).shift().fillna(-99999)) + + # Iterate over rows where any column has changed + for idx, row_changed in comparison.iloc[1:].iterrows(): + if row_changed.any(): # Check if any column changed in this row + number_of_changes+=1 + changed_columns = row_changed[row_changed].index.tolist() # Get the columns where changes occurred + print(f"Row {idx} - Changes detected in columns: {changed_columns}") + columns_output = ['event', 'event_date', 'appt_footprint', 'level'] + changed_columns + print(group.loc[idx, columns_output]) # Print only the changed columns + if group.loc[idx, 'when'] == 'Before': + print('-----> THIS CHANGE OCCURRED BEFORE EVENT!') + #print(group.loc[idx,columns_to_print]) + print() # For better readability + print("Number of changes is ", number_of_changes, "out of ", number_of_events, " events") + + if print_selected_changes: + tb_inf_condition = ( + ((group['tb_inf'].shift(1) == 'uninfected') & (group['tb_inf'] == 'active')) | + ((group['tb_inf'].shift(1) == 'latent') & (group['tb_inf'] == 'active')) | + ((group['tb_inf'].shift(1) == 'active') & (group['tb_inf'] == 'latent')) | + ((group['hv_inf'].shift(1) is False) & (group['hv_inf'] is True)) | + ((group['hv_art'].shift(1) == 'not') & (group['hv_art'] == 'on_not_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'not') & (group['hv_art'] == 'on_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_VL_suppressed') & (group['hv_art'] == 'on_not_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_VL_suppressed') & (group['hv_art'] == 'not')) | + ((group['hv_art'].shift(1) == 'on_not_VL_suppressed') & (group['hv_art'] == 'on_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_not_VL_suppressed') & (group['hv_art'] == 'not')) + ) + + alive_condition = ( + (group['is_alive'].shift(1) is True) & (group['is_alive'] is False) + ) + # Combine conditions for rows of interest + transition_condition = tb_inf_condition | alive_condition + + if list_of_dob.iloc[0].year >= 2010: + print("DETECTED OF INTEREST") + print(group[group['event'] == 'Birth'][columns_to_print]) + + # Filter the DataFrame based on the condition + filtered_transitions = group[transition_condition] + if not filtered_transitions.empty: + if list_of_dob.iloc[0].year < 2010: + print("DETECTED OF INTEREST") + print(filtered_transitions[columns_to_print]) + + +print("Number of individuals simulated ", filtered_data.groupby('person_ID').ngroups) + + + diff --git a/src/scripts/analysis_data_generation/scenario_generate_chains.py b/src/scripts/analysis_data_generation/scenario_generate_chains.py new file mode 100644 index 0000000000..0f53a1461b --- /dev/null +++ b/src/scripts/analysis_data_generation/scenario_generate_chains.py @@ -0,0 +1,119 @@ +"""This Scenario file run the model to generate event chans + +Run on the batch system using: +``` +tlo batch-submit + src/scripts/analysis_data_generation/scenario_generate_chains.py +``` + +or locally using: +``` + tlo scenario-run src/scripts/analysis_data_generation/scenario_generate_chains.py +``` + +""" +from pathlib import Path +from typing import Dict + +import pandas as pd + +from tlo import Date, logging +from tlo.analysis.utils import get_parameters_for_status_quo, mix_scenarios, get_filtered_treatment_ids +from tlo.methods.fullmodel import fullmodel +from tlo.methods.scenario_switcher import ImprovedHealthSystemAndCareSeekingScenarioSwitcher +from tlo.scenario import BaseScenario +from tlo.methods import ( + alri, + cardio_metabolic_disorders, + care_of_women_during_pregnancy, + contraception, + demography, + depression, + diarrhoea, + enhanced_lifestyle, + epi, + healthburden, + healthseekingbehaviour, + healthsystem, + hiv, + rti, + labour, + malaria, + newborn_outcomes, + postnatal_supervisor, + pregnancy_supervisor, + stunting, + symptommanager, + tb, + wasting, +) + +class GenerateEventChains(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 42 + self.start_date = Date(2010, 1, 1) + self.end_date = self.start_date + pd.DateOffset(months=1) + self.pop_size = 1000 + self._scenarios = self._get_scenarios() + self.number_of_draws = len(self._scenarios) + self.runs_per_draw = 3 + self.generate_event_chains = True + + def log_configuration(self): + return { + 'filename': 'generate_event_chains', + 'directory': Path('./outputs'), # <- (specified only for local running) + 'custom_levels': { + '*': logging.WARNING, + 'tlo.methods.demography': logging.INFO, + 'tlo.methods.events': logging.INFO, + 'tlo.methods.demography.detail': logging.WARNING, + 'tlo.methods.healthburden': logging.INFO, + 'tlo.methods.healthsystem.summary': logging.INFO, + 'tlo.methods.collect_event_chains': logging.INFO + } + } + + def modules(self): + return ( + fullmodel() + ) + + def draw_parameters(self, draw_number, rng): + if draw_number < self.number_of_draws: + return list(self._scenarios.values())[draw_number] + else: + return + + def _get_scenarios(self) -> Dict[str, Dict]: + + return { + "Baseline": + mix_scenarios( + self._baseline(), + { + "CollectEventChains": { + "generate_event_chains": True, + }, + } + ), + + } + + def _baseline(self) -> Dict: + #Return the Dict with values for the parameter changes that define the baseline scenario. + return mix_scenarios( + get_parameters_for_status_quo(), + { + "HealthSystem": { + "mode_appt_constraints": 1, # <-- Mode 1 prior to change to preserve calibration + "cons_availability": "all", + } + }, + ) + +if __name__ == '__main__': + from tlo.cli import scenario_run + + scenario_run([__file__]) diff --git a/src/tlo/analysis/utils.py b/src/tlo/analysis/utils.py index f6aff47faa..94bc541d30 100644 --- a/src/tlo/analysis/utils.py +++ b/src/tlo/analysis/utils.py @@ -366,6 +366,137 @@ def generate_series(dataframe: pd.DataFrame) -> pd.Series: return _concat +def unpack_dict_rows(df, non_dict_cols=None): + """ + Reconstruct a full DataFrame from rows where most columns are dictionaries. + Non-dict columns (e.g., 'date') are propagated to all reconstructed rows. + + Parameters: + df: pd.DataFrame + non_dict_cols: list of columns that are NOT dictionaries + """ + if non_dict_cols is None: + non_dict_cols = [] + + original_cols = ['E', 'date', 'EventName', 'A', 'V'] + + reconstructed_rows = [] + + for _, row in df.iterrows(): + # Determine dict columns for this row + dict_cols = [col for col in original_cols if col not in non_dict_cols] + + if not dict_cols: + # No dict columns, just append row + reconstructed_rows.append(row.to_dict()) + continue + + # Use the first dict column to get the block length + first_dict_col = dict_cols[0] + block_length = len(row[first_dict_col]) + + # Build each expanded row + for i in range(block_length): + new_row = {} + for col in original_cols: + cell = row[col] + if col in dict_cols: + # Access the dict using string or integer keys + new_row[col] = cell.get(str(i), cell.get(i)) + else: + # Propagate non-dict value + new_row[col] = cell + reconstructed_rows.append(new_row) + + # Build DataFrame in original column order + out = pd.DataFrame(reconstructed_rows)[original_cols] + + return out.reset_index(drop=True) + + +def print_filtered_df(df): + """ + Prints rows of the DataFrame excluding EventName 'Initialise' and 'Birth'. + """ + pd.set_option('display.max_colwidth', None) + filtered = df#[~df['EventName'].isin(['StartOfSimulation', 'Birth'])] + + dict_cols = ["Info"] + max_items = 2 + # Step 2: Truncate dictionary columns for display + if dict_cols is not None: + for col in dict_cols: + def truncate_dict(d): + if isinstance(d, dict): + items = list(d.items())[:max_items] # keep only first `max_items` + return dict(items) + return d + filtered[col] = filtered[col].apply(truncate_dict) + print(filtered) + + +def extract_event_chains(results_folder: Path, + ) -> dict: + """Utility function to collect chains of events. Individuals across runs of the same draw will be combined into unique df. + Returns dictionary where keys are draws, and each draw is associated with a dataframe of format 'E', 'EventDate', 'EventName', 'Info' where 'Info' is a dictionary that combines A&Vs for a particular individual + date + event name combination. + """ + module = 'tlo.collect_event_chains' + key = 'event_chains' + + # get number of draws and numbers of runs + info = get_scenario_info(results_folder) + + # Collect results from each draw/run. Individuals across runs of the same draw will be combined into unique df. + res = dict() + + for draw in range(info['number_of_draws']): + + # All individuals in same draw will be combined across runs, so their ID will be offset. + dfs_from_runs = [] + ID_offset = 0 + + for run in range(info['runs_per_draw']): + + try: + df: pd.DataFrame = load_pickled_dataframes(results_folder, draw, run, module)[module][key] + + recon = unpack_dict_rows(df, ['date']) + print(recon) + #del recon['EventDate'] + # For now convert value to string in all cases to facilitate manipulation. This can be reversed later. + recon['V'] = recon['V'].apply(str) + # Collapse into 'E', 'EventDate', 'EventName', 'Info' format where 'Info' is dict listing attributes (e.g. {a1:v1, a2:v2, a3:v3, ...} ) + df_collapsed = ( + recon.groupby(['E', 'date', 'EventName']) + .apply(lambda g: dict(zip(g['A'], g['V']))) + .reset_index(name='Info') + ) + df_final = df_collapsed.sort_values(by=['E','date'], ascending=True).reset_index(drop=True) + birth_count = (df_final['EventName'] == 'Birth').sum() + + print("Birth count for run ", run, "is ", birth_count) + df_final['E'] = df_final['E'] + ID_offset + + # Calculate ID offset for next run + ID_offset = (max(df_final['E']) + 1) + + # Append these chains to list + dfs_from_runs.append(df_final) + + except KeyError: + # Some logs could not be found - probably because this run failed. + # Simply to not append anything to the df collecting chains. + print("Run failed") + + # Combine all dfs into a single DataFrame + res[draw] = pd.concat(dfs_from_runs, ignore_index=True) + + # Optionally, sort by 'E' and 'EventDate' after combining + res[draw] = res[draw].sort_values(by=['E', 'date']).reset_index(drop=True) + + return res + + def compute_summary_statistics( results: pd.DataFrame, central_measure: Union[Literal["mean", "median"], None] = None, diff --git a/src/tlo/events.py b/src/tlo/events.py index 9dd34c9448..56acb82f43 100644 --- a/src/tlo/events.py +++ b/src/tlo/events.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from tlo import Simulation +from tlo.notify import notifier class Priority(Enum): """Enumeration for the Priority, which is used in sorting the events in the simulation queue.""" @@ -22,7 +23,6 @@ def __lt__(self, other): return self.value < other.value return NotImplemented - class Event: """Base event class, from which all others inherit. @@ -63,8 +63,15 @@ def apply(self, target): def run(self): """Make the event happen.""" + + # Dispatch notification that event is about to run + notifier.dispatch("event.about_to_run", data={"target": self.target, "module" : self.module.name, "link_info" : {"EventName": type(self).__name__}}) + self.apply(self.target) self.post_apply_hook() + + # Dispatch notification that event has just ran + notifier.dispatch("event.has_just_ran", data={"target": self.target, "link_info" : {"EventName": type(self).__name__}}) class RegularEvent(Event): diff --git a/src/tlo/methods/collect_event_chains.py b/src/tlo/methods/collect_event_chains.py new file mode 100644 index 0000000000..712d8c045e --- /dev/null +++ b/src/tlo/methods/collect_event_chains.py @@ -0,0 +1,357 @@ +from tlo.notify import notifier + +from pathlib import Path +from typing import Optional, List +from tlo import Module, Parameter, Types, logging, population +from tlo.population import Population +import pandas as pd + +from tlo.util import df_to_EAV, convert_chain_links_into_EAV, read_csv_files + +import copy + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +class CollectEventChains(Module): + + def __init__( + self, + name: Optional[str] = None, + generate_event_chains: Optional[bool] = None, + modules_of_interest: Optional[List[str]] = None, + events_to_ignore: Optional[List[str]] = None + + ): + super().__init__(name) + + self.generate_event_chains = generate_event_chains + self.modules_of_interest = modules_of_interest + self.events_to_ignore = events_to_ignore + + # This is how I am passing data from fnc taking place before event to the one after + # It doesn't seem very elegant but not sure how else to go about it + self.print_chains = False + self.df_before = [] + self.row_before = pd.Series() + self.mni_instances_before = False + self.mni_row_before = {} + self.entire_mni_before = {} + + PARAMETERS = { + # Options within module + "generate_event_chains": Parameter( + Types.BOOL, "Whether or not we want to collect chains of events for individuals" + ), + "modules_of_interest": Parameter( + Types.LIST, "Restrict the events collected to specific modules. If *, print for all modules" + ), + "events_to_ignore": Parameter( + Types.LIST, "Events to be ignored when collecting chains" + ), + } + + def initialise_simulation(self, sim): + notifier.add_listener("simulation.pop_has_been_initialised", self.on_notification_pop_has_been_initialised) + notifier.add_listener("simulation.on_birth", self.on_notification_of_birth) + notifier.add_listener("event.about_to_run", self.on_notification_event_about_to_run) + notifier.add_listener("event.has_just_ran", self.on_notification_event_has_just_ran) + + def read_parameters(self, resourcefilepath: Optional[Path] = None): + self.load_parameters_from_dataframe(pd.read_csv(resourcefilepath/"ResourceFile_GenerateEventChains/parameter_values.csv")) + + def initialise_population(self, population): + # Use parameter file values by default, if not overwritten + self.generate_event_chains = self.parameters['generate_event_chains'] \ + if self.generate_event_chains is None \ + else self.generate_event_chains + + self.modules_of_interest = self.parameters['modules_of_interest'] \ + if self.modules_of_interest is None \ + else self.modules_of_interest + + self.events_to_ignore = self.parameters['events_to_ignore'] \ + if self.events_to_ignore is None \ + else self.events_to_ignore + + # If modules of interest is '*', set by default to all modules included in the simulation + if self.modules_of_interest == ['*']: + self.modules_of_interest = list(self.sim.modules.keys()) + + def get_generate_event_chains(self) -> bool: + """Returns `generate_event_chains`. (Should be equal to what is specified by the parameter, but + overwrite with what was provided in argument if an argument was specified -- provided for backward + compatibility/debugging.)""" + return self.parameters['generate_event_chains'] \ + if self.arg_generate_event_chains is None \ + else self.arg_generate_event_chains + + def on_birth(self, mother, child): + # Could the notification of birth simply take place here? + pass + + def on_notification_pop_has_been_initialised(self, data): + # When logging events for each individual to reconstruct chains, only the changes in individual properties will be logged. + # At the start of the simulation + when a new individual is born, we therefore want to store all of their properties at the start. + if self.parameters['generate_event_chains']: + + # EDNAV structure to capture status of individuals at the start of the simulation + ednav = df_to_EAV(self.sim.population.props, self.sim.date, 'StartOfSimulation') + + logger.info(key='event_chains', + data = ednav.to_dict(), + description='Links forming chains of events for simulated individuals') + + + def on_notification_of_birth(self, data): + + if self.parameters['generate_event_chains']: + # When individual is born, store their initial properties to provide a starting point to the chain of property + # changes that this individual will undergo as a result of events taking place. + link_info = data['link_info'] + link_info.update(self.sim.population.props.loc[data['target']].to_dict()) + chain_links = {} + chain_links[data['target']] = link_info + + ednav = convert_chain_links_into_EAV(chain_links) + + logger.info(key='event_chains', + data = ednav.to_dict(), + description='Links forming chains of events for simulated individuals') + + + def on_notification_event_about_to_run(self, data): + """Do this when notified that an event is about to run. This function checks whether this event should be logged as part of the event chains, and if so stored required information before the event has occurred. """ + + # Only log event if + # 1) generate_event_chains is set to True + # 2) the event belongs to modules of interest and + # 3) the event is not in the list of events to ignore + if not self.generate_event_chains or (data['module'] not in self.modules_of_interest) or (data['link_info']['EventName'] in self.events_to_ignore): + return + else: + # Initialise these variables + self.print_chains = False + self.df_before = [] + self.row_before = pd.Series() + self.mni_instances_before = False + self.mni_row_before = {} + self.entire_mni_before = {} + + self.print_chains = True + + # Target is single individual + if not isinstance(data['target'], Population): + + # Save row for comparison after event has occurred + self.row_before = self.sim.population.props.loc[abs(data['target'])].copy().fillna(-99999) + + # Check if individual is already in mni dictionary, if so copy her original status + if 'PregnancySupervisor' in self.sim.modules: + mni = self.sim.modules['PregnancySupervisor'].mother_and_newborn_info + if data['target'] in mni: + self.mni_instances_before = True + self.mni_row_before = mni[data['target']].copy() + else: + self.mni_row_before = None + + else: + + # This will be a population-wide event. In order to find individuals for which this led to + # a meaningful change, make a copy of the while pop dataframe/mni before the event has occurred. + self.df_before = self.sim.population.props.copy() + if 'PregnancySupervisor' in self.sim.modules: + self.entire_mni_before = copy.deepcopy(self.sim.modules['PregnancySupervisor'].mother_and_newborn_info) + else: + self.entire_mni_before = None + + return + + + def on_notification_event_has_just_ran(self, data): + """ If print_chains=True, this function logs the event and identifies and logs the any property changes that have occured to one or multiple individuals as a result of the event taking place. """ + + if not self.print_chains: + return + else: + + chain_links = {} + + # Target is single individual + if not isinstance(data["target"], Population): + + # Copy full new status for individual + row_after = self.sim.population.props.loc[abs(data['target'])].fillna(-99999) + + # Check if individual is in mni after the event + mni_instances_after = False + if 'PregnancySupervisor' in self.sim.modules: + mni = self.sim.modules['PregnancySupervisor'].mother_and_newborn_info + if data['target'] in mni: + mni_instances_after = True + else: + mni_instances_after = None + + # Create and store event for this individual, regardless of whether any property change occurred + link_info = data['link_info'] + + # Store (if any) property changes as a result of the event for this individual + for key in self.row_before.index: + if self.row_before[key] != row_after[key]: # Note: used fillna previously, so this is safe + link_info[key] = row_after[key] + + if 'PregnancySupervisor' in self.sim.modules: + # Now check and store changes in the mni dictionary, accounting for following cases: + # Individual is in mni dictionary before and after + if self.mni_instances_before and mni_instances_after: + for key in self.mni_row_before: + if self.mni_values_differ(self.mni_row_before[key], mni[data['target']][key]): + link_info[key] = mni[data['target']][key] + # Individual is only in mni dictionary before event + elif self.mni_instances_before and not mni_instances_after: + default = self.sim.modules['PregnancySupervisor'].default_all_mni_values + for key in self.mni_row_before: + if self.mni_values_differ(self.mni_row_before[key], default[key]): + link_info[key] = default[key] + # Individual is only in mni dictionary after event + elif mni_instances_after and not self.mni_instances_before: + default = self.sim.modules['PregnancySupervisor'].default_all_mni_values + for key in default: + if self.mni_values_differ(default[key], mni[data['target']][key]): + link_info[key] = mni[data['target']][key] + # Else, no need to do anything + + # Add individual to the chain links + chain_links[data['target']] = link_info + + else: + # Target is entire population. Identify individuals for which properties have changed + # and store their changes. + + # Population frame after event + df_after = self.sim.population.props + if 'PregnancySupervisor' in self.sim.modules: + entire_mni_after = copy.deepcopy(self.sim.modules['PregnancySupervisor'].mother_and_newborn_info) + else: + entire_mni_after = None + + # Create and store the event and dictionary of changes for affected individuals + chain_links = self.compare_population_dataframe_and_mni(self.df_before, df_after, self.entire_mni_before, entire_mni_after) + + if chain_links: + # Convert chain_links into EAV + ednav = convert_chain_links_into_EAV(chain_links) + + logger.info(key='event_chains', + data= ednav.to_dict(), + description='Links forming chains of events for simulated individuals') + + # Reset variables + self.print_chains = False + self.df_before = [] + self.row_before = pd.Series() + self.mni_instances_before = False + self.mni_row_before = {} + self.entire_mni_before = {} + + return + + def mni_values_differ(self, v1, v2): + + if isinstance(v1, list) and isinstance(v2, list): + return v1 != v2 # simple element-wise comparison + + if pd.isna(v1) and pd.isna(v2): + return False # treat both NaT/NaN as equal + return v1 != v2 + + def compare_entire_mni_dicts(self,entire_mni_before, entire_mni_after): + diffs = {} + + all_individuals = set(entire_mni_before.keys()) | set(entire_mni_after.keys()) + + for person in all_individuals: + if person not in entire_mni_before: # but is afterward + for key in entire_mni_after[person]: + if self.mni_values_differ(entire_mni_after[person][key],self.sim.modules['PregnancySupervisor'].default_all_mni_values[key]): + if person not in diffs: + diffs[person] = {} + diffs[person][key] = entire_mni_after[person][key] + + elif person not in entire_mni_after: # but is beforehand + for key in entire_mni_before[person]: + if self.mni_values_differ(entire_mni_before[person][key],self.sim.modules['PregnancySupervisor'].default_all_mni_values[key]): + if person not in diffs: + diffs[person] = {} + diffs[person][key] = self.sim.modules['PregnancySupervisor'].default_all_mni_values[key] + + else: # person is in both + # Compare properties + for key in entire_mni_before[person]: + if self.mni_values_differ(entire_mni_before[person][key],entire_mni_after[person][key]): + if person not in diffs: + diffs[person] = {} + diffs[person][key] = entire_mni_after[person][key] + + return diffs + + def compare_population_dataframe_and_mni(self,df_before, df_after, entire_mni_before, entire_mni_after): + """ This function compares the population dataframe and mni dictionary before/after a population-wide event has occurred. + It allows us to identify the individuals for which this event led to a significant (i.e. property) change, and to store the properties which have changed as a result of it. """ + + # Create a mask of where values are different + diff_mask = (df_before != df_after) & ~(df_before.isna() & df_after.isna()) + if 'PregnancySupervisor' in self.sim.modules: + diff_mni = self.compare_entire_mni_dicts(entire_mni_before, entire_mni_after) + else: + diff_mni = [] + + # Create an empty list to store changes for each of the individuals + chain_links = {} + len_of_diff = len(diff_mask) + + # Loop through each row of the mask + persons_changed = [] + + for idx, row in diff_mask.iterrows(): + changed_cols = row.index[row].tolist() + + if changed_cols: # Proceed only if there are changes in the row + persons_changed.append(idx) + # Create a dictionary for this person + # First add event info + link_info = { + 'EventName': type(self).__name__, + } + + # Store the new values from df_after for the changed columns + for col in changed_cols: + link_info[col] = df_after.at[idx, col] + + if idx in diff_mni: + # This person has also undergone changes in the mni dictionary, so add these here + for key in diff_mni[idx]: + link_info[col] = diff_mni[idx][key] + + # Append the event and changes to the individual key + chain_links[idx] = link_info + + if 'PregnancySupervisor' in self.sim.modules: + # For individuals which only underwent changes in mni dictionary, save changes here + if len(diff_mni)>0: + for key in diff_mni: + if key not in persons_changed: + # If individual hadn't been previously added due to changes in pop df, add it here + link_info = { + 'EventName': type(self).__name__, + } + + for key_prop in diff_mni[key]: + link_info[key_prop] = diff_mni[key][key_prop] + + chain_links[key] = link_info + + return chain_links + + + diff --git a/src/tlo/methods/fullmodel.py b/src/tlo/methods/fullmodel.py index 3f0c79434e..3c710c7dd2 100644 --- a/src/tlo/methods/fullmodel.py +++ b/src/tlo/methods/fullmodel.py @@ -8,6 +8,7 @@ cardio_metabolic_disorders, care_of_women_during_pregnancy, cervical_cancer, + collect_event_chains, contraception, copd, demography, @@ -116,6 +117,7 @@ def fullmodel( copd.Copd, depression.Depression, epilepsy.Epilepsy, + collect_event_chains.CollectEventChains, ] return [ module_class( diff --git a/src/tlo/methods/hsi_event.py b/src/tlo/methods/hsi_event.py index cc5e0a4fb2..32620f6c28 100644 --- a/src/tlo/methods/hsi_event.py +++ b/src/tlo/methods/hsi_event.py @@ -7,6 +7,8 @@ from tlo import Date, logging from tlo.events import Event +from tlo.notify import notifier + if TYPE_CHECKING: from tlo import Module, Simulation @@ -193,13 +195,37 @@ def _run_after_hsi_event(self) -> None: item_codes=self._EQUIPMENT, facility_id=self.facility_info.id ) + def run(self, squeeze_factor): """Make the event happen.""" + + # Dispatch notification that HSI event is about to run + notifier.dispatch("event.about_to_run", data={"target": self.target, "module" : self.module.name, "link_info" : {"EventName": type(self).__name__}}) + updated_appt_footprint = self.apply(self.target, squeeze_factor) self.post_apply_hook() self._run_after_hsi_event() + + # Dispatch notification that HSI event has just ran + if updated_appt_footprint is not None: + footprint = updated_appt_footprint + else: + footprint = self.EXPECTED_APPT_FOOTPRINT + try: + level = self.facility_info.level + except: + level = "N/A" + + notifier.dispatch("event.has_just_ran", + data={"target": self.target, + "link_info" : {"EventName": type(self).__name__, + "footprint": footprint, + "level": level + }}) + return updated_appt_footprint + def get_consumables( self, diff --git a/src/tlo/methods/pregnancy_helper_functions.py b/src/tlo/methods/pregnancy_helper_functions.py index 76076a5dcf..d8e01ef205 100644 --- a/src/tlo/methods/pregnancy_helper_functions.py +++ b/src/tlo/methods/pregnancy_helper_functions.py @@ -531,91 +531,15 @@ def update_mni_dictionary(self, individual_id): if self == self.sim.modules['PregnancySupervisor']: - mni[individual_id] = {'delete_mni': False, # if True, mni deleted in report_daly_values function - 'didnt_seek_care': False, - 'cons_not_avail': False, - 'comp_not_avail': False, - 'hcw_not_avail': False, - 'ga_anc_one': 0, - 'anc_ints': [], - 'abortion_onset': pd.NaT, - 'abortion_haem_onset': pd.NaT, - 'abortion_sep_onset': pd.NaT, - 'eclampsia_onset': pd.NaT, - 'mild_mod_aph_onset': pd.NaT, - 'severe_aph_onset': pd.NaT, - 'chorio_onset': pd.NaT, - 'chorio_in_preg': False, # use in predictor in newborn linear models - 'ectopic_onset': pd.NaT, - 'ectopic_rupture_onset': pd.NaT, - 'gest_diab_onset': pd.NaT, - 'gest_diab_diagnosed_onset': pd.NaT, - 'gest_diab_resolution': pd.NaT, - 'mild_anaemia_onset': pd.NaT, - 'mild_anaemia_resolution': pd.NaT, - 'moderate_anaemia_onset': pd.NaT, - 'moderate_anaemia_resolution': pd.NaT, - 'severe_anaemia_onset': pd.NaT, - 'severe_anaemia_resolution': pd.NaT, - 'mild_anaemia_pp_onset': pd.NaT, - 'mild_anaemia_pp_resolution': pd.NaT, - 'moderate_anaemia_pp_onset': pd.NaT, - 'moderate_anaemia_pp_resolution': pd.NaT, - 'severe_anaemia_pp_onset': pd.NaT, - 'severe_anaemia_pp_resolution': pd.NaT, - 'hypertension_onset': pd.NaT, - 'hypertension_resolution': pd.NaT, - 'obstructed_labour_onset': pd.NaT, - 'sepsis_onset': pd.NaT, - 'uterine_rupture_onset': pd.NaT, - 'mild_mod_pph_onset': pd.NaT, - 'severe_pph_onset': pd.NaT, - 'secondary_pph_onset': pd.NaT, - 'vesicovaginal_fistula_onset': pd.NaT, - 'vesicovaginal_fistula_resolution': pd.NaT, - 'rectovaginal_fistula_onset': pd.NaT, - 'rectovaginal_fistula_resolution': pd.NaT, - 'test_run': False, # used by labour module when running some model tests - 'pred_syph_infect': pd.NaT, # date syphilis is predicted to onset - 'new_onset_spe': False, - 'cs_indication': 'none' - } + mni[individual_id] = self.sim.modules['PregnancySupervisor'].default_mni_values.copy() elif self == self.sim.modules['Labour']: - labour_variables = {'labour_state': None, - # Term Labour (TL), Early Preterm (EPTL), Late Preterm (LPTL) or Post Term (POTL) - 'birth_weight': 'normal_birth_weight', - 'birth_size': 'average_for_gestational_age', - 'delivery_setting': None, # home_birth, health_centre, hospital - 'twins': df.at[individual_id, 'ps_multiple_pregnancy'], - 'twin_count': 0, - 'twin_one_comps': False, - 'pnc_twin_one': 'none', - 'bf_status_twin_one': 'none', - 'eibf_status_twin_one': False, - 'an_placental_abruption': df.at[individual_id, 'ps_placental_abruption'], - 'corticosteroids_given': False, - 'clean_birth_practices': False, - 'abx_for_prom_given': False, - 'abx_for_pprom_given': False, - 'endo_pp': False, - 'retained_placenta': False, - 'uterine_atony': False, - 'amtsl_given': False, - 'cpd': False, - 'mode_of_delivery': 'vaginal_delivery', - 'neo_will_receive_resus_if_needed': False, - # vaginal_delivery, instrumental, caesarean_section - 'hsi_cant_run': False, # True (T) or False (F) - 'sought_care_for_complication': False, # True (T) or False (F) - 'sought_care_labour_phase': 'none', - 'referred_for_cs': False, # True (T) or False (F) - 'referred_for_blood': False, # True (T) or False (F) - 'received_blood_transfusion': False, # True (T) or False (F) - 'referred_for_surgery': False, # True (T) or False (F)' - 'death_in_labour': False, # True (T) or False (F) - 'single_twin_still_birth': False, # True (T) or False (F) - 'will_receive_pnc': 'none', - 'passed_through_week_one': False} - - mni[individual_id].update(labour_variables) + + labour_default = self.sim.modules['PregnancySupervisor'].default_labour_values.copy() + mni[individual_id].update(labour_default) + + # Update from default based on individual case + mni[individual_id]['twins'] = df.at[individual_id, 'ps_multiple_pregnancy'] + mni[individual_id]['an_placental_abruption'] = df.at[individual_id, 'ps_placental_abruption'] + + diff --git a/src/tlo/methods/pregnancy_supervisor.py b/src/tlo/methods/pregnancy_supervisor.py index 1be38175f7..d5071bc459 100644 --- a/src/tlo/methods/pregnancy_supervisor.py +++ b/src/tlo/methods/pregnancy_supervisor.py @@ -60,6 +60,96 @@ def __init__(self, name=None): # This variable will store a Bitset handler for the property ps_abortion_complications self.abortion_complications = None + + self.default_mni_values = {'delete_mni': False, # if True, mni deleted in report_daly_values function + 'didnt_seek_care': False, + 'cons_not_avail': False, + 'comp_not_avail': False, + 'hcw_not_avail': False, + 'ga_anc_one': 0, + 'anc_ints': [], + 'abortion_onset': pd.NaT, + 'abortion_haem_onset': pd.NaT, + 'abortion_sep_onset': pd.NaT, + 'eclampsia_onset': pd.NaT, + 'mild_mod_aph_onset': pd.NaT, + 'severe_aph_onset': pd.NaT, + 'chorio_onset': pd.NaT, + 'chorio_in_preg': False, # use in predictor in newborn linear models + 'ectopic_onset': pd.NaT, + 'ectopic_rupture_onset': pd.NaT, + 'gest_diab_onset': pd.NaT, + 'gest_diab_diagnosed_onset': pd.NaT, + 'gest_diab_resolution': pd.NaT, + 'none_anaemia_onset': pd.NaT, + 'none_anaemia_resolution': pd.NaT, + 'mild_anaemia_onset': pd.NaT, + 'mild_anaemia_resolution': pd.NaT, + 'moderate_anaemia_onset': pd.NaT, + 'moderate_anaemia_resolution': pd.NaT, + 'severe_anaemia_onset': pd.NaT, + 'severe_anaemia_resolution': pd.NaT, + 'mild_anaemia_pp_onset': pd.NaT, + 'mild_anaemia_pp_resolution': pd.NaT, + 'moderate_anaemia_pp_onset': pd.NaT, + 'moderate_anaemia_pp_resolution': pd.NaT, + 'severe_anaemia_pp_onset': pd.NaT, + 'severe_anaemia_pp_resolution': pd.NaT, + 'hypertension_onset': pd.NaT, + 'hypertension_resolution': pd.NaT, + 'obstructed_labour_onset': pd.NaT, + 'sepsis_onset': pd.NaT, + 'uterine_rupture_onset': pd.NaT, + 'mild_mod_pph_onset': pd.NaT, + 'severe_pph_onset': pd.NaT, + 'secondary_pph_onset': pd.NaT, + 'vesicovaginal_fistula_onset': pd.NaT, + 'vesicovaginal_fistula_resolution': pd.NaT, + 'rectovaginal_fistula_onset': pd.NaT, + 'rectovaginal_fistula_resolution': pd.NaT, + 'test_run': False, # used by labour module when running some model tests + 'pred_syph_infect': pd.NaT, # date syphilis is predicted to onset + 'new_onset_spe': False, + 'cs_indication': 'none' + } + self.default_labour_values = {'labour_state': None, + # Term Labour (TL), Early Preterm (EPTL), Late Preterm (LPTL) or Post Term (POTL) + 'birth_weight': 'normal_birth_weight', + 'birth_size': 'average_for_gestational_age', + 'delivery_setting': None, # home_birth, health_centre, hospital + 'twins': None, + 'twin_count': 0, + 'twin_one_comps': False, + 'pnc_twin_one': 'none', + 'bf_status_twin_one': 'none', + 'eibf_status_twin_one': False, + 'an_placental_abruption': None, + 'corticosteroids_given': False, + 'clean_birth_practices': False, + 'abx_for_prom_given': False, + 'abx_for_pprom_given': False, + 'endo_pp': False, + 'retained_placenta': False, + 'uterine_atony': False, + 'amtsl_given': False, + 'cpd': False, + 'mode_of_delivery': 'vaginal_delivery', + 'neo_will_receive_resus_if_needed': False, + # vaginal_delivery, instrumental, caesarean_section + 'hsi_cant_run': False, # True (T) or False (F) + 'sought_care_for_complication': False, # True (T) or False (F) + 'sought_care_labour_phase': 'none', + 'referred_for_cs': False, # True (T) or False (F) + 'referred_for_blood': False, # True (T) or False (F) + 'received_blood_transfusion': False, # True (T) or False (F) + 'referred_for_surgery': False, # True (T) or False (F)' + 'death_in_labour': False, # True (T) or False (F) + 'single_twin_still_birth': False, # True (T) or False (F) + 'will_receive_pnc': 'none', + 'passed_through_week_one': False} + + self.default_all_mni_values = self.default_mni_values + self.default_all_mni_values.update(self.default_labour_values) # Finally we create a dictionary to capture the frequency of key outcomes for logging mnh_oc = pregnancy_helper_functions.generate_mnh_outcome_counter() diff --git a/src/tlo/methods/rti.py b/src/tlo/methods/rti.py index 994b8d1054..92f79f7538 100644 --- a/src/tlo/methods/rti.py +++ b/src/tlo/methods/rti.py @@ -2776,7 +2776,7 @@ class RTIPollingEvent(RegularEvent, PopulationScopeEventMixin): def __init__(self, module): """Schedule to take place every month """ - super().__init__(module, frequency=DateOffset(months=1)) + super().__init__(module, frequency=DateOffset(months=1000)) # Single polling event p = module.parameters # Parameters which transition the model between states self.base_1m_prob_rti = (p['base_rate_injrti'] / 12) @@ -2864,9 +2864,15 @@ def apply(self, population): .when('.between(70,79)', self.rr_injrti_age7079), Predictor('li_ex_alc').when(True, self.rr_injrti_excessalcohol) ) + #if self.sim.generate_event_chains is True and self.sim.generate_event_chains_overwrite_epi is True: + #pred = 1.0 + #else: pred = eq.predict(df.loc[rt_current_non_ind]) + + random_draw_in_rti = self.module.rng.random_sample(size=len(rt_current_non_ind)) selected_for_rti = rt_current_non_ind[pred > random_draw_in_rti] + # Update to say they have been involved in a rti df.loc[selected_for_rti, 'rt_road_traffic_inc'] = True # Set the date that people were injured to now @@ -4874,6 +4880,7 @@ def __init__(self, module, person_id): self.treated_code = 'none' def apply(self, person_id, squeeze_factor): + self._number_of_times_this_event_has_run += 1 df = self.sim.population.props rng = self.module.rng @@ -4922,10 +4929,12 @@ def apply(self, person_id, squeeze_factor): # injury is being treated in this surgery # find untreated injury codes that are treated with major surgery relevant_codes = np.intersect1d(injuries_to_be_treated, surgically_treated_codes) + # check that the person sent here has an appropriate code(s) assert len(relevant_codes) > 0 # choose a code at random self.treated_code = rng.choice(relevant_codes) + if request_outcome: # check the people sent here hasn't died due to rti, have had their injuries diagnosed and been through # RTI_Med @@ -5012,7 +5021,9 @@ def apply(self, person_id, squeeze_factor): # ------------------------------------- Perm disability from amputation ------------------------------------ codes = ['782', '782a', '782b', '782c', '783', '882', '883', '884'] + if self.treated_code in codes: + # Track whether they are permanently disabled df.at[person_id, 'rt_perm_disability'] = True # Find the column and code where the permanent injury is stored diff --git a/src/tlo/notify.py b/src/tlo/notify.py new file mode 100644 index 0000000000..b1b4434ba9 --- /dev/null +++ b/src/tlo/notify.py @@ -0,0 +1,72 @@ +""" +A dead simple synchronous notification dispatcher. + +Usage +----- +# In the notifying class/module +from tlo.notify import notifier + +notifier.dispatch("simulation.on_start", data={"one": 1, "two": 2}) + +# In the listening class/module +from tlo.notify import notifier + +def on_notification(data): + print("Received notification:", data) + +notifier.add_listener("simulation.on_start", on_notification) +""" + + +class Notifier: + """ + A simple synchronous notification dispatcher supporting listeners. + """ + + def __init__(self): + self.listeners = {} + + def add_listener(self, notification_key, listener): + """ + Register a listener for a specific notification. + + :param notification_key: The identifier to listen for. + :param listener: A callable to be invoked when the notification is dispatched. + """ + if notification_key not in self.listeners: + self.listeners[notification_key] = [] + self.listeners[notification_key].append(listener) + + def remove_listener(self, notification_key, listener): + """ + Remove a previously registered listener for a notification. + + :param notification_key: The identifier. + :param listener: The listener callable to remove. + """ + if notification_key in self.listeners: + self.listeners[notification_key].remove(listener) + if not self.listeners[notification_key]: + del self.listeners[notification_key] + + def dispatch(self, notification_key, data=None): + """ + Dispatch a notification to all registered listeners. + + :param notification_key: The identifier. + :param data: Optional data to pass to each listener. + """ + if notification_key in self.listeners: + for listener in self.listeners[notification_key]: + listener(data) + + def clear_listeners(self): + """ + Clear all registered listeners. Essential because the notifier is a global singleton. + e.g. if you are running multiple tests or simulations in the same process. + """ + self.listeners.clear() + + +# Create a global notifier instance +notifier = Notifier() diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index d2560f92d9..ded5960e6e 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -9,9 +9,10 @@ from collections import Counter, OrderedDict from pathlib import Path from typing import TYPE_CHECKING, Optional - import numpy as np +from tlo.notify import notifier + try: import dill @@ -26,6 +27,7 @@ topologically_sort_modules, ) from tlo.events import Event, IndividualScopeEventMixin +from tlo.notify import notifier from tlo.progressbar import ProgressBar if TYPE_CHECKING: @@ -106,7 +108,6 @@ def __init__( self.end_date = None self.output_file = None self.population: Optional[Population] = None - self.show_progress_bar = show_progress_bar self.resourcefilepath = Path(resourcefilepath) @@ -116,6 +117,8 @@ def __init__( self._custom_log_levels = None self._log_filepath = self._configure_logging(**log_config) + # clear notifier listeners from any previous simulation in this process + notifier.clear_listeners() # random number generator seed_from = "auto" if seed is None else "user" @@ -282,13 +285,15 @@ def make_initial_population(self, *, n: int) -> None: key="debug", data=f"{module.name}.initialise_population() {time.time() - start1} s", ) - + + # Dispatch notification that pop has been initialised + notifier.dispatch("simulation.pop_has_been_initialised", data={}) + end = time.time() logger.info(key="info", data=f"make_initial_population() {end - start} s") - + def initialise(self, *, end_date: Date) -> None: """Initialise all modules in simulation. - :param end_date: Date to end simulation on - accessible to modules to allow initialising data structures which may depend (in size for example) on the date range being simulated. @@ -298,6 +303,7 @@ def initialise(self, *, end_date: Date) -> None: raise SimulationPreviouslyInitialisedError(msg) self.date = self.start_date self.end_date = end_date # store the end_date so that others can reference it + for module in self.modules.values(): module.initialise_simulation(self) self._initialised = True @@ -382,6 +388,7 @@ def run_simulation_to(self, *, to_date: Date) -> None: self._update_progress_bar(progress_bar, date) self.fire_single_event(event, date) self.date = to_date + if self.show_progress_bar: progress_bar.stop() @@ -436,8 +443,13 @@ def do_birth(self, mother_id: int) -> int: child_id = self.population.do_birth() for module in self.modules.values(): module.on_birth(mother_id, child_id) + + # Dispatch notification that birth is about to occur + notifier.dispatch("simulation.on_birth", data={'target': child_id, 'link_info' : {'EventName': 'Birth'}}) + return child_id + def find_events_for_person(self, person_id: int) -> list[tuple[Date, Event]]: """Find the events in the queue for a particular person. diff --git a/src/tlo/util.py b/src/tlo/util.py index efe17a9920..189f994353 100644 --- a/src/tlo/util.py +++ b/src/tlo/util.py @@ -94,6 +94,33 @@ def transition_states(initial_series: pd.Series, prob_matrix: pd.DataFrame, rng: return final_states +def df_to_EAV(df, date, event_name): + """Function to convert dataframe into EAV""" + eav = df.stack().reset_index() + eav.columns = ['E', 'A', 'V'] + eav['EventName'] = event_name + eav = eav[["E", "EventName", "A", "V"]] + + return eav + + +def convert_chain_links_into_EAV(chain_links): + df = pd.DataFrame.from_dict(chain_links, orient="index") + id_cols = ["EventName"] + + eav = df.reset_index().melt( + id_vars=["index"] + id_cols, # index = person ID + var_name="A", + value_name="V" + ) + + eav.rename(columns={"index": "E"}, inplace=True) + + eav = eav[["E", "EventName", "A", "V"]] + + return eav + + def sample_outcome(probs: pd.DataFrame, rng: np.random.RandomState): """ Helper function to randomly sample an outcome for each individual in a population from a set of probabilities that are specific to each individual. diff --git a/tests/test_notify.py b/tests/test_notify.py new file mode 100644 index 0000000000..ad5e828bbf --- /dev/null +++ b/tests/test_notify.py @@ -0,0 +1,23 @@ +from tlo.notify import notifier + + +def test_notifier(): + # in listening code + received_data = [] + + def callback(data): + received_data.append(data) + + notifier.add_listener("test.signal", callback) + + # in emitting code + notifier.dispatch("test.signal", data={"value": 42}) + + assert len(received_data) == 1 + assert received_data[0] == {"value": 42} + + # Unsubscribe and test no further calls + notifier.remove_listener("test.signal", callback) + notifier.dispatch("test.signal", data={"value": 100}) + + assert len(received_data) == 1 # No new data diff --git a/tests/test_rti.py b/tests/test_rti.py index 2b65595782..711215b8cf 100644 --- a/tests/test_rti.py +++ b/tests/test_rti.py @@ -66,6 +66,7 @@ def test_run(seed): check_dtypes(sim) + @pytest.mark.slow def test_all_injuries_run(seed): """