diff --git a/resources/ResourceFile_PregnancyCohort.xlsx b/resources/ResourceFile_PregnancyCohort.xlsx new file mode 100644 index 0000000000..bd4b0af086 --- /dev/null +++ b/resources/ResourceFile_PregnancyCohort.xlsx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9f41deeb8fd44fda6cc451a7e23390c98bba449e704ccd62e258caaf2e6e6606 +size 18595782 diff --git a/resources/maternal cohort/ResourceFile_All2024PregnanciesCohortModel.xlsx b/resources/maternal cohort/ResourceFile_All2024PregnanciesCohortModel.xlsx new file mode 100644 index 0000000000..500fcfd169 --- /dev/null +++ b/resources/maternal cohort/ResourceFile_All2024PregnanciesCohortModel.xlsx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39435a18741e06862f5c0b29d47c331c2c15d22f2954d204c5cf9e3fd80f20bc +size 20542399 diff --git a/src/scripts/analysis_data_generation/analysis_extract_data.py b/src/scripts/analysis_data_generation/analysis_extract_data.py new file mode 100644 index 0000000000..9ee37cabef --- /dev/null +++ b/src/scripts/analysis_data_generation/analysis_extract_data.py @@ -0,0 +1,557 @@ +"""Produce plots to show the health impact (deaths, dalys) each the healthcare system (overall health impact) when +running under different MODES and POLICIES (scenario_impact_of_actual_vs_funded.py)""" + +# short tclose -> ideal case +# long tclose -> status quo +import argparse +from pathlib import Path +from typing import Tuple + +import pandas as pd +import matplotlib.pyplot as plt + +from tlo import Date +from tlo.analysis.utils import extract_results, extract_event_chains +from datetime import datetime +from collections import Counter +import ast + +# Time simulated to collect data +start_date = Date(2010, 1, 1) +end_date = start_date + pd.DateOffset(months=13) + +# Range of years considered +min_year = 2010 +max_year = 2040 + + +def all_columns(_df): + return pd.Series(_df.all()) + +def check_if_beyond_time_range_considered(progression_properties): + matching_keys = [key for key in progression_properties.keys() if "rt_date_to_remove_daly" in key] + if matching_keys: + for key in matching_keys: + if progression_properties[key] > end_date: + print("Beyond time range considered, need at least ",progression_properties[key]) + +def print_filtered_df(df): + """ + Prints rows of the DataFrame excluding EventName 'Initialise' and 'Birth'. + """ + pd.set_option('display.max_colwidth', None) + filtered = df#[~df['EventName'].isin(['StartOfSimulation', 'Birth'])] + + dict_cols = ["Info"] + max_items = 2 + # Step 2: Truncate dictionary columns for display + if dict_cols is not None: + for col in dict_cols: + def truncate_dict(d): + if isinstance(d, dict): + items = list(d.items())[:max_items] # keep only first `max_items` + return dict(items) + return d + filtered[col] = filtered[col].apply(truncate_dict) + print(filtered) + + +def apply(results_folder: Path, output_folder: Path, resourcefilepath: Path = None, ): + """Produce standard set of plots describing the effect of each TREATMENT_ID. + - We estimate the epidemiological impact as the EXTRA deaths that would occur if that treatment did not occur. + - We estimate the draw on healthcare system resources as the FEWER appointments when that treatment does not occur. + """ + pd.set_option('display.max_rows', None) + pd.set_option('display.max_colwidth', None) + + individual_event_chains = extract_event_chains(results_folder) + print_filtered_df(individual_event_chains[0]) + exit(-1) + + eval_env = { + 'datetime': datetime, # Add the datetime class to the eval environment + 'pd': pd, # Add pandas to handle Timestamp + 'Timestamp': pd.Timestamp, # Specifically add Timestamp for eval + 'NaT': pd.NaT, + 'nan': float('nan'), # Include NaN for eval (can also use pd.NA if preferred) + } + + initial_properties_of_interest = ['rt_MAIS_military_score','rt_ISS_score','rt_disability','rt_polytrauma','rt_injury_1','rt_injury_2','rt_injury_3','rt_injury_4','rt_injury_5','rt_injury_6', 'rt_imm_death','sy_injury','sy_severe_trauma','sex','li_urban', 'li_wealth', 'li_mar_stat', 'li_in_ed', 'li_ed_lev'] + + # Will be added through computation: age at time of RTI + # Will be added through computation: total duration of event + + initial_rt_event_properties = set() + + num_individuals = 1000 + num_runs = 1 + record = [] + # Include results folder in output file name + name_tag = str(results_folder).replace("outputs/", "") + + + + for p in range(0,num_individuals): + + print("At person = ", p, " out of ", num_individuals) + + individual_event_chains = extract_results( + results_folder, + module='tlo.simulation', + key='event_chains', + column=str(p), + do_scaling=False + ) + + for r in range(0,num_runs): + initial_properties = {} + key_first_event = {} + key_last_event = {} + first_event = {} + last_event = {} + properties = {} + average_disability = 0 + total_dt_included = 0 + dt_in_prev_disability = 0 + prev_disability_incurred = 0 + ind_Counter = {'0': Counter(), '1a': Counter(), '1b' : Counter(), '2' : Counter()} + # Count total appts + + list_for_individual = [] + for item,row in individual_event_chains.iterrows(): + value = individual_event_chains.loc[item,(0, r)] + if value !='' and isinstance(value, str): + evaluated = eval(value, eval_env) + list_for_individual.append(evaluated) + + for i in list_for_individual: + print(i) + + """ + # These are the properties of the individual before the start of the chain of events + initial_properties = list_for_individual[0] + + # Initialise first event by gathering parameters of interest from initial_properties + first_event = {key: initial_properties[key] for key in initial_properties_of_interest if key in initial_properties} + + # The changing or adding of properties from the first_event will be stored in progression_properties + progression_properties = {} + + for i in list_for_individual: + # Skip the initial_properties, or in other words only consider these if they are 'proper' events + if 'event' in i: + #print(i) + if 'RTIPolling' in i['event']: + + # Keep track of which properties are changed during polling events + for key,value in i.items(): + if 'rt_' in key: + initial_rt_event_properties.add(key) + + # Retain a copy of Polling event + polling_event = i.copy() + + # Update parameters of interest following RTI + key_first_event = {key: i[key] if key in i else value for key, value in first_event.items()} + + # Calculate age of individual at time of event + key_first_event['age_in_days_at_event'] = (i['rt_date_inj'] - initial_properties['date_of_birth']).days + + # Keep track of evolution in individual's properties + progression_properties = initial_properties.copy() + progression_properties.update(i) + + # Initialise chain of Dalys incurred + if 'rt_disability' in i: + prev_disability_incurred = i['rt_disability'] + prev_date = i['event_date'] + + else: + # Progress properties of individual, even if this event is a death + progression_properties.update(i) + + # If disability has changed as a result of this, recalculate and add previous to rolling average + if 'rt_disability' in i: + + dt_in_prev_disability = (i['event_date'] - prev_date).days + #print("Detected change in disability", i['rt_disability'], "after dt=", dt_in_prev_disability) + #print("Adding the following to the average", prev_disability_incurred, " x ", dt_in_prev_disability ) + average_disability += prev_disability_incurred*dt_in_prev_disability + total_dt_included += dt_in_prev_disability + # Update variables + prev_disability_incurred = i['rt_disability'] + prev_date = i['event_date'] + + # Update running footprint + if 'appt_footprint' in i and i['appt_footprint'] != 'Counter()': + footprint = i['appt_footprint'] + if 'Counter' in footprint: + footprint = footprint[len("Counter("):-1] + apply = eval(footprint, eval_env) + ind_Counter[i['level']].update(Counter(apply)) + + # If the individual has died, ensure chain of event is interrupted here and update rolling average of DALYs + if 'is_alive' in i and i['is_alive'] is False: + if ((i['event_date'] - polling_event['rt_date_inj']).days) > total_dt_included: + dt_in_prev_disability = (i['event_date'] - prev_date).days + average_disability += prev_disability_incurred*dt_in_prev_disability + total_dt_included += dt_in_prev_disability + break + + # check_if_beyond_time_range_considered(progression_properties) + + # Compute final properties of individual + key_last_event['is_alive_after_RTI'] = progression_properties['is_alive'] + key_last_event['duration_days'] = (progression_properties['event_date'] - polling_event['rt_date_inj']).days + + # If individual didn't die and the key_last_event didn't result in a final change in DALYs, ensure that the last change is recorded here + if not key_first_event['rt_imm_death'] and (total_dt_included < key_last_event['duration_days']): + #print("Number of events", len(list_for_individual)) + #for i in list_for_individual: + # if 'event' in i: + # print(i) + dt_in_prev_disability = (progression_properties['event_date'] - prev_date).days + average_disability += prev_disability_incurred*dt_in_prev_disability + total_dt_included += dt_in_prev_disability + + # Now calculate the average disability incurred, and store any permanent disability and total footprint + if not key_first_event['rt_imm_death'] and key_last_event['duration_days']> 0: + key_last_event['rt_disability_average'] = average_disability/key_last_event['duration_days'] + else: + key_last_event['rt_disability_average'] = 0.0 + + key_last_event['rt_disability_permanent'] = progression_properties['rt_disability'] + key_last_event.update({'total_footprint': ind_Counter}) + + if key_last_event['duration_days']!=total_dt_included: + print("The duration of event and total_dt_included don't match", key_last_event['duration_days'], total_dt_included) + exit(-1) + + properties = key_first_event | key_last_event + + record.append(properties) + """ + + df = pd.DataFrame(record) + df.to_csv("new_raw_data_" + name_tag + ".csv", index=False) + + print(df) + print(initial_rt_event_properties) + exit(-1) + #print(i) + + #dict = {} + #for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # dict[i] = [] + + #for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # event_chains = extract_results( + # results_folder, + # module='tlo.simulation'#, + # key='event_chains', + # column = str(i), + # #custom_generate_series=get_num_dalys_by_year, + # do_scaling=False + # ) + # print(event_chains) + # print(event_chains.index) + # print(event_chains.columns.levels) + + # for index, row in event_chains.iterrows(): + # if event_chains.iloc[index,0] is not None: + # if(event_chains.iloc[index,0]['person_ID']==i): #and 'event' in event_chains.iloc[index,0].keys()): + # dict[i].append(event_chains.iloc[index,0]) + #elif (event_chains.iloc[index,0]['person_ID']==i and 'event' not in event_chains.iloc[index,0].keys()): + #print(event_chains.iloc[index,0]['de_depr']) + # exit(-1) + #for item in dict[0]: + # print(item) + + #exit(-1) + + TARGET_PERIOD = (Date(min_year, 1, 1), Date(max_year, 1, 1)) + + # Definitions of general helper functions + lambda stub: output_folder / f"{stub.replace('*', '_star_')}.png" # noqa: E731 + + def target_period() -> str: + """Returns the target period as a string of the form YYYY-YYYY""" + return "-".join(str(t.year) for t in TARGET_PERIOD) + + def get_parameter_names_from_scenario_file() -> Tuple[str]: + """Get the tuple of names of the scenarios from `Scenario` class used to create the results.""" + from scripts.healthsystem.impact_of_actual_vs_funded.scenario_impact_of_actual_vs_funded import ( + ImpactOfHealthSystemMode, + ) + e = ImpactOfHealthSystemMode() + return tuple(e._scenarios.keys()) + + def get_num_deaths(_df): + """Return total number of Deaths (total within the TARGET_PERIOD) + """ + return pd.Series(data=len(_df.loc[pd.to_datetime(_df.date).between(*TARGET_PERIOD)])) + + def get_num_dalys(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year.between(*[i.year for i in TARGET_PERIOD])] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum().sum() + ) + + def get_num_dalys_by_cause(_df): + """Return number of DALYs by cause by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year.between(*[i.year for i in TARGET_PERIOD])] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum() + ) + + def set_param_names_as_column_index_level_0(_df): + """Set the columns index (level 0) as the param_names.""" + ordered_param_names_no_prefix = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ordered_param_names_no_prefix.get(col) for col in _df.columns.levels[0]] + assert len(names_of_cols_level0) == len(_df.columns.levels[0]) + _df.columns = _df.columns.set_levels(names_of_cols_level0, level=0) + return _df + + def find_difference_relative_to_comparison(_ser: pd.Series, + comparison: str, + scaled: bool = False, + drop_comparison: bool = True, + ): + """Find the difference in the values in a pd.Series with a multi-index, between the draws (level 0) + within the runs (level 1), relative to where draw = `comparison`. + The comparison is `X - COMPARISON`.""" + return _ser \ + .unstack(level=0) \ + .apply(lambda x: (x - x[comparison]) / (x[comparison] if scaled else 1.0), axis=1) \ + .drop(columns=([comparison] if drop_comparison else [])) \ + .stack() + + + def get_counts_of_hsi_by_treatment_id(_df): + """Get the counts of the short TREATMENT_IDs occurring""" + _counts_by_treatment_id = _df \ + .loc[pd.to_datetime(_df['date']).between(*TARGET_PERIOD), 'TREATMENT_ID'] \ + .apply(pd.Series) \ + .sum() \ + .astype(int) + return _counts_by_treatment_id.groupby(level=0).sum() + + year_target = 2023 + def get_counts_of_hsi_by_treatment_id_by_year(_df): + """Get the counts of the short TREATMENT_IDs occurring""" + _counts_by_treatment_id = _df \ + .loc[pd.to_datetime(_df['date']).dt.year ==year_target, 'TREATMENT_ID'] \ + .apply(pd.Series) \ + .sum() \ + .astype(int) + return _counts_by_treatment_id.groupby(level=0).sum() + + def get_counts_of_hsi_by_short_treatment_id(_df): + """Get the counts of the short TREATMENT_IDs occurring (shortened, up to first underscore)""" + _counts_by_treatment_id = get_counts_of_hsi_by_treatment_id(_df) + _short_treatment_id = _counts_by_treatment_id.index.map(lambda x: x.split('_')[0] + "*") + return _counts_by_treatment_id.groupby(by=_short_treatment_id).sum() + + def get_counts_of_hsi_by_short_treatment_id_by_year(_df): + """Get the counts of the short TREATMENT_IDs occurring (shortened, up to first underscore)""" + _counts_by_treatment_id = get_counts_of_hsi_by_treatment_id_by_year(_df) + _short_treatment_id = _counts_by_treatment_id.index.map(lambda x: x.split('_')[0] + "*") + return _counts_by_treatment_id.groupby(by=_short_treatment_id).sum() + + + # Obtain parameter names for this scenario file + param_names = get_parameter_names_from_scenario_file() + print(param_names) + + # ================================================================================================ + # TIME EVOLUTION OF TOTAL DALYs + # Plot DALYs averted compared to the ``No Policy'' policy + + year_target = 2023 # This global variable will be passed to custom function + def get_num_dalys_by_year(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year == year_target] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum().sum() + ) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + this_min_year = 2010 + for year in range(this_min_year, max_year+1): + year_target = year + num_dalys_by_year = extract_results( + results_folder, + module='tlo.methods.healthburden', + key='dalys_stacked', + custom_generate_series=get_num_dalys_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = num_dalys_by_year + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'index_original']) + concatenated_df = concatenated_df.reset_index(level='index_original',drop=True) + dalys_by_year = concatenated_df + print(dalys_by_year) + dalys_by_year.to_csv('ConvertedOutputs/Total_DALYs_with_time.csv', index=True) + + # ================================================================================================ + # Print population under each scenario + pop_model = extract_results(results_folder, + module="tlo.methods.demography", + key="population", + column="total", + index="date", + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + + pop_model.index = pop_model.index.year + pop_model = pop_model[(pop_model.index >= this_min_year) & (pop_model.index <= max_year)] + print(pop_model) + assert dalys_by_year.index.equals(pop_model.index) + assert all(dalys_by_year.columns == pop_model.columns) + pop_model.to_csv('ConvertedOutputs/Population_with_time.csv', index=True) + + # ================================================================================================ + # DALYs BROKEN DOWN BY CAUSES AND YEAR + # DALYs by cause per year + # %% Quantify the health losses associated with all interventions combined. + + year_target = 2023 # This global variable will be passed to custom function + def get_num_dalys_by_year_and_cause(_df): + """Return total number of DALYs (Stacked) by label (total within the TARGET_PERIOD)""" + return pd.Series( + data=_df + .loc[_df.year == year_target] + .drop(columns=['date', 'sex', 'age_range', 'year']) + .sum() + ) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + this_min_year = 2010 + for year in range(this_min_year, max_year+1): + year_target = year + num_dalys_by_year = extract_results( + results_folder, + module='tlo.methods.healthburden', + key='dalys_stacked', + custom_generate_series=get_num_dalys_by_year_and_cause, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = num_dalys_by_year #summarize(num_dalys_by_year) + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + + df_total = concatenated_df + df_total.to_csv('ConvertedOutputs/DALYS_by_cause_with_time.csv', index=True) + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + for year in range(min_year, max_year+1): + year_target = year + + hsi_delivered_by_year = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='HSI_Event', + custom_generate_series=get_counts_of_hsi_by_short_treatment_id_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = hsi_delivered_by_year + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + HSI_ran_by_year = concatenated_df + + del ALL + + ALL = {} + # Plot time trend show year prior transition as well to emphasise that until that point DALYs incurred + # are consistent across different policies + for year in range(min_year, max_year+1): + year_target = year + + hsi_not_delivered_by_year = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Never_ran_HSI_Event', + custom_generate_series=get_counts_of_hsi_by_short_treatment_id_by_year, + do_scaling=True + ).pipe(set_param_names_as_column_index_level_0) + ALL[year_target] = hsi_not_delivered_by_year + + # Concatenate the DataFrames into a single DataFrame + concatenated_df = pd.concat(ALL.values(), keys=ALL.keys()) + concatenated_df.index = concatenated_df.index.set_names(['date', 'cause']) + HSI_never_ran_by_year = concatenated_df + + HSI_never_ran_by_year = HSI_never_ran_by_year.fillna(0) #clean_df( + HSI_ran_by_year = HSI_ran_by_year.fillna(0) + HSI_total_by_year = HSI_ran_by_year.add(HSI_never_ran_by_year, fill_value=0) + HSI_ran_by_year.to_csv('ConvertedOutputs/HSIs_ran_by_area_with_time.csv', index=True) + HSI_never_ran_by_year.to_csv('ConvertedOutputs/HSIs_never_ran_by_area_with_time.csv', index=True) + print(HSI_ran_by_year) + print(HSI_never_ran_by_year) + print(HSI_total_by_year) + +if __name__ == "__main__": + rfp = Path('resources') + + parser = argparse.ArgumentParser( + description="Produce plots to show the impact each set of treatments", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--output-path", + help=( + "Directory to write outputs to. If not specified (set to None) outputs " + "will be written to value of --results-path argument." + ), + type=Path, + default=None, + required=False, + ) + parser.add_argument( + "--resources-path", + help="Directory containing resource files", + type=Path, + default=Path('resources'), + required=False, + ) + parser.add_argument( + "--results-path", + type=Path, + help=( + "Directory containing results from running " + "src/scripts/analysis_data_generation/scenario_generate_chains.py " + ), + default=None, + required=False + ) + args = parser.parse_args() + assert args.results_path is not None + results_path = args.results_path + + output_path = results_path if args.output_path is None else args.output_path + + apply( + results_folder=results_path, + output_folder=output_path, + resourcefilepath=args.resources_path + ) diff --git a/src/scripts/analysis_data_generation/postprocess_events_chain.py b/src/scripts/analysis_data_generation/postprocess_events_chain.py new file mode 100644 index 0000000000..96c27a04b1 --- /dev/null +++ b/src/scripts/analysis_data_generation/postprocess_events_chain.py @@ -0,0 +1,156 @@ +import pandas as pd +from dateutil.relativedelta import relativedelta + +# Remove from every individual's event chain all events that were fired after death +def cut_off_events_after_death(df): + + events_chain = df.groupby('person_ID') + + filtered_data = pd.DataFrame() + + for name, group in events_chain: + + # Find the first non-NaN 'date_of_death' and its index + first_non_nan_index = group['date_of_death'].first_valid_index() + + if first_non_nan_index is not None: + # Filter out all rows after the first non-NaN index + filtered_group = group.loc[:first_non_nan_index] # Keep rows up to and including the first valid index + filtered_data = pd.concat([filtered_data, filtered_group]) + else: + # If there are no non-NaN values, keep the original group + filtered_data = pd.concat([filtered_data, group]) + + return filtered_data + +# Load into DataFrame +def load_csv_to_dataframe(file_path): + try: + # Load raw chains into df + df = pd.read_csv(file_path) + print("Raw event chains loaded successfully!") + return df + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + except Exception as e: + print(f"An error occurred: {e}") + +file_path = 'output.csv' # Replace with the path to your CSV file + +output = load_csv_to_dataframe(file_path) + +# Some of the dates appeared not to be in datetime format. Correct here. +output['date_of_death'] = pd.to_datetime(output['date_of_death'], errors='coerce') +output['date_of_birth'] = pd.to_datetime(output['date_of_birth'], errors='coerce') +if 'hv_date_inf' in output.columns: + output['hv_date_inf'] = pd.to_datetime(output['hv_date_inf'], errors='coerce') + + +date_start = pd.to_datetime('2010-01-01') +if 'Other' in output['cause_of_death'].values: + print("ERROR: 'Other' was included in sim as possible cause of death") + exit(-1) + +# Choose which columns in individual properties to visualise +columns_to_print =['event','is_alive','hv_inf', 'hv_art','tb_inf', 'tb_date_active', 'event_date', 'when'] +#columns_to_print =['person_ID', 'date_of_birth', 'date_of_death', 'cause_of_death','hv_date_inf', 'hv_art','tb_inf', 'tb_date_active', 'event date', 'event'] + +# When checking which individuals led to *any* changes in individual properties, exclude these columns from comparison +columns_to_exclude_in_comparison = ['when', 'event', 'event_date', 'age_exact_years', 'age_years', 'age_days', 'age_range', 'level', 'appt_footprint'] + +# If considering epidemiology consistent with sim, add check here. +check_ages_of_those_HIV_inf = False +if check_ages_of_those_HIV_inf: + for index, row in output.iterrows(): + if pd.isna(row['hv_date_inf']): + continue # Skip this iteration + diff = relativedelta(output.loc[index, 'hv_date_inf'],output.loc[index, 'date_of_birth']) + if diff.years > 1 and diff.years<15: + print("Person contracted HIV infection at age younger than 15", diff) + +# Remove events after death +filtered_data = cut_off_events_after_death(output) + +print_raw_events = True # Print raw chain of events for each individual +print_selected_changes = False +print_all_changes = True +person_ID_of_interest = 494 + +pd.set_option('display.max_rows', None) + +for name, group in filtered_data.groupby('person_ID'): + list_of_dob = group['date_of_birth'] + + # Select individuals based on when they were born + if list_of_dob.iloc[0].year<2010: + + # Check that immutable properties are fixed for this individual, i.e. that events were collated properly: + all_identical_dob = group['date_of_birth'].nunique() == 1 + all_identical_sex = group['sex'].nunique() == 1 + if all_identical_dob is False or all_identical_sex is False: + print("Immutable properties are changing! This is not chain for single individual") + print(group) + exit(-1) + + print("----------------------------------------------------------------------") + print("person_ID ", group['person_ID'].iloc[0], "d.o.b ", group['date_of_birth'].iloc[0]) + print("Number of events for this individual ", group['person_ID'].iloc[0], "is :", len(group)/2) # Divide by 2 before printing Before/After for each event + number_of_events =len(group)/2 + number_of_changes=0 + if print_raw_events: + print(group) + + if print_all_changes: + # Check each row + comparison = group.drop(columns=columns_to_exclude_in_comparison).fillna(-99999).ne(group.drop(columns=columns_to_exclude_in_comparison).shift().fillna(-99999)) + + # Iterate over rows where any column has changed + for idx, row_changed in comparison.iloc[1:].iterrows(): + if row_changed.any(): # Check if any column changed in this row + number_of_changes+=1 + changed_columns = row_changed[row_changed].index.tolist() # Get the columns where changes occurred + print(f"Row {idx} - Changes detected in columns: {changed_columns}") + columns_output = ['event', 'event_date', 'appt_footprint', 'level'] + changed_columns + print(group.loc[idx, columns_output]) # Print only the changed columns + if group.loc[idx, 'when'] == 'Before': + print('-----> THIS CHANGE OCCURRED BEFORE EVENT!') + #print(group.loc[idx,columns_to_print]) + print() # For better readability + print("Number of changes is ", number_of_changes, "out of ", number_of_events, " events") + + if print_selected_changes: + tb_inf_condition = ( + ((group['tb_inf'].shift(1) == 'uninfected') & (group['tb_inf'] == 'active')) | + ((group['tb_inf'].shift(1) == 'latent') & (group['tb_inf'] == 'active')) | + ((group['tb_inf'].shift(1) == 'active') & (group['tb_inf'] == 'latent')) | + ((group['hv_inf'].shift(1) is False) & (group['hv_inf'] is True)) | + ((group['hv_art'].shift(1) == 'not') & (group['hv_art'] == 'on_not_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'not') & (group['hv_art'] == 'on_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_VL_suppressed') & (group['hv_art'] == 'on_not_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_VL_suppressed') & (group['hv_art'] == 'not')) | + ((group['hv_art'].shift(1) == 'on_not_VL_suppressed') & (group['hv_art'] == 'on_VL_suppressed')) | + ((group['hv_art'].shift(1) == 'on_not_VL_suppressed') & (group['hv_art'] == 'not')) + ) + + alive_condition = ( + (group['is_alive'].shift(1) is True) & (group['is_alive'] is False) + ) + # Combine conditions for rows of interest + transition_condition = tb_inf_condition | alive_condition + + if list_of_dob.iloc[0].year >= 2010: + print("DETECTED OF INTEREST") + print(group[group['event'] == 'Birth'][columns_to_print]) + + # Filter the DataFrame based on the condition + filtered_transitions = group[transition_condition] + if not filtered_transitions.empty: + if list_of_dob.iloc[0].year < 2010: + print("DETECTED OF INTEREST") + print(filtered_transitions[columns_to_print]) + + +print("Number of individuals simulated ", filtered_data.groupby('person_ID').ngroups) + + + diff --git a/src/scripts/analysis_data_generation/scenario_generate_chains.py b/src/scripts/analysis_data_generation/scenario_generate_chains.py new file mode 100644 index 0000000000..b327a23323 --- /dev/null +++ b/src/scripts/analysis_data_generation/scenario_generate_chains.py @@ -0,0 +1,188 @@ +"""This Scenario file run the model to generate event chans + +Run on the batch system using: +``` +tlo batch-submit + src/scripts/analysis_data_generation/scenario_generate_chains.py +``` + +or locally using: +``` + tlo scenario-run src/scripts/analysis_data_generation/scenario_generate_chains.py +``` + +""" +from pathlib import Path +from typing import Dict + +import pandas as pd + +from tlo import Date, logging +from tlo.analysis.utils import get_parameters_for_status_quo, mix_scenarios, get_filtered_treatment_ids +from tlo.methods.fullmodel import fullmodel +from tlo.methods.scenario_switcher import ImprovedHealthSystemAndCareSeekingScenarioSwitcher +from tlo.scenario import BaseScenario +from tlo.methods import ( + alri, + cardio_metabolic_disorders, + care_of_women_during_pregnancy, + contraception, + demography, + depression, + diarrhoea, + enhanced_lifestyle, + epi, + healthburden, + healthseekingbehaviour, + healthsystem, + hiv, + rti, + labour, + malaria, + newborn_outcomes, + postnatal_supervisor, + pregnancy_supervisor, + stunting, + symptommanager, + tb, + wasting, +) + +class GenerateDataChains(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 42 + self.start_date = Date(2010, 1, 1) + self.end_date = self.start_date + pd.DateOffset(months=1) + self.pop_size = 1000 + self._scenarios = self._get_scenarios() + self.number_of_draws = len(self._scenarios) + self.runs_per_draw = 3 + self.generate_event_chains = True + + def log_configuration(self): + return { + 'filename': 'generate_event_chains', + 'directory': Path('./outputs'), # <- (specified only for local running) + 'custom_levels': { + '*': logging.WARNING, + 'tlo.methods.demography': logging.INFO, + 'tlo.methods.events': logging.INFO, + 'tlo.methods.demography.detail': logging.WARNING, + 'tlo.methods.healthburden': logging.INFO, + 'tlo.methods.healthsystem.summary': logging.INFO, + } + } + + def modules(self): + # MODIFY + # Here instead of running full module + return [demography.Demography(resourcefilepath=self.resources), + enhanced_lifestyle.Lifestyle(resourcefilepath=self.resources), + healthburden.HealthBurden(resourcefilepath=self.resources), + symptommanager.SymptomManager(resourcefilepath=self.resources, spurious_symptoms=False),#, + #rti.RTI(resourcefilepath=self.resources), + pregnancy_supervisor.PregnancySupervisor(resourcefilepath=self.resources), + labour.Labour(resourcefilepath=self.resources), + care_of_women_during_pregnancy.CareOfWomenDuringPregnancy(resourcefilepath=self.resources), + contraception.Contraception(resourcefilepath=self.resources), + newborn_outcomes.NewbornOutcomes(resourcefilepath=self.resources), + postnatal_supervisor.PostnatalSupervisor(resourcefilepath=self.resources), + hiv.Hiv(resourcefilepath=self.resources), + tb.Tb(resourcefilepath=self.resources), + epi.Epi(resourcefilepath=self.resources), + healthseekingbehaviour.HealthSeekingBehaviour(resourcefilepath=self.resources), + #simplified_births.SimplifiedBirths(resourcefilepath=resourcefilepath), + healthsystem.HealthSystem(resourcefilepath=self.resources, + mode_appt_constraints=1, + cons_availability='all')] + #return ( + # fullmodel(resourcefilepath=self.resources) + # +# [ImprovedHealthSystemAndCareSeekingScenarioSwitcher(resourcefilepath=self.resources)] + #) + """ + def draw_parameters(self, draw_number, rng): + return mix_scenarios( + get_parameters_for_status_quo(), + { + 'HealthSystem': { + 'Service_Availability': list(self._scenarios.values())[draw_number], + }, + } + ) + + def _get_scenarios(self) -> Dict[str, list[str]]: + Return the Dict with values for the parameter `Service_Availability` keyed by a name for the scenario. + The sequences of scenarios systematically omits one of the TREATMENT_ID's that is defined in the model. + + # Generate list of TREATMENT_IDs and filter to the resolution needed + treatments = get_filtered_treatment_ids(depth=2) + treatments_RTI = [item for item in treatments if 'Rti' in item] + + # Return 'Service_Availability' values, with scenarios for everything, nothing, and ones for which each + # treatment is omitted + service_availability = dict({"Everything": ["*", "Nothing": []}) + #service_availability.update( + # {f"No {t.replace('_*', '*')}": [x for x in treatments if x != t] for t in treatments_RTI} + #) + + return service_availability + + """ + def draw_parameters(self, draw_number, rng): + if draw_number < self.number_of_draws: + return list(self._scenarios.values())[draw_number] + else: + return + + # case 1: gfHE = -0.030, factor = 1.01074 + # case 2: gfHE = -0.020, factor = 1.02116 + # case 3: gfHE = -0.015, factor = 1.02637 + # case 4: gfHE = 0.015, factor = 1.05763 + # case 5: gfHE = 0.020, factor = 1.06284 + # case 6: gfHE = 0.030, factor = 1.07326 + + def _get_scenarios(self) -> Dict[str, Dict]: + #Return the Dict with values for the parameters that are changed, keyed by a name for the scenario. + + treatments = get_filtered_treatment_ids(depth=2) + treatments_RTI = [item for item in treatments if 'Rti' in item] + + # Return 'Service_Availability' values, with scenarios for everything, nothing, and ones for which each + # treatment is omitted + service_availability = dict({"Everything": ["*"], "Nothing": []}) + service_availability.update( + {f"No {t.replace('_*', '*')}": [x for x in treatments if x != t] for t in treatments_RTI} + ) + print(service_availability.keys()) + + return { + # =========== STATUS QUO ============ + "Baseline": + mix_scenarios( + self._baseline(), + { + "HealthSystem": { + "Service_Availability": service_availability["No Rti_ShockTreatment*"], + }, + } + ), + + } + + def _baseline(self) -> Dict: + #Return the Dict with values for the parameter changes that define the baseline scenario. + return mix_scenarios( + get_parameters_for_status_quo(), + { + "HealthSystem": { + "mode_appt_constraints": 1, # <-- Mode 1 prior to change to preserve calibration + "cons_availability": "all", + } + }, + ) + +if __name__ == '__main__': + from tlo.cli import scenario_run + + scenario_run([__file__]) diff --git a/src/scripts/maternal_perinatal_analyses/cohort_analysis/cohort_interventions_group_scenario.py b/src/scripts/maternal_perinatal_analyses/cohort_analysis/cohort_interventions_group_scenario.py new file mode 100644 index 0000000000..587f4e1e85 --- /dev/null +++ b/src/scripts/maternal_perinatal_analyses/cohort_analysis/cohort_interventions_group_scenario.py @@ -0,0 +1,81 @@ +import numpy as np +import pandas as pd + +from pathlib import Path + +from tlo import Date, logging +from tlo.methods import mnh_cohort_module +from tlo.methods.fullmodel import fullmodel +from tlo.scenario import BaseScenario + + +class BaselineScenario(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 562661 + self.start_date = Date(2024, 1, 1) + self.end_date = Date(2025, 1, 2) + self.pop_size = 12_000 + self.number_of_draws = 11 + self.runs_per_draw = 20 + + def log_configuration(self): + return { + 'filename': 'block_intervention_group_test', 'directory': './outputs', + "custom_levels": { + "*": logging.WARNING, + "tlo.methods.demography": logging.INFO, + "tlo.methods.demography.detail": logging.INFO, + "tlo.methods.contraception": logging.INFO, + "tlo.methods.healthsystem.summary": logging.INFO, + "tlo.methods.healthburden": logging.INFO, + "tlo.methods.labour": logging.INFO, + "tlo.methods.labour.detail": logging.INFO, + "tlo.methods.newborn_outcomes": logging.INFO, + "tlo.methods.care_of_women_during_pregnancy": logging.INFO, + "tlo.methods.pregnancy_supervisor": logging.INFO, + "tlo.methods.postnatal_supervisor": logging.INFO, + } + } + + def modules(self): + return [*fullmodel(resourcefilepath=self.resources, + module_kwargs={'Schisto': {'mda_execute': False}}), + mnh_cohort_module.MaternalNewbornHealthCohort(resourcefilepath=self.resources)] + + def draw_parameters(self, draw_number, rng): + if draw_number == 0: + return {'PregnancySupervisor': { + 'analysis_year': 2024}} + else: + intervention_groups = [['oral_antihypertensives', 'iv_antihypertensives', 'mgso4'], + ['oral_antihypertensives', 'iv_antihypertensives', 'mgso4'], + + ['amtsl', 'pph_treatment_uterotonics', 'pph_treatment_mrrp'], + ['amtsl', 'pph_treatment_uterotonics', 'pph_treatment_mrrp'], + + ['post_abortion_care_core', 'ectopic_pregnancy_treatment'], + ['post_abortion_care_core', 'ectopic_pregnancy_treatment'], + + ['caesarean_section', 'blood_transfusion', 'pph_treatment_surgery'], + ['caesarean_section', 'blood_transfusion', 'pph_treatment_surgery'], + + ['abx_for_prom', 'sepsis_treatment', 'birth_kit'], + ['abx_for_prom', 'sepsis_treatment', 'birth_kit']] + + avail_for_draw = [0.0, 1.0, + 0.0, 1.0, + 0.0, 1.0, + 0.0, 1.0, + 0.0, 1.0] + + return {'PregnancySupervisor': { + 'analysis_year': 2024, + 'interventions_analysis': True, + 'interventions_under_analysis':intervention_groups[draw_number-1], + 'intervention_analysis_availability': avail_for_draw[draw_number-1]}} + + +if __name__ == '__main__': + from tlo.cli import scenario_run + scenario_run([__file__]) diff --git a/src/scripts/maternal_perinatal_analyses/cohort_analysis/cohort_interventions_scenario.py b/src/scripts/maternal_perinatal_analyses/cohort_analysis/cohort_interventions_scenario.py new file mode 100644 index 0000000000..926fca6b6e --- /dev/null +++ b/src/scripts/maternal_perinatal_analyses/cohort_analysis/cohort_interventions_scenario.py @@ -0,0 +1,68 @@ +import numpy as np +import pandas as pd + +from pathlib import Path + +from tlo import Date, logging +from tlo.methods import mnh_cohort_module +from tlo.methods.fullmodel import fullmodel +from tlo.scenario import BaseScenario + + +class BaselineScenario(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 796967 + self.start_date = Date(2024, 1, 1) + self.end_date = Date(2025, 1, 2) + self.pop_size = 30_000 + self.number_of_draws = 7 + self.runs_per_draw = 60 + + def log_configuration(self): + return { + 'filename': 'block_intervention_big_pop_test', 'directory': './outputs', + "custom_levels": { + "*": logging.WARNING, + "tlo.methods.demography": logging.INFO, + "tlo.methods.demography.detail": logging.INFO, + "tlo.methods.contraception": logging.INFO, + "tlo.methods.healthsystem.summary": logging.INFO, + "tlo.methods.healthburden": logging.INFO, + "tlo.methods.labour": logging.INFO, + "tlo.methods.labour.detail": logging.INFO, + "tlo.methods.newborn_outcomes": logging.INFO, + "tlo.methods.care_of_women_during_pregnancy": logging.INFO, + "tlo.methods.pregnancy_supervisor": logging.INFO, + "tlo.methods.postnatal_supervisor": logging.INFO, + } + } + + def modules(self): + return [*fullmodel(resourcefilepath=self.resources, + module_kwargs={'Schisto': {'mda_execute': False}}), + mnh_cohort_module.MaternalNewbornHealthCohort(resourcefilepath=self.resources)] + + def draw_parameters(self, draw_number, rng): + if draw_number == 0: + return {'PregnancySupervisor': { + 'analysis_year': 2024}} + else: + interventions_for_analysis = ['bp_measurement','bp_measurement', + 'post_abortion_care_core', 'post_abortion_care_core', + 'ectopic_pregnancy_treatment', 'ectopic_pregnancy_treatment'] + + avail_for_draw = [0.0, 1.0, + 0.0, 1.0, + 0.0, 1.0] + + return {'PregnancySupervisor': { + 'analysis_year': 2024, + 'interventions_analysis': True, + 'interventions_under_analysis':[interventions_for_analysis[draw_number-1]], + 'intervention_analysis_availability': avail_for_draw[draw_number-1]}} + + +if __name__ == '__main__': + from tlo.cli import scenario_run + scenario_run([__file__]) diff --git a/src/scripts/maternal_perinatal_analyses/cohort_analysis/dummy_cohort_azure_calib.py b/src/scripts/maternal_perinatal_analyses/cohort_analysis/dummy_cohort_azure_calib.py new file mode 100644 index 0000000000..abc6844cb8 --- /dev/null +++ b/src/scripts/maternal_perinatal_analyses/cohort_analysis/dummy_cohort_azure_calib.py @@ -0,0 +1,237 @@ +import os + +import pandas as pd + +import matplotlib.pyplot as plt +import numpy as np + +from tlo.analysis.utils import extract_results, get_scenario_outputs, summarize + +outputspath = './outputs/sejjj49@ucl.ac.uk/' + +scenario = 'block_intervention_test-2024-11-06T145016Z' +ordered_interventions = ['oral_antihypertensives', 'iv_antihypertensives', 'mgso4', 'post_abortion_care_core'] + +intervention_groups = [] +draws = [] + +results_folder= get_scenario_outputs(scenario, outputspath)[-1] + +def get_ps_data_frames(key, results_folder): + def sort_df(_df): + _x = _df.drop(columns=['date'], inplace=False) + return _x.iloc[0] + + results_df = extract_results( + results_folder, + module="tlo.methods.pregnancy_supervisor", + key=key, + custom_generate_series=sort_df, + do_scaling=False + ) + results_df_summ = summarize(results_df) + + return [results_df, results_df_summ] + +all_dalys_dfs = extract_results( + results_folder, + module="tlo.methods.healthburden", + key="dalys_stacked", + custom_generate_series=( + lambda df: df.drop( + columns=['date', 'sex', 'age_range']).groupby(['year']).sum().stack()), + do_scaling=True) + +mat_disorders_all = all_dalys_dfs.loc[(slice(None), 'Maternal Disorders'), :] + +mat_dalys_df = mat_disorders_all.loc[2024] +mat_dalys_df_sum = summarize(mat_dalys_df) + +results = {k:get_ps_data_frames(k, results_folder)[0] for k in + ['mat_comp_incidence', 'nb_comp_incidence', 'deaths_and_stillbirths','service_coverage', + 'yearly_mnh_counter_dict']} + +results_sum = {k:get_ps_data_frames(k, results_folder)[1] for k in + ['mat_comp_incidence', 'nb_comp_incidence', 'deaths_and_stillbirths','service_coverage', + 'yearly_mnh_counter_dict']} + + +def get_data(df, key, draw): + return (df.loc[key, (draw, 'lower')], + df.loc[key, (draw, 'mean')], + df.loc[key, (draw, 'upper')]) + +mmrs_min = {f'{k}_min':get_data(results_sum['deaths_and_stillbirths'], d) for k, d in zip (ordered_interventions, draws) } +mmrs_max = { } + +mmrs = {'baseline':get_data(results_sum['deaths_and_stillbirths'], 0), + # 'oral_antihypertensives_min':get_data(results_sum['deaths_and_stillbirths'], 1), + # 'oral_antihypertensives_max': get_data(results_sum['deaths_and_stillbirths'], 2), + # 'iv_antihypertensives_min':get_data(results_sum['deaths_and_stillbirths'], 3), + # 'iv_antihypertensives_max': get_data(results_sum['deaths_and_stillbirths'], 4), + # 'amtsl_min':get_data(results_sum['deaths_and_stillbirths'], 5), + # 'amtsl_max': get_data(results_sum['deaths_and_stillbirths'], 6), + 'mgso4_min':get_data(results_sum['deaths_and_stillbirths'], 7), + 'mgso4_max': get_data(results_sum['deaths_and_stillbirths'], 8), + # 'post_abortion_care_core_min':get_data(results_sum['deaths_and_stillbirths'], 9), + # 'post_abortion_care_core_max': get_data(results_sum['deaths_and_stillbirths'], 10), + # 'caesarean_section_min':get_data(results_sum['deaths_and_stillbirths'], 11), + # 'caesarean_section_max': get_data(results_sum['deaths_and_stillbirths'], 12), + # 'ectopic_pregnancy_treatment_min':get_data(results_sum['deaths_and_stillbirths'], 13), + # 'ectopic_pregnancy_treatment_max': get_data(results_sum['deaths_and_stillbirths'], 14), + } + + +def get_mmr_diffs(df, draws): + diff_results = {} + baseline = results['deaths_and_stillbirths'][0] + + for draw in draws: + # diff_df = ((results['deaths_and_stillbirths'][draw] - baseline)/baseline) * 100 + diff_df = results['deaths_and_stillbirths'][draw] - baseline + diff_df.columns = pd.MultiIndex.from_tuples([(draw, v) for v in range(len(diff_df.columns))], + names=['draw', 'run']) + results_diff = summarize(diff_df) + results_diff.fillna(0) + diff_results.update({draw: results_diff}) + + return diff_results + +# MMR + + + +# Maternal deaths +# DALYs + + + +mmrs = {'baseline':get_data(results_sum['deaths_and_stillbirths'], 0), + 'oral_antihypertensives_min':get_data(results_sum['deaths_and_stillbirths'], 1), + 'oral_antihypertensives_max': get_data(results_sum['deaths_and_stillbirths'], 2), + 'iv_antihypertensives_min':get_data(results_sum['deaths_and_stillbirths'], 3), + 'iv_antihypertensives_max': get_data(results_sum['deaths_and_stillbirths'], 4), + 'amtsl_min':get_data(results_sum['deaths_and_stillbirths'], 5), + 'amtsl_max': get_data(results_sum['deaths_and_stillbirths'], 6), + 'mgso4_min':get_data(results_sum['deaths_and_stillbirths'], 7), + 'mgso4_max': get_data(results_sum['deaths_and_stillbirths'], 8), + 'post_abortion_care_core_min':get_data(results_sum['deaths_and_stillbirths'], 9), + 'post_abortion_care_core_max': get_data(results_sum['deaths_and_stillbirths'], 10), + 'caesarean_section_min':get_data(results_sum['deaths_and_stillbirths'], 11), + 'caesarean_section_max': get_data(results_sum['deaths_and_stillbirths'], 12), + 'ectopic_pregnancy_treatment_min':get_data(results_sum['deaths_and_stillbirths'], 13), + 'ectopic_pregnancy_treatment_max': get_data(results_sum['deaths_and_stillbirths'], 14), + } + + +diff_results = get_mmr_diffs(results, [7,8]) + + +results_diff = {#'oral_antihypertensives_min':get_data(diff_results[1], 1), +# 'oral_antihypertensives_max':get_data(diff_results[2], 2), +# 'iv_antihypertensives_min':get_data(diff_results[3], 3), +# 'iv_antihypertensives_max': get_data(diff_results[4], 4), +# 'amtsl_min':get_data(diff_results[5], 5), +# 'amtsl_max': get_data(diff_results[6], 6), + 'mgso4_min':get_data(diff_results[7], 7), + 'mgso4_max':get_data(diff_results[8], 8), + # 'post_abortion_care_core_min':get_data(diff_results[9], 9), + # 'post_abortion_care_core_max': get_data(diff_results[10], 10), + # 'caesarean_section_min':get_data(diff_results[11], 11), + # 'caesarean_section_max': get_data(diff_results[12], 12), + # 'ectopic_pregnancy_treatment_min':get_data(diff_results[13], 13), + # 'ectopic_pregnancy_treatment_max': get_data(diff_results[14], 14) + } + +# todo: compare deaths with demography logging... + +data = mmrs + +# Extract means and errors +labels = data.keys() +means = [vals[1] for vals in data.values()] +lower_errors = [vals[1] - vals[0] for vals in data.values()] +upper_errors = [vals[2] - vals[1] for vals in data.values()] +errors = [lower_errors, upper_errors] + +# Create bar chart with error bars +fig, ax = plt.subplots() +ax.bar(labels, means, yerr=errors, capsize=5, alpha=0.7, ecolor='black') +ax.set_ylabel('MMR') +ax.set_title('Average MMR under each scenario') + +# Adjust label size +plt.xticks(fontsize=8, rotation=90) +plt.tight_layout() +plt.show() + +# +# # Example data with uncertainties +# parameters = ['Blood Transfusion', 'Uterotonics', 'Sepsis treatment'] +# base_value = results_sum['deaths_and_stillbirths'].at['direct_mmr', (0, 'mean')] # base case value for the output variable +# high_values = [results_sum['deaths_and_stillbirths'].at['direct_mmr', (1, 'mean')], +# results_sum['deaths_and_stillbirths'].at['direct_mmr', (3, 'mean')], +# results_sum['deaths_and_stillbirths'].at['direct_mmr', (5, 'mean')]] # lower-bound values for each parameter +# low_values = [results_sum['deaths_and_stillbirths'].at['direct_mmr', (2, 'mean')], +# results_sum['deaths_and_stillbirths'].at['direct_mmr', (4, 'mean')], +# results_sum['deaths_and_stillbirths'].at['direct_mmr', (6, 'mean')]] # upper-bound values for each parameter +# +# # Calculate deltas from base value +# low_deltas = [base_value - lv for lv in low_values] +# high_deltas = [hv - base_value for hv in high_values] +# +# # Sort parameters by absolute impact +# abs_impacts = [abs(low) + abs(high) for low, high in zip(low_deltas, high_deltas)] +# sorted_indices = np.argsort(abs_impacts)[::-1] +# parameters = [parameters[i] for i in sorted_indices] +# low_deltas = [low_deltas[i] for i in sorted_indices] +# high_deltas = [high_deltas[i] for i in sorted_indices] +# +# # Calculate changes from the base case +# low_deltas = [base_value - lv for lv in low_values] +# high_deltas = [hv - base_value for hv in high_values] +# +# # Sort parameters by absolute impact (for a tornado effect) +# abs_impacts = [abs(low) + abs(high) for low, high in zip(low_deltas, high_deltas)] +# sorted_indices = np.argsort(abs_impacts)[::-1] +# parameters = [parameters[i] for i in sorted_indices] +# low_deltas = [low_deltas[i] for i in sorted_indices] +# high_deltas = [high_deltas[i] for i in sorted_indices] +# +# # Plotting +# fig, ax = plt.subplots(figsize=(8, 6)) +# +# # Plot each bar for the low and high values +# for i, (param, low, high) in enumerate(zip(parameters, low_deltas, high_deltas)): +# ax.barh(param, high, left=base_value, color='skyblue') +# ax.barh(param, low, left=base_value + low, color='salmon') +# +# # Reference line for base value +# ax.axvline(base_value, color='black', linestyle='--', label="Base Value") +# +# # Labels and title +# ax.set_xlabel('Output Variable') +# ax.set_title('Tornado Plot') +# plt.legend(['Base Value']) +# plt.show() + +# import matplotlib.pyplot as plt +# import numpy as np +# # Sample data +# mmr_data = { +# 'int_1': [(235, 250, 265), (335, 350, 365)], +# 'int_2': [(170, 195, 200), (290, 305, 320)], +# 'int_3': [(280, 295, 310), (295 ,310, 325)], +# 'int_4': [(165, 180, 195), (385, 400, 415)] +# } +# # Plotting +# fig, ax = plt.subplots() +# for key, intervals in mmr_data.items(): +# for idx, (lower, mean, upper) in enumerate(intervals): +# x = np.arange(len(mmr_data)) * len(intervals) + idx +# ax.plot(x, mean, 'o', label=f'{key}' if idx == 0 else "") +# ax.fill_between([x, x], [lower, lower], [upper, upper], alpha=0.2) +# ax.set_xticks(np.arange(len(mmr_data)) * len(intervals) + 0.5) +# ax.set_xticklabels(mmr_data.keys()) +# plt.legend() +# plt.show() diff --git a/src/scripts/maternal_perinatal_analyses/cohort_analysis/int_analysis_script.py b/src/scripts/maternal_perinatal_analyses/cohort_analysis/int_analysis_script.py new file mode 100644 index 0000000000..a0d2f1570d --- /dev/null +++ b/src/scripts/maternal_perinatal_analyses/cohort_analysis/int_analysis_script.py @@ -0,0 +1,255 @@ +import os +import scipy.stats as st +from scipy.stats import t, norm, shapiro + +import pandas as pd + +import matplotlib.pyplot as plt +import numpy as np + +from tlo.analysis.utils import extract_results, get_scenario_outputs, summarize, create_pickles_locally, get_scenario_info + +outputspath = './outputs/sejjj49@ucl.ac.uk/' + +scenario = 'block_intervention_big_pop_test-2024-11-27T110117Z' +results_folder= get_scenario_outputs(scenario, outputspath)[-1] +# create_pickles_locally(results_folder, compressed_file_name_prefix='block_intervention_big_pop_test') + +interventions =['bp_measurement', 'post_abortion_care_core', 'ectopic_pregnancy_treatment'] + +int_analysis = ['baseline'] + +for i in interventions: + int_analysis.append(f'{i}_min') + int_analysis.append(f'{i}_max') + +draws = [x for x in range(len(int_analysis))] + +def summarize_confidence_intervals(results: pd.DataFrame) -> pd.DataFrame: + """Utility function to compute summary statistics + + Finds mean value and 95% interval across the runs for each draw. + """ + + # Calculate summary statistics + grouped = results.groupby(axis=1, by='draw', sort=False) + mean = grouped.mean() + sem = grouped.sem() # Standard error of the mean + + # Calculate the critical value for a 95% confidence level + n = grouped.size().max() # Assuming the largest group size determines the degrees of freedom + critical_value = t.ppf(0.975, df=n - 1) # Two-tailed critical value + + # Compute the margin of error + margin_of_error = critical_value * sem + + # Compute confidence intervals + lower = mean - margin_of_error + upper = mean + margin_of_error + + # Combine into a single DataFrame + summary = pd.concat({'mean': mean, 'lower': lower, 'upper': upper}, axis=1) + + # Format the DataFrame as in the original code + summary.columns = summary.columns.swaplevel(1, 0) + summary.columns.names = ['draw', 'stat'] + summary = summary.sort_index(axis=1) + + return summary + +# Access dataframes generated from pregnancy supervisor +def get_ps_data_frames(key, results_folder): + def sort_df(_df): + _x = _df.drop(columns=['date'], inplace=False) + return _x.iloc[0] + + results_df = extract_results( + results_folder, + module="tlo.methods.pregnancy_supervisor", + key=key, + custom_generate_series=sort_df, + do_scaling=False + ) + results_df_summ = summarize_confidence_intervals(results_df) + + return {'crude':results_df, 'summarised':results_df_summ} + +results = {k:get_ps_data_frames(k, results_folder) for k in + ['mat_comp_incidence', 'nb_comp_incidence', 'deaths_and_stillbirths','service_coverage', + 'yearly_mnh_counter_dict']} + +direct_deaths = extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=( + lambda df: df.loc[(df['label'] == 'Maternal Disorders')].assign( + year=df['date'].dt.year).groupby(['year'])['year'].count()), + do_scaling=False) + +br = extract_results( + results_folder, + module="tlo.methods.demography", + key="on_birth", + custom_generate_series=( + lambda df: df.assign( + year=df['date'].dt.year).groupby(['year'])['year'].count()), + do_scaling=False + ) + +dd_sum = summarize_confidence_intervals(direct_deaths) +dd_mmr = (direct_deaths/br) * 100_000 +dd_mr_sum = summarize_confidence_intervals(dd_mmr) + +all_dalys_dfs = extract_results( + results_folder, + module="tlo.methods.healthburden", + key="dalys_stacked", + custom_generate_series=( + lambda df: df.drop( + columns=['date', 'sex', 'age_range']).groupby(['year']).sum().stack()), + do_scaling=False) + +mat_disorders_all = all_dalys_dfs.loc[(slice(None), 'Maternal Disorders'), :] + +mat_dalys_df = mat_disorders_all.loc[2024] +mat_dalys_df_sum = summarize_confidence_intervals(mat_dalys_df) + +results.update({'dalys':{'crude': mat_dalys_df, 'summarised': mat_dalys_df_sum}}) + +# Summarised results +def get_data(df, key, draw): + return (df.loc[key, (draw, 'lower')], + df.loc[key, (draw, 'mean')], + df.loc[key, (draw, 'upper')]) + +dalys_by_scenario = {k: get_data(results['dalys']['summarised'], 'Maternal Disorders', d) for k, d in zip ( + int_analysis, draws)} + +mmr_by_scnario = {k: get_data(results['deaths_and_stillbirths']['summarised'], 'direct_mmr', d) for k, d in zip ( + int_analysis, draws)} + +mmr_by_scenario_oth_log = {k: get_data(dd_mr_sum, 2024, d) for k, d in zip ( + int_analysis, draws)} + +def barcharts(data, y_label, title): + + # Extract means and errors + labels = data.keys() + means = [vals[1] for vals in data.values()] + # lower_errors = [vals[0] for vals in data.values()] + # upper_errors = [vals[2] for vals in data.values()] + + lower_errors = [vals[1] - vals[0] for vals in data.values()] + upper_errors = [vals[2] - vals[1] for vals in data.values()] + errors = [lower_errors, upper_errors] + + # Create bar chart with error bars + fig, ax = plt.subplots() + ax.bar(labels, means, yerr=errors, capsize=5, alpha=0.7, ecolor='black') + ax.set_ylabel(y_label) + ax.set_title(title) + + # Adjust label size + plt.xticks(fontsize=8, rotation=90) + plt.tight_layout() + plt.show() + +barcharts(dalys_by_scenario, 'DALYs', 'Total Maternal Disorders DALYs by scenario') +barcharts(mmr_by_scnario, 'MMR', 'Total MMR by scenario') +barcharts(mmr_by_scnario, 'MMR', 'Total MMR by scenario') + +# Difference results +def get_diffs(df_key, result_key, ints, draws): + diff_results = {} + baseline = results[df_key]['crude'][0] + + for draw, int in zip(draws, ints): + diff_df = results[df_key]['crude'][draw] - baseline + diff_df.columns = pd.MultiIndex.from_tuples([(draw, v) for v in range(len(diff_df.columns))], + names=['draw', 'run']) + results_diff = summarize_confidence_intervals(diff_df) + results_diff.fillna(0) + diff_results.update({int: results_diff.loc[result_key].values}) + + return [diff_results, diff_df] + +diff_results = {} +baseline = dd_mmr[0] + +for draw, int in zip(draws, int_analysis): + diff_df = dd_mmr[draw] - baseline + diff_df.columns = pd.MultiIndex.from_tuples([(draw, v) for v in range(len(diff_df.columns))], + names=['draw', 'run']) + results_diff = summarize_confidence_intervals(diff_df) + results_diff.fillna(0) + diff_results.update({int: results_diff.loc[2024].values}) + + +mat_deaths = get_diffs('deaths_and_stillbirths', 'direct_maternal_deaths', int_analysis, draws)[0] +mmr_diffs = get_diffs('deaths_and_stillbirths', 'direct_mmr', int_analysis, draws)[0] +dalys_diffs = get_diffs('dalys', 'Maternal Disorders', int_analysis, draws)[0] +mat_deaths_2 = diff_results + +def get_diff_plots(data, outcome): + categories = list(data.keys()) + mins = [arr[0] for arr in data.values()] + means = [arr[1] for arr in data.values()] + maxs = [arr[2] for arr in data.values()] + + # Error bars (top and bottom of the uncertainty interval) + errors = [(mean - min_val, max_val - mean) for mean, min_val, max_val in zip(means, mins, maxs)] + errors = np.array(errors).T + + # todo: the error bars are slightly off... + + # Plotting + plt.figure(figsize=(12, 6)) + plt.errorbar(categories, means, yerr=errors, fmt='o', capsize=5) + plt.axhline(0, color='gray', linestyle='--') # Adding a horizontal line at y=0 for reference + plt.xticks(rotation=90) + plt.xlabel('Scenarios') + plt.ylabel('Crude Difference from Baseline Scenario') + plt.title(f'Difference of {outcome} from Baseline Scenario') + plt.grid(True) + plt.tight_layout() + plt.show() + +get_diff_plots(mmr_diffs, 'MMR') +get_diff_plots(mat_deaths, 'Maternal Deaths (crude)') +get_diff_plots(mat_deaths_2, 'MMR (demog log)') +get_diff_plots(dalys_diffs, 'Maternal DALYs') + + +# NORMALITY OF MMR ESTIMATES ACROSS RUNS (NOT DIFFERENCES) +for draw in draws: + data = results['deaths_and_stillbirths']['crude'].loc['direct_mmr', draw].values + + # Importing Shapiro-Wilk test for normality + # Conducting Shapiro-Wilk test + stat, p_value = shapiro(data) + # Plotting histogram + plt.hist(data, bins=15, density=True, alpha=0.6, color='skyblue', edgecolor='black') + + # Overlay normal distribution (optional) + mean, std = np.mean(data), np.std(data) + xmin, xmax = plt.xlim() + x = np.linspace(xmin, xmax, 100) + p = norm.pdf(x, mean, std) + plt.axvline(mean, color='green', linestyle='-', linewidth=2, label='Data Mean') + plt.plot(x, p, 'r--', linewidth=2, label='Normal Curve') + + # Adding labels and legend + plt.title(f'MMR data Histogram with Normality Test (p-value = {p_value:.4f}) (Draw {draw})') + plt.xlabel('Value') + plt.ylabel('Density') + plt.legend() + # Show plot + plt.show() + # Printing Shapiro-Wilk test results + print(f"Shapiro-Wilk Test Statistic: {stat:.4f}, p-value: {p_value:.4f}") + if p_value > 0.05: + print("Result: Data likely follows a normal distribution (p > 0.05).") + else: + print("Result: Data likely does not follow a normal distribution (p ≤ 0.05).") + diff --git a/src/scripts/maternal_perinatal_analyses/cohort_analysis/local_run_cohort_calibration.py b/src/scripts/maternal_perinatal_analyses/cohort_analysis/local_run_cohort_calibration.py new file mode 100644 index 0000000000..a93cdfb61a --- /dev/null +++ b/src/scripts/maternal_perinatal_analyses/cohort_analysis/local_run_cohort_calibration.py @@ -0,0 +1,116 @@ +import os +from pathlib import Path + +import pandas as pd + +from tlo import Date, Simulation, logging +from tlo.methods import mnh_cohort_module +from tlo.methods.fullmodel import fullmodel +from tlo.analysis.utils import parse_log_file + + +resourcefilepath = Path('./resources') +outputpath = Path("./outputs/cohort_testing") # folder for convenience of storing outputs +population_size = 2000 + +sim = Simulation(start_date=Date(2024, 1, 1), + seed=456, + log_config={"filename": "log_cohort_calibration", + "custom_levels": {"*": logging.DEBUG}, + "directory": outputpath}) + +sim.register(*fullmodel(resourcefilepath=resourcefilepath), + mnh_cohort_module.MaternalNewbornHealthCohort(resourcefilepath=resourcefilepath)) + +sim.make_initial_population(n=population_size) +sim.simulate(end_date=Date(2025, 1, 1)) + +output = parse_log_file(sim.log_filepath) + +# output = parse_log_file( +# '/Users/j_collins/PycharmProjects/TLOmodel/outputs/log_cohort_calibration__2024-10-04T101535.log') + +# Make output dataframe +results = pd.DataFrame(columns=['model', 'data', 'source'], + index= ['deaths', + 'MMR', + 'DALYs', + 'twins', + 'ectopic', + 'abortion', + 'miscarriage', + 'syphilis', + 'anaemia_an', + 'anaemia_pn' + 'gdm', + 'PROM', + 'pre_eclampsia', + 'gest-htn', + 'severe_gest-htn', + 'severe pre-eclampsia', + 'eclampsia', + 'praevia', + 'abruption', + 'aph', + 'OL', + 'UR', + 'sepsis', + 'PPH']) + +# total_pregnancies = population_size +total_pregnancies = 2000 +total_births = len(output['tlo.methods.demography']['on_birth']) +prop_live_births = (total_births/total_pregnancies) * 100 + +# Mortality/DALY +deaths_df = output['tlo.methods.demography']['death'] +prop_deaths_df = output['tlo.methods.demography.detail']['properties_of_deceased_persons'] + +dir_mat_deaths = deaths_df.loc[(deaths_df['label'] == 'Maternal Disorders')] +init_indir_mat_deaths = prop_deaths_df.loc[(prop_deaths_df['is_pregnant'] | prop_deaths_df['la_is_postpartum']) & + (prop_deaths_df['cause_of_death'].str.contains('Malaria|Suicide|ever_stroke|diabetes|' + 'chronic_ischemic_hd|ever_heart_attack|' + 'chronic_kidney_disease') | + (prop_deaths_df['cause_of_death'] == 'TB'))] + +hiv_mat_deaths = prop_deaths_df.loc[(prop_deaths_df['is_pregnant'] | prop_deaths_df['la_is_postpartum']) & + (prop_deaths_df['cause_of_death'].str.contains('AIDS_non_TB|AIDS_TB'))] + +indir_mat_deaths = len(init_indir_mat_deaths) + (len(hiv_mat_deaths) * 0.3) +total_deaths = len(dir_mat_deaths) + indir_mat_deaths + +# TOTAL_DEATHS +results.at['deaths', 'model'] = total_deaths +results.at['MMR', 'model'] = (total_deaths / total_births) * 100_000 +results.at['DALYs', 'model'] = output['tlo.methods.healthburden']['dalys_stacked']['Maternal Disorders'].sum() + +# Maternal conditions +an_comps = output['tlo.methods.pregnancy_supervisor']['maternal_complication'] +la_comps = output['tlo.methods.labour']['maternal_complication'] +pn_comps = output['tlo.methods.postnatal_supervisor']['maternal_complication'] + +twin_births = len(output['tlo.methods.newborn_outcomes']['twin_birth']) + +total_completed_pregnancies = (len(an_comps.loc[an_comps['type'] == 'ectopic_unruptured']) + + len(an_comps.loc[an_comps['type'] == 'induced_abortion']) + + len(an_comps.loc[an_comps['type'] == 'spontaneous_abortion']) + + (total_births - twin_births) + + len(output['tlo.methods.pregnancy_supervisor']['antenatal_stillbirth']) + + len(output['tlo.methods.labour']['intrapartum_stillbirth'])) + +print(total_completed_pregnancies) # this value may be less than the starting population size due to antenatal +# maternal deaths + +# Twins (todo) + +# Ectopic +results.at['ectopic', 'model'] = (len(an_comps.loc[an_comps['type'] == 'ectopic_unruptured']) / total_pregnancies) * 1000 +results.at['ectopic', 'data'] = 10.0 +results.at['ectopic', 'source'] = 'Panelli et al.' + +# Abortion + + +# Miscarriage + +# Health system diff --git a/src/scripts/maternal_perinatal_analyses/scenario_files/full_model_long_run_cohort.py b/src/scripts/maternal_perinatal_analyses/scenario_files/full_model_long_run_cohort.py new file mode 100644 index 0000000000..b0a5072bc7 --- /dev/null +++ b/src/scripts/maternal_perinatal_analyses/scenario_files/full_model_long_run_cohort.py @@ -0,0 +1,35 @@ +from tlo import Date, logging +from tlo.methods.fullmodel import fullmodel + +from tlo.scenario import BaseScenario + + +class FullModelRunForCohort(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 537184 + self.start_date = Date(2010, 1, 1) + self.end_date = Date(2025, 1, 1) + self.pop_size = 200_000 + self.number_of_draws = 1 + self.runs_per_draw = 1 + + def log_configuration(self): + return { + 'filename': 'fullmodel_200k_cohort', 'directory': './outputs', + "custom_levels": { + "*": logging.WARNING, + "tlo.methods.contraception": logging.DEBUG, + } + } + + def modules(self): + return fullmodel(resourcefilepath=self.resources) + + def draw_parameters(self, draw_number, rng): + return {} + + +if __name__ == '__main__': + from tlo.cli import scenario_run + scenario_run([__file__]) diff --git a/src/tlo/analysis/utils.py b/src/tlo/analysis/utils.py index f6aff47faa..bc8784ae66 100644 --- a/src/tlo/analysis/utils.py +++ b/src/tlo/analysis/utils.py @@ -364,6 +364,167 @@ def generate_series(dataframe: pd.DataFrame) -> pd.Series: _concat = pd.concat(res, axis=1) _concat.columns.names = ['draw', 'run'] # name the levels of the columns multi-index return _concat + + +import pandas as pd + +def old_unpack_dict_rows(df): + """ + Reconstruct a full dataframe from rows whose columns contain dictionaries + mapping local-row-index → value. Preserves original column order. + """ + original_cols = ['E', 'EventDate', 'EventName', 'A', 'V'] + reconstructed_rows = [] + + for _, row in df.iterrows(): + # Determine how many rows this block has (using the first dict column) + first_dict_col = next(col for col in original_cols if isinstance(row[col], dict)) + block_length = len(row[first_dict_col]) + + # Build each reconstructed row + for i in range(block_length): + new_row = {} + for col in original_cols: + cell = row[col] + if not isinstance(cell, dict): + raise ValueError(f"Column {col} does not contain a dictionary") + new_row[col] = cell.get(str(i)) + reconstructed_rows.append(new_row) + + # Build DataFrame and enforce the original column order + out = pd.DataFrame(reconstructed_rows)[original_cols] + return out.reset_index(drop=True) + + +def unpack_dict_rows(df, non_dict_cols=None): + """ + Reconstruct a full DataFrame from rows where most columns are dictionaries. + Non-dict columns (e.g., 'date') are propagated to all reconstructed rows. + + Parameters: + df: pd.DataFrame + non_dict_cols: list of columns that are NOT dictionaries + """ + if non_dict_cols is None: + non_dict_cols = [] + + original_cols = ['E', 'date', 'EventName', 'A', 'V'] + + reconstructed_rows = [] + + for _, row in df.iterrows(): + # Determine dict columns for this row + dict_cols = [col for col in original_cols if col not in non_dict_cols] + + if not dict_cols: + # No dict columns, just append row + reconstructed_rows.append(row.to_dict()) + continue + + # Use the first dict column to get the block length + first_dict_col = dict_cols[0] + block_length = len(row[first_dict_col]) + + # Build each expanded row + for i in range(block_length): + new_row = {} + for col in original_cols: + cell = row[col] + if col in dict_cols: + # Access the dict using string or integer keys + new_row[col] = cell.get(str(i), cell.get(i)) + else: + # Propagate non-dict value + new_row[col] = cell + reconstructed_rows.append(new_row) + + # Build DataFrame in original column order + out = pd.DataFrame(reconstructed_rows)[original_cols] + + return out.reset_index(drop=True) + + +def print_filtered_df(df): + """ + Prints rows of the DataFrame excluding EventName 'Initialise' and 'Birth'. + """ + pd.set_option('display.max_colwidth', None) + filtered = df#[~df['EventName'].isin(['StartOfSimulation', 'Birth'])] + + dict_cols = ["Info"] + max_items = 2 + # Step 2: Truncate dictionary columns for display + if dict_cols is not None: + for col in dict_cols: + def truncate_dict(d): + if isinstance(d, dict): + items = list(d.items())[:max_items] # keep only first `max_items` + return dict(items) + return d + filtered[col] = filtered[col].apply(truncate_dict) + print(filtered) + + +def extract_event_chains(results_folder: Path, + ) -> dict: + """Utility function to collect chains of events. Individuals across runs of the same draw will be combined into unique df. + Returns dictionary where keys are draws, and each draw is associated with a dataframe of format 'E', 'EventDate', 'EventName', 'Info' where 'Info' is a dictionary that combines A&Vs for a particular individual + date + event name combination. + """ + module = 'tlo.simulation' + key = 'event_chains' + + # get number of draws and numbers of runs + info = get_scenario_info(results_folder) + + # Collect results from each draw/run. Individuals across runs of the same draw will be combined into unique df. + res = dict() + + for draw in range(info['number_of_draws']): + + # All individuals in same draw will be combined across runs, so their ID will be offset. + dfs_from_runs = [] + ID_offset = 0 + + for run in range(info['runs_per_draw']): + + try: + df: pd.DataFrame = load_pickled_dataframes(results_folder, draw, run, module)[module][key] + + recon = unpack_dict_rows(df, ['date']) + print(recon) + #del recon['EventDate'] + # For now convert value to string in all cases to facilitate manipulation. This can be reversed later. + recon['V'] = recon['V'].apply(str) + # Collapse into 'E', 'EventDate', 'EventName', 'Info' format where 'Info' is dict listing attributes (e.g. {a1:v1, a2:v2, a3:v3, ...} ) + df_collapsed = ( + recon.groupby(['E', 'date', 'EventName']) + .apply(lambda g: dict(zip(g['A'], g['V']))) + .reset_index(name='Info') + ) + df_final = df_collapsed.sort_values(by=['E','date'], ascending=True).reset_index(drop=True) + birth_count = (df_final['EventName'] == 'Birth').sum() + + print("Birth count for run ", run, "is ", birth_count) + df_final['E'] = df_final['E'] + ID_offset + + # Calculate ID offset for next run + ID_offset = (max(df_final['E']) + 1) + + # Append these chains to list + dfs_from_runs.append(df_final) + + except KeyError: + # Some logs could not be found - probably because this run failed. + # Simply to not append anything to the df collecting chains. + print("Run failed") + + # Combine all dfs into a single DataFrame + res[draw] = pd.concat(dfs_from_runs, ignore_index=True) + + # Optionally, sort by 'E' and 'EventDate' after combining + res[draw] = res[draw].sort_values(by=['E', 'date']).reset_index(drop=True) + + return res def compute_summary_statistics( diff --git a/src/tlo/events.py b/src/tlo/events.py index 9dd34c9448..8149070e45 100644 --- a/src/tlo/events.py +++ b/src/tlo/events.py @@ -4,11 +4,27 @@ from enum import Enum from typing import TYPE_CHECKING -from tlo import DateOffset +from tlo import DateOffset, logging if TYPE_CHECKING: from tlo import Simulation +import pandas as pd + +from tlo.util import convert_chain_links_into_EAV + +import copy + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +logger_chain = logging.getLogger('tlo.simulation') +logger_chain.setLevel(logging.INFO) + +logger_summary = logging.getLogger(f"{__name__}.summary") +logger_summary.setLevel(logging.INFO) + +debug_chains = True class Priority(Enum): """Enumeration for the Priority, which is used in sorting the events in the simulation queue.""" @@ -60,12 +76,272 @@ def apply(self, target): :param target: the target of the event """ raise NotImplementedError + + def mni_values_differ(self, v1, v2): + + if isinstance(v1, list) and isinstance(v2, list): + return v1 != v2 # simple element-wise comparison + + if pd.isna(v1) and pd.isna(v2): + return False # treat both NaT/NaN as equal + return v1 != v2 + + def compare_entire_mni_dicts(self,entire_mni_before, entire_mni_after): + diffs = {} + + all_individuals = set(entire_mni_before.keys()) | set(entire_mni_after.keys()) + + for person in all_individuals: + if person not in entire_mni_before: # but is afterward + for key in entire_mni_after[person]: + if self.mni_values_differ(entire_mni_after[person][key],self.sim.modules['PregnancySupervisor'].default_all_mni_values[key]): + if person not in diffs: + diffs[person] = {} + diffs[person][key] = entire_mni_after[person][key] + + elif person not in entire_mni_after: # but is beforehand + for key in entire_mni_before[person]: + if self.mni_values_differ(entire_mni_before[person][key],self.sim.modules['PregnancySupervisor'].default_all_mni_values[key]): + if person not in diffs: + diffs[person] = {} + diffs[person][key] = self.sim.modules['PregnancySupervisor'].default_all_mni_values[key] + + else: # person is in both + # Compare properties + for key in entire_mni_before[person]: + if self.mni_values_differ(entire_mni_before[person][key],entire_mni_after[person][key]): + if person not in diffs: + diffs[person] = {} + diffs[person][key] = entire_mni_after[person][key] + + return diffs + + def compare_population_dataframe_and_mni(self,df_before, df_after, entire_mni_before, entire_mni_after): + """ This function compares the population dataframe and mni dictionary before/after a population-wide event has occurred. + It allows us to identify the individuals for which this event led to a significant (i.e. property) change, and to store the properties which have changed as a result of it. """ + + # Create a mask of where values are different + diff_mask = (df_before != df_after) & ~(df_before.isna() & df_after.isna()) + if 'PregnancySupervisor' in self.sim.modules: + diff_mni = self.compare_entire_mni_dicts(entire_mni_before, entire_mni_after) + else: + diff_mni = [] + + # Create an empty list to store changes for each of the individuals + chain_links = {} + len_of_diff = len(diff_mask) + + # Loop through each row of the mask + persons_changed = [] + + for idx, row in diff_mask.iterrows(): + changed_cols = row.index[row].tolist() + + if changed_cols: # Proceed only if there are changes in the row + persons_changed.append(idx) + # Create a dictionary for this person + # First add event info + link_info = { + 'EventName': type(self).__name__, + } + + # Store the new values from df_after for the changed columns + for col in changed_cols: + link_info[col] = df_after.at[idx, col] + + if idx in diff_mni: + # This person has also undergone changes in the mni dictionary, so add these here + for key in diff_mni[idx]: + link_info[col] = diff_mni[idx][key] + + # Append the event and changes to the individual key + chain_links[idx] = link_info + + if 'PregnancySupervisor' in self.sim.modules: + # For individuals which only underwent changes in mni dictionary, save changes here + if len(diff_mni)>0: + for key in diff_mni: + if key not in persons_changed: + # If individual hadn't been previously added due to changes in pop df, add it here + link_info = { + 'EventName': type(self).__name__, + } + + for key_prop in diff_mni[key]: + link_info[key_prop] = diff_mni[key][key_prop] + + chain_links[key] = link_info + + # Ensure the partial death events are cleared after event + #df = self.sim.population.props + #df.loc[:,'death_weight'] = pd.Series([[] for _ in range(len(df))]) + #df.loc[:,'cause_of_partial_death'] = pd.Series([[] for _ in range(len(df))]) + #df.loc[:,'date_of_partial_death'] = pd.Series([[] for _ in range(len(df))]) + + return chain_links + + + def store_chains_to_do_before_event(self) -> tuple[bool, pd.Series, pd.DataFrame, dict, dict, bool]: + """ This function checks whether this event should be logged as part of the event chains, and if so stored required information before the event has occurred. """ + + # Initialise these variables + print_chains = False + df_before = [] + row_before = pd.Series() + mni_instances_before = False + mni_row_before = {} + entire_mni_before = {} + + # Only print event if it belongs to modules of interest and if it is not in the list of events to ignore + if all(sub not in str(self) for sub in self.sim.generate_event_chains_ignore_events): + + # Will eventually use this once I can actually GET THE NAME OF THE SELF + #if not set(self.sim.generate_event_chains_ignore_events).intersection(str(self)): + + print_chains = True + + # Target is single individual + if self.target != self.sim.population: + + # Save row for comparison after event has occurred + row_before = self.sim.population.props.loc[[abs(self.target)], :].copy(deep=True).iloc[0] + row_before = row_before.fillna(-99999) + # Recursively deep copy all object columns (like lists, Series, dicts) + for col in row_before.index: + val = row_before[col] + if isinstance(val, (list, dict, pd.Series)): + row_before[col] = copy.deepcopy(val) + + # Check if individual is already in mni dictionary, if so copy her original status + if 'PregnancySupervisor' in self.sim.modules: + mni = self.sim.modules['PregnancySupervisor'].mother_and_newborn_info + if self.target in mni: + mni_instances_before = True + mni_row_before = mni[self.target].copy() + else: + mni_row_before = None + + else: + + # This will be a population-wide event. In order to find individuals for which this led to + # a meaningful change, make a copy of the while pop dataframe/mni before the event has occurred. + df_before = self.sim.population.props.copy() + if 'PregnancySupervisor' in self.sim.modules: + entire_mni_before = copy.deepcopy(self.sim.modules['PregnancySupervisor'].mother_and_newborn_info) + else: + entire_mni_before = None + + return print_chains, row_before, df_before, mni_row_before, entire_mni_before, mni_instances_before + + def store_chains_to_do_after_event(self, row_before, df_before, mni_row_before, entire_mni_before, mni_instances_before) -> dict: + """ If print_chains=True, this function logs the event and identifies and logs the any property changes that have occured to one or multiple individuals as a result of the event taking place. """ + + chain_links = {} + + # Target is single individual + if self.target != self.sim.population: + + # Copy full new status for individual + row_after = self.sim.population.props.loc[abs(self.target)].fillna(-99999) + + # Check if individual is in mni after the event + mni_instances_after = False + if 'PregnancySupervisor' in self.sim.modules: + mni = self.sim.modules['PregnancySupervisor'].mother_and_newborn_info + if self.target in mni: + mni_instances_after = True + else: + mni_instances_after = None + + # Create and store event for this individual, regardless of whether any property change occurred + link_info = { + 'EventName' : type(self).__name__, + } + + # Store (if any) property changes as a result of the event for this individual + for key in row_before.index: + if row_before[key] != row_after[key]: # Note: used fillna previously, so this is safe + link_info[key] = row_after[key] + + if 'PregnancySupervisor' in self.sim.modules: + # Now check and store changes in the mni dictionary, accounting for following cases: + # Individual is in mni dictionary before and after + if mni_instances_before and mni_instances_after: + for key in mni_row_before: + if self.mni_values_differ(mni_row_before[key], mni[self.target][key]): + link_info[key] = mni[self.target][key] + # Individual is only in mni dictionary before event + elif mni_instances_before and not mni_instances_after: + default = self.sim.modules['PregnancySupervisor'].default_all_mni_values + for key in mni_row_before: + if self.mni_values_differ(mni_row_before[key], default[key]): + link_info[key] = default[key] + # Individual is only in mni dictionary after event + elif mni_instances_after and not mni_instances_before: + default = self.sim.modules['PregnancySupervisor'].default_all_mni_values + for key in default: + if self.mni_values_differ(default[key], mni[self.target][key]): + link_info[key] = mni[self.target][key] + # Else, no need to do anything + + # Add individual to the chain links + chain_links[self.target] = link_info + + # Ensure the partial death events are cleared after event + #df = self.sim.population.props + #df.at[self.target,'death_weight'] = [] + #df.at[self.target,'cause_of_partial_death'] = [] + #df.at[self.target,'date_of_partial_death'] = [] + else: + # Target is entire population. Identify individuals for which properties have changed + # and store their changes. + + # Population frame after event + df_after = self.sim.population.props + if 'PregnancySupervisor' in self.sim.modules: + entire_mni_after = copy.deepcopy(self.sim.modules['PregnancySupervisor'].mother_and_newborn_info) + else: + entire_mni_after = None + + # Create and store the event and dictionary of changes for affected individuals + chain_links = self.compare_population_dataframe_and_mni(df_before, df_after, entire_mni_before, entire_mni_after) + + return chain_links + def run(self): """Make the event happen.""" + + # Collect relevant information before event takes place + if self.sim.generate_event_chains: + print_chains, row_before, df_before, mni_row_before, entire_mni_before, mni_instances_before = self.store_chains_to_do_before_event() + self.apply(self.target) self.post_apply_hook() - + + # Collect event info + meaningful property changes of individuals. Combined, these will constitute a 'link' + # in the individual's event chain. + if self.sim.generate_event_chains and print_chains: + chain_links = self.store_chains_to_do_after_event(row_before, df_before, mni_row_before, entire_mni_before, mni_instances_before) + + if chain_links: + # Convert chain_links into EAV + ednav = convert_chain_links_into_EAV(chain_links) + + logger_chain.info(key='event_chains', + data= ednav.to_dict(), + description='Links forming chains of events for simulated individuals') + """ + # Create empty logger for entire pop + pop_dict = {i: '' for i in range(FACTOR_POP_DICT)} # Always include all possible individuals + pop_dict.update(chain_links) + + # Log chain_links here + if len(chain_links)>0: + logger_chain.info(key='event_chains', + data= pop_dict, + description='Links forming chains of events for simulated individuals') + """ class RegularEvent(Event): """An event that automatically reschedules itself at a fixed frequency.""" diff --git a/src/tlo/methods/care_of_women_during_pregnancy.py b/src/tlo/methods/care_of_women_during_pregnancy.py index 4cdf3339bb..029e84c6e5 100644 --- a/src/tlo/methods/care_of_women_during_pregnancy.py +++ b/src/tlo/methods/care_of_women_during_pregnancy.py @@ -733,7 +733,9 @@ def screening_interventions_delivered_at_every_contact(self, hsi_event): self, int_name='urine_dipstick', hsi_event=hsi_event, q_param=[params['prob_intervention_delivered_urine_ds']], + equipment={'Urine dip Stick'}, + cons=self.item_codes_preg_consumables['urine_dipstick'], dx_test='urine_dipstick_protein') @@ -2559,6 +2561,7 @@ def apply(self, person_id, squeeze_factor): q_param=[l_params['prob_hcw_avail_retained_prod'], l_params['mean_hcw_competence_hp']], cons=pac_cons, opt_cons=pac_opt_cons, + equipment={'D&C set', 'Suction Curettage machine', 'Drip stand', 'Infusion pump', 'Evacuation set'}) if pac_delivered: diff --git a/src/tlo/methods/contraception.py b/src/tlo/methods/contraception.py index 6f56a55cc7..7aaa64c181 100644 --- a/src/tlo/methods/contraception.py +++ b/src/tlo/methods/contraception.py @@ -1106,6 +1106,12 @@ def set_new_pregnancy(self, women_id: list): }, description='pregnancy following the failure of contraceptive method') + person = df.loc[w] + + logger.debug(key='properties_of_pregnant_person', + data=person.to_dict(), + description='values of all properties at the time of pregnancy for newwly pregnany persons') + class ContraceptionLoggingEvent(RegularEvent, PopulationScopeEventMixin): def __init__(self, module): diff --git a/src/tlo/methods/demography.py b/src/tlo/methods/demography.py index 2acaad75eb..6ceaabb88b 100644 --- a/src/tlo/methods/demography.py +++ b/src/tlo/methods/demography.py @@ -120,8 +120,13 @@ def __init__(self, name=None, equal_allocation_by_district: bool = False): # as optional if they can be undefined for a given individual. PROPERTIES = { 'is_alive': Property(Types.BOOL, 'Whether this individual is alive'), + 'aliveness_weight': Property(Types.REAL, 'If allowing individual to survive, track decreasing weight'), + 'death_weight': Property(Types.LIST, 'If allowing individual to survive, track weight of death'), 'date_of_birth': Property(Types.DATE, 'Date of birth of this individual'), 'date_of_death': Property(Types.DATE, 'Date of death of this individual'), + 'date_of_partial_death': Property(Types.LIST, 'Date of latest partial death of this individual'), + 'cause_of_partial_death': Property(Types.LIST, 'Date of latest partial death of this individual'), + 'sex': Property(Types.CATEGORICAL, 'Male or female', categories=['M', 'F']), 'mother_id': Property(Types.INT, 'Unique identifier of mother of this individual'), @@ -273,9 +278,14 @@ def initialise_population(self, population): # Assign the characteristics df.is_alive.values[:] = True + alive_count = sum(df.is_alive) + df.aliveness_weight.values[:] = 1 + df.loc[df.is_alive, 'death_weight'] = pd.Series([[] for _ in range(alive_count)]) df.loc[df.is_alive, 'date_of_birth'] = demog_char_to_assign['date_of_birth'] df.loc[df.is_alive, 'date_of_death'] = pd.NaT + df.loc[df.is_alive, 'date_of_partial_death'] = pd.Series([[] for _ in range(alive_count)]) df.loc[df.is_alive, 'cause_of_death'] = np.nan + df.loc[df.is_alive, 'cause_of_partial_death'] = pd.Series([[] for _ in range(alive_count)]) df.loc[df.is_alive, 'sex'] = demog_char_to_assign['Sex'] df.loc[df.is_alive, 'mother_id'] = DEFAULT_MOTHER_ID # Motherless, and their characterists are not inherited df.loc[df.is_alive, 'district_num_of_residence'] = demog_char_to_assign['District_Num'].values[:] @@ -324,9 +334,10 @@ def initialise_simulation(self, sim): # Launch the repeating event that will store statistics about the population structure sim.schedule_event(DemographyLoggingEvent(self), sim.date) - # Create (and store pointer to) the OtherDeathPoll and schedule first occurrence immediately - self.other_death_poll = OtherDeathPoll(self) - sim.schedule_event(self.other_death_poll, sim.date) + if sim.generate_event_chains is False: + # Create (and store pointer to) the OtherDeathPoll and schedule first occurrence immediately + self.other_death_poll = OtherDeathPoll(self) + sim.schedule_event(self.other_death_poll, sim.date) # Log the initial population scaling-factor (to the logger of this module and that of `tlo.methods.population`) for _logger in (logger, logger_scale_factor): @@ -368,9 +379,13 @@ def on_birth(self, mother_id, child_id): child = { 'is_alive': True, + 'aliveness_weight': 1.0, + 'death_weight': [], 'date_of_birth': self.sim.date, 'date_of_death': pd.NaT, 'cause_of_death': np.nan, + 'date_of_partial_death': [], + 'cause_of_partial_death': [], 'sex': 'M' if rng.random_sample() < fraction_of_births_male else 'F', 'mother_id': mother_id, 'district_num_of_residence': _district_num_of_residence, @@ -494,6 +509,7 @@ def process_causes_of_death(self): data=mapper_from_gbd_causes ) + def do_death(self, individual_id: int, cause: str, originating_module: Module): """Register and log the death of an individual from a specific cause. * 'individual_id' is the index in the population.props dataframe to the (one) person. @@ -780,6 +796,26 @@ def apply(self, population): self.sim.date + DateOffset(days=self.module.rng.randint(0, 30))) +class InstantaneousPartialDeath(Event, IndividualScopeEventMixin): + """ + Call the do_death function to cause the person to die. + + Note that no checking is done here. (Checking is done within `do_death` which can also be called directly.) + + The 'individual_id' is the index in the population.props dataframe. It is for _one_ person only. + The 'cause' is the cause that is defined by the disease module (aka, "tlo cause"). + The 'module' passed to this event is the disease module that is causing the death. + """ + + def __init__(self, module, individual_id, cause, weight): + super().__init__(module, person_id=individual_id) + self.cause = cause + self.weight = weight + + def apply(self, individual_id): + self.sim.modules['Demography'].do_partial_death(individual_id, cause=self.cause, originating_module=self.module, weight=self.weight) + + class InstantaneousDeath(Event, IndividualScopeEventMixin): """ Call the do_death function to cause the person to die. @@ -822,8 +858,8 @@ def apply(self, population): logger.info( key='population', data={'total': sum(sex_count), - 'male': sex_count['M'], - 'female': sex_count['F'] + 'male': sex_count['M'] if 'M' in sex_count.index else 0, + 'female': sex_count['F'] if 'F' in sex_count.index else 0, }) # (nb. if you groupby both sex and age_range, you weirdly lose categories where size==0, so diff --git a/src/tlo/methods/epi.py b/src/tlo/methods/epi.py index 4bc298aefc..eed98ac96e 100644 --- a/src/tlo/methods/epi.py +++ b/src/tlo/methods/epi.py @@ -181,6 +181,7 @@ def initialise_simulation(self, sim): # add an event to log to screen sim.schedule_event(EpiLoggingEvent(self), sim.date + DateOffset(years=1)) + # TODO: check with Tara shes happy with this (could come in as its own PR) # HPV vaccine given from 2018 onwards if self.sim.date.year < 2018: sim.schedule_event(HpvScheduleEvent(self), Date(2018, 1, 1)) diff --git a/src/tlo/methods/hiv.py b/src/tlo/methods/hiv.py index 8b40e37a34..f7cc2eadb2 100644 --- a/src/tlo/methods/hiv.py +++ b/src/tlo/methods/hiv.py @@ -1523,15 +1523,16 @@ def per_capita_testing_rate(self): df = self.sim.population.props - if not self.stored_test_numbers: - # If it's the first year, set previous_test_numbers to 0 + # get number of tests performed in last time period + if self.sim.date.year == (self.sim.start_date.year + 1): + number_tests_new = df.hv_number_tests.sum() previous_test_numbers = 0 else: # For subsequent years, retrieve the last stored number previous_test_numbers = self.stored_test_numbers[-1] - # Calculate number of tests now performed - cumulative, include those who have died - number_tests_new = df.hv_number_tests.sum() + # Calculate number of tests now performed - cumulative, include those who have died + number_tests_new = df.hv_number_tests.sum() # Store the number of tests performed in this year for future reference self.stored_test_numbers.append(number_tests_new) @@ -1682,6 +1683,37 @@ def do_at_generic_first_appt( # Main Polling Event # --------------------------------------------------------------------------- +class HivPollingEventForDataGeneration(RegularEvent, PopulationScopeEventMixin): + """ The HIV Polling Events for Data Generation + * Ensures that + """ + + def __init__(self, module): + super().__init__( + module, frequency=DateOffset(years=120) + ) # repeats every 12 months, but this can be changed + + def apply(self, population): + + df = population.props + + # Make everyone who is alive and not infected (no-one should be) susceptible + susc_idx = df.loc[ + df.is_alive + & ~df.hv_inf + ].index + + n_susceptible = len(susc_idx) + print("Number of individuals susceptible", n_susceptible) + # Schedule the date of infection for each new infection: + for i in susc_idx: + date_of_infection = self.sim.date + pd.DateOffset( + # Ensure that individual will be infected before end of sim + days=self.module.rng.randint(0, 365*(int(self.sim.end_date.year - self.sim.date.year)+1)) + ) + self.sim.schedule_event( + HivInfectionEvent(self.module, i), date_of_infection + ) class HivRegularPollingEvent(RegularEvent, PopulationScopeEventMixin): """ The HIV Regular Polling Events @@ -1703,6 +1735,7 @@ def apply(self, population): fraction_of_year_between_polls = self.frequency.months / 12 beta = p["beta"] * fraction_of_year_between_polls + # ----------------------------------- HORIZONTAL TRANSMISSION ----------------------------------- def horizontal_transmission(to_sex, from_sex): # Count current number of alive 15-80 year-olds at risk of transmission @@ -1778,6 +1811,7 @@ def horizontal_transmission(to_sex, from_sex): HivInfectionEvent(self.module, idx), date_of_infection ) + # ----------------------------------- SPONTANEOUS TESTING ----------------------------------- def spontaneous_testing(current_year): @@ -1902,6 +1936,8 @@ def vmmc_for_child(): vmmc_for_child() + + # --------------------------------------------------------------------------- # Natural History Events # --------------------------------------------------------------------------- diff --git a/src/tlo/methods/hsi_event.py b/src/tlo/methods/hsi_event.py index cc5e0a4fb2..0fd2a29416 100644 --- a/src/tlo/methods/hsi_event.py +++ b/src/tlo/methods/hsi_event.py @@ -7,17 +7,27 @@ from tlo import Date, logging from tlo.events import Event +from tlo.population import Population +from tlo.util import convert_chain_links_into_EAV +import pandas as pd if TYPE_CHECKING: from tlo import Module, Simulation from tlo.methods.healthsystem import HealthSystem +# Pointing to the logger in events +logger_chains = logging.getLogger("tlo.simulation") +logger_chains.setLevel(logging.INFO) + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger_summary = logging.getLogger(f"{__name__}.summary") logger_summary.setLevel(logging.INFO) +debug_chains = True + + # Declare the level which will be used to represent the merging of levels '1b' and '2' LABEL_FOR_MERGED_FACILITY_LEVELS_1B_AND_2 = "2" @@ -194,12 +204,146 @@ def _run_after_hsi_event(self) -> None: facility_id=self.facility_info.id ) + def values_differ(self, v1, v2): + + if isinstance(v1, list) and isinstance(v2, list): + return v1 != v2 # simple element-wise comparison + + if pd.isna(v1) and pd.isna(v2): + return False # treat both NaT/NaN as equal + return v1 != v2 + + + def store_chains_to_do_before_event(self) -> tuple[bool, pd.Series, dict, bool]: + """ This function checks whether this event should be logged as part of the event chains, and if so stored required information before the event has occurred. """ + + # Initialise these variables + print_chains = False + row_before = pd.Series() + mni_instances_before = False + mni_row_before = {} + + # Only print event if it belongs to modules of interest and if it is not in the list of events to ignore + if (self.module in self.sim.generate_event_chains_modules_of_interest) and all(sub not in str(self) for sub in self.sim.generate_event_chains_ignore_events): + + # Will eventually use this once I can actually GET THE NAME OF THE SELF + #if not set(self.sim.generate_event_chains_ignore_events).intersection(str(self)): + + print_chains = True + + # Target is single individual + if self.target != self.sim.population: + + # Save row for comparison after event has occurred + row_before = self.sim.population.props.loc[abs(self.target)].copy().fillna(-99999) + + # Check if individual is in mni dictionary before the event, if so store its original status + if 'PregnancySupervisor' in self.sim.modules: + mni = self.sim.modules['PregnancySupervisor'].mother_and_newborn_info + if self.target in mni: + mni_instances_before = True + mni_row_before = mni[self.target].copy() + + else: + print("ERROR: there shouldn't be pop-wide HSI event") + exit(-1) + + return print_chains, row_before, mni_row_before, mni_instances_before + + def store_chains_to_do_after_event(self, row_before, footprint, mni_row_before, mni_instances_before) -> dict: + """ If print_chains=True, this function logs the event and identifies and logs the any property changes that have occured to one or multiple individuals as a result of the event taking place. """ + + # For HSI event, this will only ever occur for individual events + chain_links = {} + + row_after = self.sim.population.props.loc[abs(self.target)].fillna(-99999) + + mni_instances_after = False + if 'PregnancySupervisor' in self.sim.modules: + mni = self.sim.modules['PregnancySupervisor'].mother_and_newborn_info + if self.target in mni: + mni_instances_after = True + + # Create and store dictionary of changes. Note that person_ID, event, event_date, appt_foot, and level + # will be stored regardless of whether individual experienced property changes or not. + + # Add event details + try: + record_footprint = str(footprint) + record_level = self.facility_info.level + except: + record_footprint = 'N/A' + record_level = 'N/A' + + link_info = { + 'EventName' : type(self).__name__, + 'appt_footprint' : record_footprint, + 'level' : record_level, + } + + # Add changes to properties + for key in row_before.index: + if row_before[key] != row_after[key]: # Note: used fillna previously + link_info[key] = row_after[key] + + if 'PregnancySupervisor' in self.sim.modules: + # Now store changes in the mni dictionary, accounting for following cases: + # Individual is in mni dictionary before and after + if mni_instances_before and mni_instances_after: + for key in mni_row_before: + if self.values_differ(mni_row_before[key], mni[self.target][key]): + link_info[key] = mni[self.target][key] + # Individual is only in mni dictionary before event + elif mni_instances_before and not mni_instances_after: + default = self.sim.modules['PregnancySupervisor'].default_all_mni_values + for key in mni_row_before: + if self.values_differ(mni_row_before[key], default[key]): + link_info[key] = default[key] + # Individual is only in mni dictionary after event + elif mni_instances_after and not mni_instances_before: + default = self.sim.modules['PregnancySupervisor'].default_all_mni_values + for key in default: + if self.values_differ(default[key], mni[self.target][key]): + link_info[key] = mni[self.target][key] + + chain_links[self.target] = link_info + + return chain_links + + def run(self, squeeze_factor): """Make the event happen.""" + + + if self.sim.generate_event_chains and self.target != self.sim.population: + print_chains, row_before, mni_row_before, mni_instances_before = self.store_chains_to_do_before_event() + + footprint = self.EXPECTED_APPT_FOOTPRINT + updated_appt_footprint = self.apply(self.target, squeeze_factor) self.post_apply_hook() self._run_after_hsi_event() + + + if self.sim.generate_event_chains and self.target != self.sim.population: + + # If the footprint has been updated when the event ran, change it here + if updated_appt_footprint is not None: + footprint = updated_appt_footprint + + if print_chains: + chain_links = self.store_chains_to_do_after_event(row_before, str(footprint), mni_row_before, mni_instances_before) + + if chain_links: + + # Convert chain_links into EAV + ednav = convert_chain_links_into_EAV(chain_links) + logger_chain.info(key='event_chains', + data = ednav, + description='Links forming chains of events for simulated individuals') + return updated_appt_footprint + def get_consumables( self, diff --git a/src/tlo/methods/labour.py b/src/tlo/methods/labour.py index 8cd97da57f..4eb933cf30 100644 --- a/src/tlo/methods/labour.py +++ b/src/tlo/methods/labour.py @@ -864,6 +864,7 @@ def get_and_store_labour_item_codes(self): {ic('Infant resuscitator, clear plastic + mask + bag_each_CMST'): 1} def initialise_simulation(self, sim): + # We call the following function to store the required consumables for the simulation run within the appropriate # dictionary self.get_and_store_labour_item_codes() @@ -1483,25 +1484,37 @@ def apply_risk_of_early_postpartum_death(self, individual_id): # Check the right women are at risk of death self.postpartum_characteristics_checker(individual_id) - # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death - # (either from one or multiple causes) and if death occurs returns the cause - potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( - self, individual_id=individual_id, timing='postnatal') + survived = True + + if self.sim.generate_event_chains: + + risks = pregnancy_helper_functions.get_risk_of_death_from_cause_maternal( + self, individual_id=individual_id, timing='postnatal') + + pregnancy_helper_functions.apply_multiple_partial_deaths(self, risks, individual_id=individual_id) + + else: + + # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death + # (either from one or multiple causes) and if death occurs returns the cause + potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( + self, individual_id=individual_id, timing='postnatal') - # Log df row containing complications and treatments to calculate met need + # Log df row containing complications and treatments to calculate met need - # If a cause is returned death is scheduled - if potential_cause_of_death: - pregnancy_helper_functions.log_mni_for_maternal_death(self, individual_id) - self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 + # If a cause is returned death is scheduled + if potential_cause_of_death: + survived = False + pregnancy_helper_functions.log_mni_for_maternal_death(self, individual_id) + self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 - self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, - originating_module=self.sim.modules['Labour']) + self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, + originating_module=self.sim.modules['Labour']) # If she hasn't died from any complications, we reset some key properties that resolve after risk of death # has been applied - else: + if survived: if df.at[individual_id, 'pn_htn_disorders'] == 'eclampsia': df.at[individual_id, 'pn_htn_disorders'] = 'severe_pre_eclamp' @@ -2547,19 +2560,31 @@ def apply(self, individual_id): self.module.labour_characteristics_checker(individual_id) - # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death - # (either from one or multiple causes) and if death occurs returns the cause - potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( - self.module, individual_id=individual_id, timing='intrapartum') + survived = True + + if self.sim.generate_event_chains: + + risks = pregnancy_helper_functions.get_risk_of_death_from_cause_maternal( + self.module, individual_id=individual_id, timing='intrapartum') + + pregnancy_helper_functions.apply_multiple_partial_deaths(self, risks, individual_id=individual_id) + + else: + + # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death + # (either from one or multiple causes) and if death occurs returns the cause + potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( + self.module, individual_id=individual_id, timing='intrapartum') - # If a cause is returned death is scheduled - if potential_cause_of_death: - pregnancy_helper_functions.log_mni_for_maternal_death(self.module, individual_id) - self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 - self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, - originating_module=self.sim.modules['Labour']) + # If a cause is returned death is scheduled + if potential_cause_of_death: + survived = False + pregnancy_helper_functions.log_mni_for_maternal_death(self.module, individual_id) + self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 + self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, + originating_module=self.sim.modules['Labour']) - mni[individual_id]['death_in_labour'] = True + mni[individual_id]['death_in_labour'] = True # Next we determine if she will experience a stillbirth during her delivery outcome_of_still_birth_equation = self.module.predict(self.module.la_linear_models['intrapartum_still_birth'], @@ -2606,7 +2631,7 @@ def apply(self, individual_id): mni[individual_id]['didnt_seek_care'] = False # Finally, reset some treatment variables - if not potential_cause_of_death: + if survived: df.at[individual_id, 'la_maternal_hypertension_treatment'] = False df.at[individual_id, 'ac_iv_anti_htn_treatment'] = False df.at[individual_id, 'ac_mag_sulph_treatment'] = False diff --git a/src/tlo/methods/mnh_cohort_module.py b/src/tlo/methods/mnh_cohort_module.py new file mode 100644 index 0000000000..77ba618748 --- /dev/null +++ b/src/tlo/methods/mnh_cohort_module.py @@ -0,0 +1,175 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +from tlo import DateOffset, Module, Parameter, Property, Types, logging +from tlo.methods import Metadata +from tlo.analysis.utils import parse_log_file +from tlo.events import Event, IndividualScopeEventMixin + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +class MaternalNewbornHealthCohort(Module): + """ + When registered this module overrides the population data frame with a cohort of pregnant women. Cohort properties + are sourced from a long run of the full model in which the properties of all newly pregnant women per year were + logged. The cohort represents women in 2024. The maximum population size is 13,000. + """ + + METADATA = { + Metadata.DISEASE_MODULE, + Metadata.USES_SYMPTOMMANAGER, + Metadata.USES_HEALTHSYSTEM, + Metadata.USES_HEALTHBURDEN + } + + CAUSES_OF_DEATH = {} + CAUSES_OF_DISABILITY = {} + PARAMETERS = {} + PROPERTIES = {} + + def __init__(self, name=None, resourcefilepath=None): + super().__init__(name) + self.resourcefilepath = resourcefilepath + + def read_parameters(self, data_folder): + pass + + def initialise_population(self, population): + """Set our property values for the initial population. + + This method is called by the simulation when creating the initial population, and is + responsible for assigning initial values, for every individual, of those properties + 'owned' by this module, i.e. those declared in the PROPERTIES dictionary above. + + :param population: the population of individuals + """ + + # Read in excel sheet with cohort + all_preg_df = pd.read_excel(Path(f'{self.resourcefilepath}/maternal cohort') / + 'ResourceFile_All2024PregnanciesCohortModel.xlsx') + + # Only select rows equal to the desired population size + if len(self.sim.population.props) <= len(all_preg_df): + preg_pop = all_preg_df.loc[0:(len(self.sim.population.props))-1] + #preg_pop = all_preg_df.iloc[0:(len(self.sim.population.props))-1].copy() + + else: + # Calculate the number of rows needed to reach the desired length + additional_rows = len(self.sim.population.props) - len(all_preg_df) + + # Initialize an empty DataFrame for additional rows + rows_to_add = pd.DataFrame(columns=all_preg_df.columns) + + # Loop to fill the required additional rows + while additional_rows > 0: + if additional_rows >= len(all_preg_df): + rows_to_add = pd.concat([rows_to_add, all_preg_df], ignore_index=True) + additional_rows -= len(all_preg_df) + else: + rows_to_add = pd.concat([rows_to_add, all_preg_df.iloc[:additional_rows]], ignore_index=True) + additional_rows = 0 + + # Concatenate the original DataFrame with the additional rows + preg_pop = pd.concat([all_preg_df, rows_to_add], ignore_index=True) + + + # Set the dtypes and index of the cohort dataframe + props_dtypes = self.sim.population.props.dtypes + + preg_pop.loc[:,'aliveness_weight'] = 1 + preg_pop.loc[:,'death_weight'] = [[] for _ in range(len(preg_pop))] + preg_pop.loc[:,'cause_of_partial_death'] = [[] for _ in range(len(preg_pop))] + preg_pop.loc[:,'date_of_partial_death'] = [[] for _ in range(len(preg_pop))] + + print(preg_pop.columns) + missing_cols = [col for col in self.sim.population.props.columns if col not in preg_pop.columns] + if missing_cols: + print("Missing columns") + for col in missing_cols: + print(col) + exit(-1) + + # Reorder columns + preg_pop = preg_pop[self.sim.population.props.columns] + + preg_pop_final = preg_pop.astype(props_dtypes.to_dict()) + + preg_pop_final.index.name = 'person' + + # For the below columns we manually overwrite the dtypes + for column in ['rt_injuries_for_minor_surgery', 'rt_injuries_for_major_surgery', + 'rt_injuries_to_heal_with_time', 'rt_injuries_for_open_fracture_treatment', + 'rt_injuries_left_untreated', 'rt_injuries_to_cast']: + preg_pop_final[column] = [[] for _ in range(len(preg_pop_final))] + + # Set the population.props dataframe to the new cohort + self.sim.population.props = preg_pop_final + + # Update key pregnancy properties + df = self.sim.population.props + population = df.loc[df.is_alive] + df.loc[population.index, 'date_of_last_pregnancy'] = self.sim.start_date + df.loc[population.index, 'co_contraception'] = "not_using" + + # import tableone + # columns = ['age_years', 'la_parity', 'region_of_residence', 'li_wealth', 'li_bmi', 'li_mar_stat', 'li_ed_lev', + # 'li_urban', 'ps_prev_spont_abortion', 'ps_prev_stillbirth', 'ps_prev_pre_eclamp', 'ps_prev_gest_diab'] + # categorical = ['region_of_residence', 'li_wealth', 'li_bmi' ,'li_mar_stat', 'li_ed_lev', 'li_urban', + # 'ps_prev_spont_abortion', 'ps_prev_stillbirth', 'ps_prev_pre_eclamp', 'ps_prev_gest_diab'] + # continuous = ['age_years', 'la_parity'] + # + # rename = {'age_years': 'Age (years)', + # 'la_parity': 'Parity', + # 'district_of_residence': 'District', + # 'li_wealth': 'Wealth Qunitle', + # 'li_bmi': 'BMI level', + # 'li_mar_stat': 'Marital Status', + # 'li_ed_lev': 'Education Level', + # 'li_urban': 'Urban/Rural', + # 'ps_prev_spont_abortion': 'Previous Miscarriage', + # 'ps_prev_stillbirth': 'Previous Stillbirth', + # 'ps_prev_pre_eclamp': 'Previous Pre-eclampsia', + # 'ps_prev_gest_diab': 'Previous Gestational Diabetes', + # } + # from tableone import TableOne + # + # mytable = TableOne(self.sim.population.props[columns], categorical=categorical, + # continuous=continuous, rename=rename, pval=False) + + def initialise_simulation(self, sim): + """Get ready for simulation start. + + This method is called just before the main simulation loop begins, and after all + modules have read their parameters and the initial population has been created. + It is a good place to add initial events to the event queue. + + """ + df = self.sim.population.props + + # Clear HSI queue for events scheduled during initialisation + sim.modules['HealthSystem'].HSI_EVENT_QUEUE.clear() + + # Clear the individual event queue for events scheduled during initialisation + updated_event_queue = [item for item in self.sim.event_queue.queue + if not isinstance(item[3], IndividualScopeEventMixin)] + self.sim.event_queue.queue = updated_event_queue + + # Prevent additional pregnancies from occurring during the cohort tun + self.sim.modules['Contraception'].processed_params['p_pregnancy_with_contraception_per_month'].iloc[:] = 0 + self.sim.modules['Contraception'].processed_params['p_pregnancy_no_contraception_per_month'].iloc[:] = 0 + + # Set labour date for cohort women + for person in df.index: + self.sim.modules['Labour'].set_date_of_labour(person) + + def on_birth(self, mother_id, child_id): + pass + + def report_daly_values(self): + pass + + def on_hsi_alert(self, person_id, treatment_id): + pass diff --git a/src/tlo/methods/postnatal_supervisor.py b/src/tlo/methods/postnatal_supervisor.py index 109b7d5694..7de5a3d355 100644 --- a/src/tlo/methods/postnatal_supervisor.py +++ b/src/tlo/methods/postnatal_supervisor.py @@ -642,18 +642,36 @@ def log_new_progressed_cases(disease): (df['pn_postnatal_period_in_weeks'] == week) & (df['pn_htn_disorders'] == 'severe_gest_htn')] - die_from_htn = pd.Series(self.rng.random_sample(len(at_risk_of_death_htn)) < - params['weekly_prob_death_severe_gest_htn'], index=at_risk_of_death_htn.index) - # Those women who die the on_death function in demography is applied - for person in die_from_htn.loc[die_from_htn].index: - self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['severe_gestational_hypertension_m_death'] += 1 - self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 + if self.sim.generate_event_chains: + + # If generating chains of events, all women at risk will partially die, partially survive + for person in at_risk_of_death_htn.index: + + original_aliveness_weight = df.loc[person,'aliveness_weight'] + df.loc[person,'aliveness_weight'] *= (1. - params['weekly_prob_death_severe_gest_htn']) + # Individual is partially dead + death_weight = original_aliveness_weight * params['weekly_prob_death_severe_gest_htn'] + + df.loc[person, 'date_of_partial_death'].append(str(self.sim.date)) + df.loc[person, 'death_weight'].append(death_weight) + df.loc[person, 'cause_of_partial_death'].append('severe_gestational_hypertension') - self.sim.modules['Demography'].do_death(individual_id=person, cause='severe_gestational_hypertension', - originating_module=self.sim.modules['PostnatalSupervisor']) + else: + + die_from_htn = pd.Series(self.rng.random_sample(len(at_risk_of_death_htn)) < + params['weekly_prob_death_severe_gest_htn'], index=at_risk_of_death_htn.index) + + for person in die_from_htn.loc[die_from_htn].index: + + # Those women who die the on_death function in demography is applied + self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['severe_gestational_hypertension_m_death'] += 1 + self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 - del self.sim.modules['PregnancySupervisor'].mother_and_newborn_info[person] + self.sim.modules['Demography'].do_death(individual_id=person, cause='severe_gestational_hypertension', + originating_module=self.sim.modules['PostnatalSupervisor']) + + del self.sim.modules['PregnancySupervisor'].mother_and_newborn_info[person] # ----------------------------------------- CARE SEEKING ------------------------------------------------------ # We now use the pn_emergency_event_mother property that has just been set for women who are experiencing @@ -800,21 +818,33 @@ def apply_risk_of_maternal_or_neonatal_death_postnatal(self, mother_or_child, in # Create a list of all the causes that may cause death in the individual (matched to GBD labels) if mother_or_child == 'mother': - # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death - # (either from one or multiple causes) and if death occurs returns the cause - potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( - self, individual_id=individual_id, timing='postnatal') - - # If a cause is returned death is scheduled - if potential_cause_of_death: - mni[individual_id]['didnt_seek_care'] = True - pregnancy_helper_functions.log_mni_for_maternal_death(self, individual_id) - self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 - self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, - originating_module=self.sim.modules['PostnatalSupervisor']) - del mni[individual_id] + survived = True + + if self.sim.generate_event_chains: + + # If collecting events, woman will always partially survive and partially die from all possible causes + risks = pregnancy_helper_functions.get_risk_of_death_from_cause_maternal( + self, individual_id=individual_id, timing='postnatal') + + pregnancy_helper_functions.apply_multiple_partial_deaths(self, risks, individual_id=individual_id) else: + # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death + # (either from one or multiple causes) and if death occurs returns the cause + potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( + self, individual_id=individual_id, timing='postnatal') + + # If a cause is returned death is scheduled + if potential_cause_of_death: + survived = False + mni[individual_id]['didnt_seek_care'] = True + pregnancy_helper_functions.log_mni_for_maternal_death(self, individual_id) + self.sim.modules['PregnancySupervisor'].mnh_outcome_counter['direct_mat_death'] += 1 + self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, + originating_module=self.sim.modules['PostnatalSupervisor']) + del mni[individual_id] + + if survived: # Reset variables for women who survive if mother.pn_postpartum_haem_secondary: df.at[individual_id, 'pn_postpartum_haem_secondary'] = False @@ -829,6 +859,8 @@ def apply_risk_of_maternal_or_neonatal_death_postnatal(self, mother_or_child, in if mother_or_child == 'child': # Neonates can have either early or late onset sepsis, not both at once- so we use either equation # depending on this neonates current condition + + survived = True if child.pn_sepsis_early_neonatal: risk_of_death = params['cfr_early_onset_neonatal_sepsis'] elif child.pn_sepsis_late_neonatal: @@ -840,14 +872,29 @@ def apply_risk_of_maternal_or_neonatal_death_postnatal(self, mother_or_child, in else: cause = 'early_onset_sepsis' - # If this neonate will die then we make the appropriate changes - if self.rng.random_sample() < risk_of_death: - self.sim.modules['PregnancySupervisor'].mnh_outcome_counter[f'{cause}_n_death'] += 1 - self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=cause, - originating_module=self.sim.modules['PostnatalSupervisor']) + if self.sim.generate_event_chains: + + # This neonate will always partially die, partially survive + original_aliveness_weight = df.loc[individual_id,'aliveness_weight'] + # Individual is partially less alive + df.loc[individual_id,'aliveness_weight'] *= (1. - risk_of_death) + # Individual is partially dead + death_weight = original_aliveness_weight * risk_of_death + df.loc[individual_id, 'date_of_partial_death'].append(str(self.sim.date)) + df.loc[individual_id, 'death_weight'].append(death_weight) + df.loc[individual_id, 'cause_of_partial_death'].append(cause) - # Otherwise we reset the variables in the data frame else: + # If this neonate will die then we make the appropriate changes + if self.rng.random_sample() < risk_of_death: + + survived = False + self.sim.modules['PregnancySupervisor'].mnh_outcome_counter[f'{cause}_n_death'] += 1 + self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=cause, + originating_module=self.sim.modules['PostnatalSupervisor']) + + # If individual survived we reset the variables in the data frame + if survived: df.at[individual_id, 'pn_sepsis_late_neonatal'] = False df.at[individual_id, 'pn_sepsis_early_neonatal'] = False diff --git a/src/tlo/methods/pregnancy_helper_functions.py b/src/tlo/methods/pregnancy_helper_functions.py index 76076a5dcf..09bc670184 100644 --- a/src/tlo/methods/pregnancy_helper_functions.py +++ b/src/tlo/methods/pregnancy_helper_functions.py @@ -309,6 +309,27 @@ def log_mni_for_maternal_death(self, person_id): logger.info(key='death_mni', data=mni_to_log) +def apply_multiple_partial_deaths(self, risks, individual_id): + """ + This function applies multiple causes of partial death on individual + """ + df = self.sim.population.props + + total_risk_of_death = 0 + for key, value in risks.items(): + total_risk_of_death += value + + # Individual is partially less alive + original_aliveness_weight = df.loc[individual_id,'aliveness_weight'] + df.loc[individual_id,'aliveness_weight'] *= (1. - total_risk_of_death) + for key, value in risks.items(): + # Individual is partially dead + death_weight = original_aliveness_weight * value + df.loc[individual_id, 'date_of_partial_death'].append(str(self.sim.date)) + df.loc[individual_id, 'death_weight'].append(death_weight) + df.loc[individual_id, 'cause_of_partial_death'].append(key) + + def calculate_risk_of_death_from_causes(self, risks, target): """ This function calculates risk of death in the context of one or more 'death causing' complications in a mother of a @@ -345,12 +366,10 @@ def calculate_risk_of_death_from_causes(self, risks, target): return False -def check_for_risk_of_death_from_cause_maternal(self, individual_id, timing): +def get_risk_of_death_from_cause_maternal(self, individual_id, timing): """ - This function calculates the risk of death associated with one or more causes being experience by an individual and - determines if they will die and which of a number of competing cause is the primary cause of death - :param individual_id: individual_id of woman at risk of death - return: cause of death or False + This function calculates the risk of death associated with one or more causes being experience by an individual + return: causes and associated risks """ params = self.current_parameters df = self.sim.population.props @@ -393,10 +412,11 @@ def check_for_risk_of_death_from_cause_maternal(self, individual_id, timing): if mother.pn_postpartum_haem_secondary and (timing == 'postnatal'): causes.append('secondary_postpartum_haemorrhage') + risks = dict() + # If this list is not empty, use either CFR parameters or linear models to calculate risk of death from each # complication she is experiencing and store in a dictionary, using each cause as the key if causes: - risks = dict() def apply_effect_of_anaemia(cause): lab_params = self.sim.modules['Labour'].current_parameters @@ -447,7 +467,21 @@ def apply_effect_of_anaemia(cause): apply_effect_of_anaemia(cause) risks.update(risk) + + return risks + + +def check_for_risk_of_death_from_cause_maternal(self, individual_id, timing): + """ + This function calculates the risk of death associated with one or more causes being experience by an individual and + determines if they will die and which of a number of competing cause is the primary cause of death + :param individual_id: individual_id of woman at risk of death + return: cause of death or False + """ + risks = get_risk_of_death_from_cause_maternal(self, individual_id, timing) + + if len(risks)>0: # Call return the result from calculate_risk_of_death_from_causes function return calculate_risk_of_death_from_causes(self, risks, target='m') @@ -455,12 +489,10 @@ def apply_effect_of_anaemia(cause): return False -def check_for_risk_of_death_from_cause_neonatal(self, individual_id): +def get_risk_of_death_from_cause_neonatal(self, individual_id): """ - This function calculates the risk of death associated with one or more causes being experience by an individual and - determines if they will die and which of a number of competing cause is the primary cause of death - :param individual_id: individual_id of woman at risk of death - return: cause of death or False + This function calculates the risk of death associated with one or more causes being experience by an individual + return: causes and associated risks """ params = self.current_parameters df = self.sim.population.props @@ -505,10 +537,11 @@ def check_for_risk_of_death_from_cause_neonatal(self, individual_id): if self.congeintal_anomalies.has_all(individual_id, 'other'): causes.append('other_anomaly') + risks = dict() + # If this list is not empty, use either CFR parameters or linear models to calculate risk of death from each # complication they experiencing and store in a dictionary, using each cause as the key if causes: - risks = dict() for cause in causes: if f'{cause}_death' in self.nb_linear_models.keys(): risk = {cause: self.nb_linear_models[f'{cause}_death'].predict( @@ -517,7 +550,21 @@ def check_for_risk_of_death_from_cause_neonatal(self, individual_id): risk = {cause: params[f'cfr_{cause}']} risks.update(risk) + + return risks + + +def check_for_risk_of_death_from_cause_neonatal(self, individual_id): + """ + This function calculates the risk of death associated with one or more causes being experience by an individual and + determines if they will die and which of a number of competing cause is the primary cause of death + :param individual_id: individual_id of woman at risk of death + return: cause of death or False + """ + + risks = get_risk_of_death_from_cause_neonatal(self, individual_id) + if len(risks)>0: # Return the result from calculate_risk_of_death_from_causes function (returns primary cause of death or False) return calculate_risk_of_death_from_causes(self, risks, target='n') @@ -531,91 +578,15 @@ def update_mni_dictionary(self, individual_id): if self == self.sim.modules['PregnancySupervisor']: - mni[individual_id] = {'delete_mni': False, # if True, mni deleted in report_daly_values function - 'didnt_seek_care': False, - 'cons_not_avail': False, - 'comp_not_avail': False, - 'hcw_not_avail': False, - 'ga_anc_one': 0, - 'anc_ints': [], - 'abortion_onset': pd.NaT, - 'abortion_haem_onset': pd.NaT, - 'abortion_sep_onset': pd.NaT, - 'eclampsia_onset': pd.NaT, - 'mild_mod_aph_onset': pd.NaT, - 'severe_aph_onset': pd.NaT, - 'chorio_onset': pd.NaT, - 'chorio_in_preg': False, # use in predictor in newborn linear models - 'ectopic_onset': pd.NaT, - 'ectopic_rupture_onset': pd.NaT, - 'gest_diab_onset': pd.NaT, - 'gest_diab_diagnosed_onset': pd.NaT, - 'gest_diab_resolution': pd.NaT, - 'mild_anaemia_onset': pd.NaT, - 'mild_anaemia_resolution': pd.NaT, - 'moderate_anaemia_onset': pd.NaT, - 'moderate_anaemia_resolution': pd.NaT, - 'severe_anaemia_onset': pd.NaT, - 'severe_anaemia_resolution': pd.NaT, - 'mild_anaemia_pp_onset': pd.NaT, - 'mild_anaemia_pp_resolution': pd.NaT, - 'moderate_anaemia_pp_onset': pd.NaT, - 'moderate_anaemia_pp_resolution': pd.NaT, - 'severe_anaemia_pp_onset': pd.NaT, - 'severe_anaemia_pp_resolution': pd.NaT, - 'hypertension_onset': pd.NaT, - 'hypertension_resolution': pd.NaT, - 'obstructed_labour_onset': pd.NaT, - 'sepsis_onset': pd.NaT, - 'uterine_rupture_onset': pd.NaT, - 'mild_mod_pph_onset': pd.NaT, - 'severe_pph_onset': pd.NaT, - 'secondary_pph_onset': pd.NaT, - 'vesicovaginal_fistula_onset': pd.NaT, - 'vesicovaginal_fistula_resolution': pd.NaT, - 'rectovaginal_fistula_onset': pd.NaT, - 'rectovaginal_fistula_resolution': pd.NaT, - 'test_run': False, # used by labour module when running some model tests - 'pred_syph_infect': pd.NaT, # date syphilis is predicted to onset - 'new_onset_spe': False, - 'cs_indication': 'none' - } + mni[individual_id] = self.sim.modules['PregnancySupervisor'].default_mni_values.copy() elif self == self.sim.modules['Labour']: - labour_variables = {'labour_state': None, - # Term Labour (TL), Early Preterm (EPTL), Late Preterm (LPTL) or Post Term (POTL) - 'birth_weight': 'normal_birth_weight', - 'birth_size': 'average_for_gestational_age', - 'delivery_setting': None, # home_birth, health_centre, hospital - 'twins': df.at[individual_id, 'ps_multiple_pregnancy'], - 'twin_count': 0, - 'twin_one_comps': False, - 'pnc_twin_one': 'none', - 'bf_status_twin_one': 'none', - 'eibf_status_twin_one': False, - 'an_placental_abruption': df.at[individual_id, 'ps_placental_abruption'], - 'corticosteroids_given': False, - 'clean_birth_practices': False, - 'abx_for_prom_given': False, - 'abx_for_pprom_given': False, - 'endo_pp': False, - 'retained_placenta': False, - 'uterine_atony': False, - 'amtsl_given': False, - 'cpd': False, - 'mode_of_delivery': 'vaginal_delivery', - 'neo_will_receive_resus_if_needed': False, - # vaginal_delivery, instrumental, caesarean_section - 'hsi_cant_run': False, # True (T) or False (F) - 'sought_care_for_complication': False, # True (T) or False (F) - 'sought_care_labour_phase': 'none', - 'referred_for_cs': False, # True (T) or False (F) - 'referred_for_blood': False, # True (T) or False (F) - 'received_blood_transfusion': False, # True (T) or False (F) - 'referred_for_surgery': False, # True (T) or False (F)' - 'death_in_labour': False, # True (T) or False (F) - 'single_twin_still_birth': False, # True (T) or False (F) - 'will_receive_pnc': 'none', - 'passed_through_week_one': False} - - mni[individual_id].update(labour_variables) + + labour_default = self.sim.modules['PregnancySupervisor'].default_labour_values.copy() + mni[individual_id].update(labour_default) + + # Update from default based on individual case + mni[individual_id]['twins'] = df.at[individual_id, 'ps_multiple_pregnancy'] + mni[individual_id]['an_placental_abruption'] = df.at[individual_id, 'ps_placental_abruption'] + + diff --git a/src/tlo/methods/pregnancy_supervisor.py b/src/tlo/methods/pregnancy_supervisor.py index 1be38175f7..ef522d62bd 100644 --- a/src/tlo/methods/pregnancy_supervisor.py +++ b/src/tlo/methods/pregnancy_supervisor.py @@ -27,7 +27,9 @@ ) from tlo.methods.causes import Cause from tlo.methods.hsi_generic_first_appts import GenericFirstAppointmentsMixin -from tlo.util import BitsetHandler, read_csv_files + +from tlo.util import BitsetHandler,read_csv_files +from tlo.methods.demography import InstantaneousPartialDeath if TYPE_CHECKING: from tlo.methods.hsi_generic_first_appts import HSIEventScheduler @@ -60,6 +62,100 @@ def __init__(self, name=None): # This variable will store a Bitset handler for the property ps_abortion_complications self.abortion_complications = None + + self.default_mni_values = {'delete_mni': False, # if True, mni deleted in report_daly_values function + 'didnt_seek_care': False, + 'cons_not_avail': False, + 'comp_not_avail': False, + 'hcw_not_avail': False, + 'ga_anc_one': 0, + 'anc_ints': [], + 'abortion_onset': pd.NaT, + 'abortion_haem_onset': pd.NaT, + 'abortion_sep_onset': pd.NaT, + 'eclampsia_onset': pd.NaT, + 'mild_mod_aph_onset': pd.NaT, + 'severe_aph_onset': pd.NaT, + 'chorio_onset': pd.NaT, + 'chorio_in_preg': False, # use in predictor in newborn linear models + 'ectopic_onset': pd.NaT, + 'ectopic_rupture_onset': pd.NaT, + 'gest_diab_onset': pd.NaT, + 'gest_diab_diagnosed_onset': pd.NaT, + 'gest_diab_resolution': pd.NaT, + 'none_anaemia_onset': pd.NaT, + 'none_anaemia_resolution': pd.NaT, + 'mild_anaemia_onset': pd.NaT, + 'mild_anaemia_resolution': pd.NaT, + 'moderate_anaemia_onset': pd.NaT, + 'moderate_anaemia_resolution': pd.NaT, + 'severe_anaemia_onset': pd.NaT, + 'severe_anaemia_resolution': pd.NaT, + 'mild_anaemia_pp_onset': pd.NaT, + 'mild_anaemia_pp_resolution': pd.NaT, + 'moderate_anaemia_pp_onset': pd.NaT, + 'moderate_anaemia_pp_resolution': pd.NaT, + 'severe_anaemia_pp_onset': pd.NaT, + 'severe_anaemia_pp_resolution': pd.NaT, + 'hypertension_onset': pd.NaT, + 'hypertension_resolution': pd.NaT, + 'obstructed_labour_onset': pd.NaT, + 'sepsis_onset': pd.NaT, + 'uterine_rupture_onset': pd.NaT, + 'mild_mod_pph_onset': pd.NaT, + 'severe_pph_onset': pd.NaT, + 'secondary_pph_onset': pd.NaT, + 'vesicovaginal_fistula_onset': pd.NaT, + 'vesicovaginal_fistula_resolution': pd.NaT, + 'rectovaginal_fistula_onset': pd.NaT, + 'rectovaginal_fistula_resolution': pd.NaT, + 'test_run': False, # used by labour module when running some model tests + 'pred_syph_infect': pd.NaT, # date syphilis is predicted to onset + 'new_onset_spe': False, + 'cs_indication': 'none' + } + self.default_labour_values = {'labour_state': None, + # Term Labour (TL), Early Preterm (EPTL), Late Preterm (LPTL) or Post Term (POTL) + 'birth_weight': 'normal_birth_weight', + 'birth_size': 'average_for_gestational_age', + 'delivery_setting': None, # home_birth, health_centre, hospital + 'twins': None, + 'twin_count': 0, + 'twin_one_comps': False, + 'pnc_twin_one': 'none', + 'bf_status_twin_one': 'none', + 'eibf_status_twin_one': False, + 'an_placental_abruption': None, + 'corticosteroids_given': False, + 'clean_birth_practices': False, + 'abx_for_prom_given': False, + 'abx_for_pprom_given': False, + 'endo_pp': False, + 'retained_placenta': False, + 'uterine_atony': False, + 'amtsl_given': False, + 'cpd': False, + 'mode_of_delivery': 'vaginal_delivery', + 'neo_will_receive_resus_if_needed': False, + # vaginal_delivery, instrumental, caesarean_section + 'hsi_cant_run': False, # True (T) or False (F) + 'sought_care_for_complication': False, # True (T) or False (F) + 'sought_care_labour_phase': 'none', + 'referred_for_cs': False, # True (T) or False (F) + 'referred_for_blood': False, # True (T) or False (F) + 'received_blood_transfusion': False, # True (T) or False (F) + 'referred_for_surgery': False, # True (T) or False (F)' + 'death_in_labour': False, # True (T) or False (F) + 'single_twin_still_birth': False, # True (T) or False (F) + 'will_receive_pnc': 'none', + 'passed_through_week_one': False} + + self.default_all_mni_values = self.default_mni_values + self.default_all_mni_values.update(self.default_labour_values) + + # Finally we create a dictionary to capture the frequency of key outcomes for logging + mnh_oc = pregnancy_helper_functions.generate_mnh_outcome_counter() + self.mnh_outcome_counter = mnh_oc['counter'] # Finally we create a dictionary to capture the frequency of key outcomes for logging mnh_oc = pregnancy_helper_functions.generate_mnh_outcome_counter() @@ -399,6 +495,7 @@ def __init__(self, name=None): 'intervention_analysis_availability': Parameter( Types.REAL, 'Probability an intervention which is included in "interventions_under_analysis" will be' 'available'), + } PROPERTIES = { @@ -452,6 +549,7 @@ def read_parameters(self, resourcefilepath: Optional[Path] = None): files='parameter_values') self.load_parameters_from_dataframe(parameter_dataframe) + # Here we map 'disability' parameters to associated DALY weights to be passed to the health burden module. # Currently this module calculates and reports all DALY weights from all maternal modules if 'HealthBurden' in self.sim.modules.keys(): @@ -1240,19 +1338,37 @@ def apply_risk_of_death_from_hypertension(self, gestation_of_interest): (df.ps_ectopic_pregnancy == 'none') & ~df.la_currently_in_labour & \ (df.ps_htn_disorders == 'severe_gest_htn') - at_risk_of_death_htn = pd.Series(self.rng.random_sample(len(at_risk.loc[at_risk])) < - params['prob_monthly_death_severe_htn'], index=at_risk.loc[at_risk].index) + if self.sim.generate_event_chains: + + # Every woman at risk will partially die + for person in df.index[at_risk]: + + original_aliveness_weight = df.loc[person,'aliveness_weight'] + # Individual is partially less alive + df.loc[person,'aliveness_weight'] *= (1. - params['prob_monthly_death_severe_htn']) + # Individual is partially dead + death_weight = original_aliveness_weight * params['prob_monthly_death_severe_htn'] + df.loc[person, 'date_of_partial_death'].append(str(self.sim.date)) + df.loc[person, 'death_weight'].append(death_weight) + df.loc[person, 'cause_of_partial_death'].append('severe_gestational_hypertension') + + # Not deleting woman from mni because she always survives + else: + + at_risk_of_death_htn = pd.Series(self.rng.random_sample(len(at_risk.loc[at_risk])) < + params['prob_monthly_death_severe_htn'], index=at_risk.loc[at_risk].index) - if not at_risk_of_death_htn.loc[at_risk_of_death_htn].empty: - # Those women who die have InstantaneousDeath scheduled - for person in at_risk_of_death_htn.loc[at_risk_of_death_htn].index: - self.mnh_outcome_counter['severe_gestational_hypertension_m_death'] += 1 - self.mnh_outcome_counter['direct_mat_death'] += 1 + if not at_risk_of_death_htn.loc[at_risk_of_death_htn].empty: + # Those women who die have InstantaneousDeath scheduled + for person in at_risk_of_death_htn.loc[at_risk_of_death_htn].index: + self.mnh_outcome_counter['severe_gestational_hypertension_m_death'] += 1 + self.mnh_outcome_counter['direct_mat_death'] += 1 - self.sim.modules['Demography'].do_death(individual_id=person, cause='severe_gestational_hypertension', - originating_module=self.sim.modules['PregnancySupervisor']) + self.sim.modules['Demography'].do_death(individual_id=person, cause='severe_gestational_hypertension', + originating_module=self.sim.modules['PregnancySupervisor']) - del mni[person] + del mni[person] + def apply_risk_of_placental_abruption(self, gestation_of_interest): """ @@ -1559,21 +1675,36 @@ def apply_risk_of_death_from_monthly_complications(self, individual_id): mother = df.loc[individual_id] - # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death - # (either from one or multiple causes) and if death occurs returns the cause - potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( + survived = True + + if self.sim.generate_event_chains: + + # Get all potential causes of death and associated risks + risks = pregnancy_helper_functions.get_risk_of_death_from_cause_maternal( self, individual_id=individual_id, timing='antenatal') - - # If a cause is returned death is scheduled - if potential_cause_of_death: - pregnancy_helper_functions.log_mni_for_maternal_death(self, individual_id) - self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, - originating_module=self.sim.modules['PregnancySupervisor']) - self.mnh_outcome_counter['direct_mat_death'] += 1 - del mni[individual_id] + + pregnancy_helper_functions.apply_multiple_partial_deaths(self, risks, individual_id=individual_id) + + else: + + # Function checks df for any potential cause of death, uses CFR parameters to determine risk of death + # (either from one or multiple causes) and if death occurs returns the cause + potential_cause_of_death = pregnancy_helper_functions.check_for_risk_of_death_from_cause_maternal( + self, individual_id=individual_id, timing='antenatal') + + # If a cause is returned death is scheduled + if potential_cause_of_death: + + # If woman has been selected to die, set survived to false + survived = False + pregnancy_helper_functions.log_mni_for_maternal_death(self, individual_id) + self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=potential_cause_of_death, + originating_module=self.sim.modules['PregnancySupervisor']) + self.mnh_outcome_counter['direct_mat_death'] += 1 + del mni[individual_id] # If not we reset variables and the woman survives - else: + if survived: mni[individual_id]['didnt_seek_care'] = False if (mother.ps_htn_disorders == 'severe_pre_eclamp') and mni[individual_id]['new_onset_spe']: @@ -1991,24 +2122,45 @@ def apply(self, individual_id): if not df.at[individual_id, 'is_alive']: return + # Woman assumed to survive, will overwrite if selected for death + survived = True + # Individual risk of death is calculated through the linear model risk_of_death = self.module.ps_linear_models[f'{self.cause}_death'].predict( df.loc[[individual_id]])[individual_id] - # If the death occurs we record it here - if self.module.rng.random_sample() < risk_of_death: + if self.sim.generate_event_chains: + original_aliveness_weight = df.loc[individual_id,'aliveness_weight'] + + # Individual partially survives + df.loc[individual_id,'aliveness_weight'] *= (1. - risk_of_death) + + # Individual partially dies + death_weight = original_aliveness_weight * risk_of_death + df.loc[individual_id, 'date_of_partial_death'].append(str(self.sim.date)) + df.loc[individual_id, 'death_weight'].append(death_weight) + df.loc[individual_id, 'cause_of_partial_death'].append('severe_gestational_hypertension') - if individual_id in mni: - pregnancy_helper_functions.log_mni_for_maternal_death(self.module, individual_id) - mni[individual_id]['delete_mni'] = True + else: - self.module.mnh_outcome_counter[f'{self.cause}_m_death'] += 1 - self.module.mnh_outcome_counter['direct_mat_death'] += 1 - self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=f'{self.cause}', - originating_module=self.sim.modules['PregnancySupervisor']) + # If the death occurs we record it here + if self.module.rng.random_sample() < risk_of_death: + + survived = False - else: - # Otherwise we reset any variables + if individual_id in mni: + pregnancy_helper_functions.log_mni_for_maternal_death(self.module, individual_id) + mni[individual_id]['delete_mni'] = True + + self.module.mnh_outcome_counter[f'{self.cause}_m_death'] += 1 + self.module.mnh_outcome_counter['direct_mat_death'] += 1 + self.sim.modules['Demography'].do_death(individual_id=individual_id, cause=f'{self.cause}', + originating_module=self.sim.modules['PregnancySupervisor']) + + # If individual survived, updated variables + if survived: + + # If woman survived we reset any variables if self.cause == 'ectopic_pregnancy': df.at[individual_id, 'ps_ectopic_pregnancy'] = 'none' diff --git a/src/tlo/methods/rti.py b/src/tlo/methods/rti.py index 994b8d1054..92f79f7538 100644 --- a/src/tlo/methods/rti.py +++ b/src/tlo/methods/rti.py @@ -2776,7 +2776,7 @@ class RTIPollingEvent(RegularEvent, PopulationScopeEventMixin): def __init__(self, module): """Schedule to take place every month """ - super().__init__(module, frequency=DateOffset(months=1)) + super().__init__(module, frequency=DateOffset(months=1000)) # Single polling event p = module.parameters # Parameters which transition the model between states self.base_1m_prob_rti = (p['base_rate_injrti'] / 12) @@ -2864,9 +2864,15 @@ def apply(self, population): .when('.between(70,79)', self.rr_injrti_age7079), Predictor('li_ex_alc').when(True, self.rr_injrti_excessalcohol) ) + #if self.sim.generate_event_chains is True and self.sim.generate_event_chains_overwrite_epi is True: + #pred = 1.0 + #else: pred = eq.predict(df.loc[rt_current_non_ind]) + + random_draw_in_rti = self.module.rng.random_sample(size=len(rt_current_non_ind)) selected_for_rti = rt_current_non_ind[pred > random_draw_in_rti] + # Update to say they have been involved in a rti df.loc[selected_for_rti, 'rt_road_traffic_inc'] = True # Set the date that people were injured to now @@ -4874,6 +4880,7 @@ def __init__(self, module, person_id): self.treated_code = 'none' def apply(self, person_id, squeeze_factor): + self._number_of_times_this_event_has_run += 1 df = self.sim.population.props rng = self.module.rng @@ -4922,10 +4929,12 @@ def apply(self, person_id, squeeze_factor): # injury is being treated in this surgery # find untreated injury codes that are treated with major surgery relevant_codes = np.intersect1d(injuries_to_be_treated, surgically_treated_codes) + # check that the person sent here has an appropriate code(s) assert len(relevant_codes) > 0 # choose a code at random self.treated_code = rng.choice(relevant_codes) + if request_outcome: # check the people sent here hasn't died due to rti, have had their injuries diagnosed and been through # RTI_Med @@ -5012,7 +5021,9 @@ def apply(self, person_id, squeeze_factor): # ------------------------------------- Perm disability from amputation ------------------------------------ codes = ['782', '782a', '782b', '782c', '783', '882', '883', '884'] + if self.treated_code in codes: + # Track whether they are permanently disabled df.at[person_id, 'rt_perm_disability'] = True # Find the column and code where the permanent injury is stored diff --git a/src/tlo/methods/tb.py b/src/tlo/methods/tb.py index d9ba7309e0..71361a7951 100644 --- a/src/tlo/methods/tb.py +++ b/src/tlo/methods/tb.py @@ -864,29 +864,31 @@ def initialise_population(self, population): df["tb_on_ipt"] = False df["tb_date_ipt"] = pd.NaT - # # ------------------ infection status ------------------ # - # WHO estimates of active TB for 2010 to get infected initial population - # don't need to scale or include treated proportion as no-one on treatment yet - inc_estimates = p["who_incidence_estimates"] - incidence_year = (inc_estimates.loc[ - (inc_estimates.year == self.sim.date.year), "incidence_per_100k" - ].values[0]) / 100_000 - - incidence_year = incidence_year * p["scaling_factor_WHO"] - self.assign_active_tb( - population, - strain="ds", - incidence=incidence_year) - - self.assign_active_tb( - population, - strain="mdr", - incidence=incidence_year * p['prop_mdr2010']) - - self.send_for_screening_general( - population - ) # send some baseline population for screening + # # ------------------ infection status ------------------ # + if self.sim.generate_event_chains is False or self.sim.generate_event_chains is None: + # WHO estimates of active TB for 2010 to get infected initial population + # don't need to scale or include treated proportion as no-one on treatment yet + inc_estimates = p["who_incidence_estimates"] + incidence_year = (inc_estimates.loc[ + (inc_estimates.year == self.sim.date.year), "incidence_per_100k" + ].values[0]) / 100_000 + + incidence_year = incidence_year * p["scaling_factor_WHO"] + + self.assign_active_tb( + population, + strain="ds", + incidence=incidence_year) + + self.assign_active_tb( + population, + strain="mdr", + incidence=incidence_year * p['prop_mdr2010']) + + self.send_for_screening_general( + population + ) # send some baseline population for screening def initialise_simulation(self, sim): """ @@ -899,8 +901,10 @@ def initialise_simulation(self, sim): sim.schedule_event(TbActiveEvent(self), sim.date) sim.schedule_event(TbRegularEvents(self), sim.date) sim.schedule_event(TbSelfCureEvent(self), sim.date) + sim.schedule_event(TbActiveCasePoll(self), sim.date + DateOffset(years=1)) + # 2) log at the end of the year # Optional: Schedule the scale-up of programs if self.parameters["type_of_scaleup"] != 'none': @@ -1402,6 +1406,53 @@ def is_subset(col_for_set, col_for_subset): # # TB infection event # # --------------------------------------------------------------------------- +class TbActiveCasePollGenerateData(RegularEvent, PopulationScopeEventMixin): + """The Tb Regular Poll Event for Data Generation for assigning active infections + * selects everyone to develop an active infection and schedules onset of active tb + sometime during the simulation + """ + + def __init__(self, module): + super().__init__(module, frequency=DateOffset(years=120)) + + def apply(self, population): + + df = population.props + now = self.sim.date + rng = self.module.rng + # Make everyone who is alive and not infected (no-one should be) susceptible + susc_idx = df.loc[ + df.is_alive + & (df.tb_inf != "active") + ].index + + len(susc_idx) + + middle_index = len(susc_idx) // 2 + + # Will equally split two strains among the population + list_ds = susc_idx[:middle_index] + list_mdr = susc_idx[middle_index:] + + # schedule onset of active tb. This will be equivalent to the "Onset", so it + # doesn't matter how long after we have decided which infection this is. + for person_id in list_ds: + date_progression = now + pd.DateOffset( + # At some point during their lifetime, this person will develop TB + days=self.module.rng.randint(0, 365*(int(self.sim.end_date.year - self.sim.date.year)+1)) + ) + # set date of active tb - properties will be updated at TbActiveEvent poll daily + df.at[person_id, "tb_scheduled_date_active"] = date_progression + df.at[person_id, "tb_strain"] = "ds" + + for person_id in list_mdr: + date_progression = now + pd.DateOffset( + days=rng.randint(0, 365*int(self.sim.end_date.year - self.sim.start_date.year + 1)) + ) + # set date of active tb - properties will be updated at TbActiveEvent poll daily + df.at[person_id, "tb_scheduled_date_active"] = date_progression + df.at[person_id, "tb_strain"] = "mdr" + class TbActiveCasePoll(RegularEvent, PopulationScopeEventMixin): """The Tb Regular Poll Event for assigning active infections @@ -1476,7 +1527,6 @@ def apply(self, population): self.module.update_parameters_for_program_scaleup() # note also culture test used in target/max scale-up in place of clinical dx - class TbActiveEvent(RegularEvent, PopulationScopeEventMixin): """ * check for those with dates of active tb onset within last time-period diff --git a/src/tlo/simulation.py b/src/tlo/simulation.py index d2560f92d9..d10ef1b24b 100644 --- a/src/tlo/simulation.py +++ b/src/tlo/simulation.py @@ -8,9 +8,12 @@ import time from collections import Counter, OrderedDict from pathlib import Path +from typing import Optional from typing import TYPE_CHECKING, Optional - +import pandas as pd +import tlo.population import numpy as np +from tlo.util import df_to_EAV, convert_chain_links_into_EAV try: import dill @@ -35,6 +38,9 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +logger_chains = logging.getLogger("tlo.methods.event") +logger_chains.setLevel(logging.INFO) + class SimulationPreviouslyInitialisedError(Exception): """Exception raised when trying to initialise an already initialised simulation.""" @@ -103,10 +109,16 @@ def __init__( self.date = self.start_date = start_date self.modules = OrderedDict() self.event_queue = EventQueue() + + self.generate_event_chains = True + self.generate_event_chains_modules_of_interest = [] + self.generate_event_chains_ignore_events = [] + self.end_date = None self.output_file = None self.population: Optional[Population] = None - + + self.show_progress_bar = show_progress_bar self.resourcefilepath = Path(resourcefilepath) @@ -282,13 +294,23 @@ def make_initial_population(self, *, n: int) -> None: key="debug", data=f"{module.name}.initialise_population() {time.time() - start1} s", ) - + + # When logging events for each individual to reconstruct chains, only the changes in individual properties will be logged. + # At the start of the simulation + when a new individual is born, we therefore want to store all of their properties at the start. + if self.generate_event_chains: + + # EDNAV structure to capture status of individuals at the start of the simulation + ednav = df_to_EAV(self.population.props, self.date, 'StartOfSimulation') + + logger.info(key='event_chains', + data = ednav.to_dict(), + description='Links forming chains of events for simulated individuals') + end = time.time() logger.info(key="info", data=f"make_initial_population() {end - start} s") def initialise(self, *, end_date: Date) -> None: """Initialise all modules in simulation. - :param end_date: Date to end simulation on - accessible to modules to allow initialising data structures which may depend (in size for example) on the date range being simulated. @@ -298,6 +320,16 @@ def initialise(self, *, end_date: Date) -> None: raise SimulationPreviouslyInitialisedError(msg) self.date = self.start_date self.end_date = end_date # store the end_date so that others can reference it + + #self.generate_event_chains = generate_event_chains + if self.generate_event_chains: + # For now keep these fixed, eventually they will be input from user + self.generate_event_chains_modules_of_interest = [self.modules] + self.generate_event_chains_ignore_events = ['AgeUpdateEvent','HealthSystemScheduler', 'SimplifiedBirthsPoll','DirectBirth', 'LifestyleEvent', 'TbActiveCasePollGenerateData','HivPollingEventForDataGeneration', 'RTIPollingEvent', 'DepressionPollingEvent','Get_Current_DALYs', 'PostnatalSupervisorEvent','PregnancySupervisorEvent', 'MalariaPollingEventDistrict','CopdPollEvent', 'CardioMetabolicDisorders_MainPollingEvent', 'CardioMetabolicDisorders_LoggingEvent'] + + # Reorder columns to place the new columns at the front + pd.set_option('display.max_columns', None) + for module in self.modules.values(): module.initialise_simulation(self) self._initialised = True @@ -366,6 +398,8 @@ def run_simulation_to(self, *, to_date: Date) -> None: :param to_date: Date to simulate up to but not including - must be before or equal to simulation end date specified in call to :py:meth:`initialise`. """ + open('output.txt', mode='a') + if not self._initialised: msg = "Simulation must be initialised before calling run_simulation_to" raise SimulationNotInitialisedError(msg) @@ -382,6 +416,7 @@ def run_simulation_to(self, *, to_date: Date) -> None: self._update_progress_bar(progress_bar, date) self.fire_single_event(event, date) self.date = to_date + if self.show_progress_bar: progress_bar.stop() @@ -423,6 +458,7 @@ def fire_single_event(self, event: Event, date: Date) -> None: """ self.date = date event.run() + def do_birth(self, mother_id: int) -> int: """Create a new child person. @@ -436,6 +472,21 @@ def do_birth(self, mother_id: int) -> int: child_id = self.population.do_birth() for module in self.modules.values(): module.on_birth(mother_id, child_id) + + if self.generate_event_chains: + # When individual is born, store their initial properties to provide a starting point to the chain of property + # changes that this individual will undergo as a result of events taking place. + link_info = self.population.props.loc[child_id].to_dict() + link_info['EventName'] = 'Birth' + chain_links = {} + chain_links[child_id] = link_info # Convert to string to avoid issue of length + + ednav = convert_chain_links_into_EAV(chain_links) + + logger.info(key='event_chains', + data = ednav.to_dict(), + description='Links forming chains of events for simulated individuals') + return child_id def find_events_for_person(self, person_id: int) -> list[tuple[Date, Event]]: diff --git a/src/tlo/util.py b/src/tlo/util.py index efe17a9920..fdc7390b98 100644 --- a/src/tlo/util.py +++ b/src/tlo/util.py @@ -94,6 +94,33 @@ def transition_states(initial_series: pd.Series, prob_matrix: pd.DataFrame, rng: return final_states +def df_to_EAV(df, date, event_name): + """Function to convert dataframe into EAV""" + eav = df.stack().reset_index() + eav.columns = ['E', 'A', 'V'] + eav['EventName'] = event_name + eav = eav[["E", "EventName", "A", "V"]] + + return eav + + +def convert_chain_links_into_EAV(chain_links): + df = pd.DataFrame.from_dict(chain_links, orient="index") + id_cols = ["EventName"] + + eav = df.reset_index().melt( + id_vars=["index"] + id_cols, # index = person ID + var_name="A", + value_name="V" + ) + + eav.rename(columns={"index": "E"}, inplace=True) + + eav = eav[["E", "EventName", "A", "V"]] + + return eav + + def sample_outcome(probs: pd.DataFrame, rng: np.random.RandomState): """ Helper function to randomly sample an outcome for each individual in a population from a set of probabilities that are specific to each individual. @@ -115,10 +142,13 @@ def sample_outcome(probs: pd.DataFrame, rng: np.random.RandomState): cumsum = _probs.cumsum(axis=1) draws = pd.Series(rng.rand(len(cumsum)), index=cumsum.index) y = cumsum.gt(draws, axis=0) - outcome = y.idxmax(axis=1) + if not y.empty: + outcome = y.idxmax(axis=1) + # return as a dict of form {person_id: outcome} only in those cases where the outcome is one of the events. + return outcome.loc[outcome != '_'].to_dict() - # return as a dict of form {person_id: outcome} only in those cases where the outcome is one of the events. - return outcome.loc[outcome != '_'].to_dict() + else: + return dict() BitsetDType = Property.PANDAS_TYPE_MAP[Types.BITSET] diff --git a/tests/test_maternal_health_helper_and_analysis_functions.py b/tests/test_maternal_health_helper_and_analysis_functions.py index a762fe4155..f396d42a47 100644 --- a/tests/test_maternal_health_helper_and_analysis_functions.py +++ b/tests/test_maternal_health_helper_and_analysis_functions.py @@ -50,6 +50,7 @@ def apply(self, person_id, squeeze_factor): return hsi_event def test_interventions_are_delivered_as_expected_not_during_analysis(seed): + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) sim.make_initial_population(n=100) @@ -79,6 +80,7 @@ def test_interventions_are_delivered_as_expected_not_during_analysis(seed): hsi_event = get_dummy_hsi(sim, mother_id, id=0, fl=0) def override_dummy_cons(value): + updated_cons = {k: value for (k, v) in sim.modules['Labour'].item_codes_lab_consumables['delivery_core'].items()} sim.modules['HealthSystem'].override_availability_of_consumables(updated_cons) @@ -121,6 +123,7 @@ def override_dummy_cons(value): def test_interventions_are_delivered_as_expected_during_analysis(seed): + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) sim.make_initial_population(n=100) @@ -142,6 +145,7 @@ def test_interventions_are_delivered_as_expected_during_analysis(seed): hsi_event = get_dummy_hsi(sim, mother_id, id=0, fl=0) def override_dummy_cons(value): + updated_cons = {k: value for (k, v) in sim.modules['Labour'].item_codes_lab_consumables['delivery_core'].items()} sim.modules['HealthSystem'].override_availability_of_consumables(updated_cons) @@ -170,6 +174,7 @@ def override_dummy_cons(value): def test_analysis_analysis_events_run_as_expected_and_update_parameters(seed): """Test that the analysis events run when scheduled and that they update the correct parameters as expected when they run""" + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) @@ -202,6 +207,7 @@ def test_analysis_analysis_events_run_as_expected_and_update_parameters(seed): unchanged_odds_anc = pparams['odds_early_init_anc4'][0] unchanged_odds_pnc = lparams['odds_will_attend_pnc'][0] + sim.make_initial_population(n=100) # run the model for 1 day sim.make_initial_population(n=100) sim.simulate(end_date=Date(2010, 1, 2)) @@ -230,6 +236,7 @@ def test_analysis_analysis_events_run_as_expected_and_update_parameters(seed): def test_analysis_analysis_events_run_as_expected_when_using_sensitivity_max_parameters(seed): + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) lparams = sim.modules['Labour'].parameters @@ -281,6 +288,7 @@ def test_analysis_analysis_events_run_as_expected_when_using_sensitivity_max_par def test_analysis_analysis_events_run_as_expected_when_using_sensitivity_min_parameters(seed): + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) lparams = sim.modules['Labour'].parameters @@ -324,6 +332,7 @@ def test_analysis_analysis_events_run_as_expected_when_using_sensitivity_min_par def test_analysis_events_force_availability_of_consumables_when_scheduled_in_anc(seed): """Test that when analysis is being conducted during a simulation that consumable availability is determined via some pre-defined analysis parameter and not via the health system within the ANC HSIs""" + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) sim.make_initial_population(n=100) @@ -338,6 +347,7 @@ def test_analysis_events_force_availability_of_consumables_when_scheduled_in_anc cparams['sensitivity_blood_test_syphilis'] = [1.0, 1.0] cparams['specificity_blood_test_syphilis'] = [1.0, 1.0] + sim.make_initial_population(n=100) sim.simulate(end_date=Date(2010, 1, 2)) # check the event ran @@ -407,6 +417,7 @@ def test_analysis_events_force_availability_of_consumables_when_scheduled_in_anc def test_analysis_events_force_availability_of_consumables_for_sba_analysis(seed): """Test that when analysis is being conducted during a simulation that consumable availability is determined via some pre-defined analysis parameter and not via the health system within the SBA HSIs""" + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) @@ -543,6 +554,7 @@ def test_analysis_events_force_availability_of_consumables_for_sba_analysis(seed def test_analysis_events_force_availability_of_consumables_for_pnc_analysis(seed): """Test that when analysis is being conducted during a simulation that consumable availability is determined via some pre-defined analysis parameter and not via the health system within the PNC HSIs""" + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) @@ -611,6 +623,7 @@ def test_analysis_events_force_availability_of_consumables_for_pnc_analysis(seed def test_analysis_events_force_availability_of_consumables_for_newborn_hsi(seed): """Test that when analysis is being conducted during a simulation that consumable availability is determined via some pre-defined analysis parameter and not via the health system within the newborn HSIs""" + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) sim.make_initial_population(n=100) @@ -626,6 +639,7 @@ def test_analysis_events_force_availability_of_consumables_for_newborn_hsi(seed) lparams['pnc_availability_probability'] = 1.0 lparams['bemonc_availability'] = 1.0 + sim.make_initial_population(n=100) sim.simulate(end_date=Date(2010, 1, 2)) df = sim.population.props @@ -681,6 +695,7 @@ def test_analysis_events_force_availability_of_consumables_for_newborn_hsi(seed) def test_analysis_events_circumnavigates_sf_and_competency_parameters(seed): """Test that the analysis event correctly overrides the parameters which controle whether the B/CEmONC signal functions can run""" + sim = Simulation(start_date=start_date, seed=seed, resourcefilepath=resourcefilepath) sim.register(*fullmodel()) diff --git a/tests/test_mnh_cohort.py b/tests/test_mnh_cohort.py new file mode 100644 index 0000000000..7bfcbed194 --- /dev/null +++ b/tests/test_mnh_cohort.py @@ -0,0 +1,85 @@ +import os + +import pandas as pd + +from pathlib import Path + +from tlo import Date, Simulation, logging +from tlo.methods import mnh_cohort_module +from tlo.methods.fullmodel import fullmodel +from tlo.analysis.utils import parse_log_file + +# The resource files +try: + resourcefilepath = Path(os.path.dirname(__file__)) / '../resources' +except NameError: + # running interactively + resourcefilepath = Path('./resources') + +start_date = Date(2024, 1, 1) + + +def register_modules(sim): + """Defines sim variable and registers all modules that can be called when running the full suite of pregnancy + modules""" + + sim.register(*fullmodel(resourcefilepath=resourcefilepath), + mnh_cohort_module.MaternalNewbornHealthCohort(resourcefilepath=resourcefilepath)) + +def test_run_sim_with_mnh_cohort(tmpdir, seed): + sim = Simulation(start_date=start_date, seed=12345, log_config={"filename": "log", "custom_levels":{ + "*": logging.DEBUG},"directory": tmpdir}) + + register_modules(sim) + sim.make_initial_population(n=3000) + sim.simulate(end_date=Date(2025, 1, 2)) + + output= parse_log_file(sim.log_filepath) + live_births = len(output['tlo.methods.demography']['on_birth']) + + deaths_df = output['tlo.methods.demography']['death'] + prop_deaths_df = output['tlo.methods.demography.detail']['properties_of_deceased_persons'] + + dir_mat_deaths = deaths_df.loc[(deaths_df['label'] == 'Maternal Disorders')] + init_indir_mat_deaths = prop_deaths_df.loc[(prop_deaths_df['is_pregnant'] | prop_deaths_df['la_is_postpartum']) & + (prop_deaths_df['cause_of_death'].str.contains('Malaria|Suicide|ever_stroke|diabetes|' + 'chronic_ischemic_hd|ever_heart_attack|' + 'chronic_kidney_disease') | + (prop_deaths_df['cause_of_death'] == 'TB'))] + + hiv_mat_deaths = prop_deaths_df.loc[(prop_deaths_df['is_pregnant'] | prop_deaths_df['la_is_postpartum']) & + (prop_deaths_df['cause_of_death'].str.contains('AIDS_non_TB|AIDS_TB'))] + + indir_mat_deaths = len(init_indir_mat_deaths) + (len(hiv_mat_deaths) * 0.3) + total_deaths = len(dir_mat_deaths) + indir_mat_deaths + + # TOTAL_DEATHS + mmr = (total_deaths / live_births) * 100_000 + + print(f'The MMR for this simulation is {mmr}') + print(f'The maternal deaths for this simulation (unscaled) are {total_deaths}') + print(f'The total maternal deaths for this simulation (scaled) are ' + f'{total_deaths * output["tlo.methods.population"]["scaling_factor"]["scaling_factor"].values[0]}') + + maternal_dalys = output['tlo.methods.healthburden']['dalys_stacked']['Maternal Disorders'].sum() + print(f'The maternal DALYs for this simulation (unscaled) are {maternal_dalys}') + + +def test_mnh_cohort_module_updates_properties_as_expected(tmpdir, seed): + sim = Simulation(start_date=start_date, seed=seed, log_config={"filename": "log", "directory": tmpdir}) + + register_modules(sim) + sim.make_initial_population(n=3000) + sim.simulate(end_date=sim.date + pd.DateOffset(days=0)) + + df = sim.population.props + pop = df.loc[df.is_alive] + + assert (df.loc[pop.index, 'sex'] == 'F').all() + assert df.loc[pop.index, 'is_pregnant'].all() + assert not (pd.isnull(df.loc[pop.index, 'la_due_date_current_pregnancy'])).all() + assert (df.loc[pop.index, 'co_contraception'] == 'not_using').all() + + # orig = sim.population.new_row + # assert (df.dtypes == orig.dtypes).all() + diff --git a/tests/test_rti.py b/tests/test_rti.py index 2b65595782..b696a249f5 100644 --- a/tests/test_rti.py +++ b/tests/test_rti.py @@ -25,6 +25,17 @@ end_date = Date(2012, 1, 1) popsize = 1000 +@pytest.mark.slow +def test_data_harvesting(seed): + """ + This test runs a simulation with a functioning health system with full service availability and no set + constraints + """ + # create sim object + sim = create_basic_rti_sim(popsize, seed) + # run simulation + sim.simulate(end_date=end_date) + exit(-1) def check_dtypes(simulation): # check types of columns in dataframe, check they are the same, list those that aren't @@ -66,6 +77,7 @@ def test_run(seed): check_dtypes(sim) + @pytest.mark.slow def test_all_injuries_run(seed): """