Skip to content

Commit c616315

Browse files
thewatitamuritbhallett
authored
Symptom Bookkeeping (#1626)
Co-authored-by: Asif Tamuri <[email protected]> Co-authored-by: Tim Hallett <[email protected]>
1 parent d0330a5 commit c616315

File tree

3 files changed

+160
-17
lines changed

3 files changed

+160
-17
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
A script to register so many symptoms and run the has_what function so many times.
3+
For use in profiling.
4+
"""
5+
6+
import time
7+
from pathlib import Path
8+
9+
import numpy as np
10+
11+
from tlo import Date, Simulation
12+
from tlo.methods import demography, symptommanager
13+
14+
15+
def setup_simulation(pop_size=100_000, num_symptoms=50):
16+
"""Set up a simulation with population and many symptoms."""
17+
start_date = Date(2010, 1, 1)
18+
19+
resource_dir = Path("./resources")
20+
21+
sim = Simulation(
22+
start_date=start_date,
23+
seed=0,
24+
#log_config={"filename": "symptom_profiling", "directory": "./outputs"},
25+
resourcefilepath=resource_dir
26+
)
27+
28+
sim.register(demography.Demography(), symptommanager.SymptomManager())
29+
30+
# Register symptoms
31+
sm = sim.modules['SymptomManager']
32+
for i in range(num_symptoms):
33+
sm.register_symptom(symptommanager.Symptom(name=f'symptom_{i}'))
34+
35+
# Initialize population - this will create the symptom properties
36+
sim.make_initial_population(n=pop_size)
37+
38+
return sim
39+
40+
41+
def assign_random_symptoms(sim, symptom_prob=0.1):
42+
"""Assign random symptoms to the population."""
43+
df = sim.population.props
44+
sm = sim.modules['SymptomManager']
45+
46+
# Assign symptoms randomly
47+
for symptom in sm.symptom_names:
48+
# Random subset of population to have this symptom
49+
has_symptom = np.random.random(len(df)) < symptom_prob
50+
person_ids = df.index[has_symptom].tolist()
51+
52+
if person_ids:
53+
sm.change_symptom(
54+
person_id=person_ids,
55+
symptom_string=symptom,
56+
add_or_remove='+',
57+
disease_module=sm,
58+
duration_in_days=None
59+
)
60+
61+
62+
def profile_has_what(sim, num_tests=1000000):
63+
"""Profiling has_what function by calling it repeatedly."""
64+
df = sim.population.props
65+
sm = sim.modules['SymptomManager']
66+
67+
# Get random sample of person_ids to test
68+
test_ids = np.random.choice(df.index[df.is_alive], size=num_tests, replace=True)
69+
70+
# Time the has_what function
71+
start_time = time.time()
72+
73+
results = []
74+
for person_id in test_ids:
75+
results.append(sm.has_what(person_id))
76+
77+
elapsed = time.time() - start_time
78+
avg_time = elapsed / num_tests
79+
80+
print(f"Tested has_what() {num_tests} times")
81+
print(f"Total time: {elapsed:.4f} seconds")
82+
print(f"Average time per call: {avg_time:.6f} seconds")
83+
print(f"First 5 results: {results[:5]}")
84+
85+
return avg_time
86+
87+
88+
print("Setting up simulation...")
89+
sim = setup_simulation(pop_size=100_000, num_symptoms=50)
90+
91+
print("Assigning random symptoms...")
92+
assign_random_symptoms(sim, symptom_prob=0.6)
93+
94+
print("\nProfiling has_what...")
95+
avg_time = profile_has_what(sim, num_tests=1000000)
96+
97+
print("\nProfiling complete!")
98+
print(f"Average time per has_what call: {avg_time:.6f} seconds")

src/tlo/methods/demography.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(self, name=None, equal_allocation_by_district: bool = False):
8383
self.gbd_causes_of_death = set() # will store all the causes of death defined in the GBD data
8484
self.gbd_causes_of_death_not_represented_in_disease_modules = set()
8585
# will store causes of death in GBD not represented in the simulation
86-
self.other_death_poll = None # will hold pointer to the OtherDeathPoll object
86+
self.other_death_poll = None # will hold pointer to the OtherDeathPoll object
8787
self.districts = None # will store all the districts in a list
8888

8989
OPTIONAL_INIT_DEPENDENCIES = {'ImprovedHealthSystemAndCareSeekingScenarioSwitcher'}
@@ -175,14 +175,14 @@ def read_parameters(self, resourcefilepath: Optional[Path] = None):
175175
# Lookup dicts to map from district_num_of_residence (in the df) and District name and Region name
176176
self.districts = self.parameters['pop_2010']['District'].drop_duplicates().to_list()
177177
self.parameters['district_num_to_district_name'] = \
178-
self.parameters['pop_2010'][['District_Num', 'District']].drop_duplicates()\
179-
.set_index('District_Num')['District']\
180-
.to_dict()
178+
self.parameters['pop_2010'][['District_Num', 'District']].drop_duplicates() \
179+
.set_index('District_Num')['District'] \
180+
.to_dict()
181181

182182
self.parameters['district_num_to_region_name'] = \
183-
self.parameters['pop_2010'][['District_Num', 'Region']].drop_duplicates()\
184-
.set_index('District_Num')['Region']\
185-
.to_dict()
183+
self.parameters['pop_2010'][['District_Num', 'Region']].drop_duplicates() \
184+
.set_index('District_Num')['Region'] \
185+
.to_dict()
186186

187187
districts_in_region = defaultdict(set)
188188
for _district in self.parameters['pop_2010'][['District', 'Region']].drop_duplicates().itertuples():
@@ -563,6 +563,10 @@ def do_death(self, individual_id: int, cause: str, originating_module: Module):
563563
if person.hs_is_inpatient:
564564
self.sim.modules['HealthSystem'].remove_beddays_footprint(person_id=individual_id)
565565

566+
# Clear symptoms for the deceased person
567+
if 'SymptomManager' in self.sim.modules:
568+
self.sim.modules['SymptomManager'].clear_symptoms_for_deceased_person(individual_id)
569+
566570
def create_mappers_from_causes_of_death_to_label(self):
567571
"""Use a helper function to create mappers for causes of death to label."""
568572
return create_mappers_from_causes_to_label(
@@ -683,6 +687,7 @@ class OtherDeathPoll(RegularEvent, PopulationScopeEventMixin):
683687
It does this by computing the GBD death rates that are implied by all the causes of death other than those that are
684688
represented in the disease module registered in this simulation.
685689
"""
690+
686691
def __init__(self, module):
687692
super().__init__(module, frequency=DateOffset(months=1))
688693
self.causes_to_represent = self.module.gbd_causes_of_death_not_represented_in_disease_modules

src/tlo/methods/symptommanager.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from collections import defaultdict
1717
from pathlib import Path
18-
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
18+
from typing import TYPE_CHECKING, List, Optional, Sequence, Set
1919

2020
import numpy as np
2121
import pandas as pd
@@ -26,11 +26,14 @@
2626
from tlo.util import BitsetHandler
2727

2828
if TYPE_CHECKING:
29+
from typing import Union
30+
2931
from tlo.population import IndividualProperties
3032

3133
logger = logging.getLogger(__name__)
3234
logger.setLevel(logging.INFO)
3335

36+
3437
# ---------------------------------------------------------------------------------------------------------
3538
# MODULE DEFINITIONS
3639
# ---------------------------------------------------------------------------------------------------------
@@ -218,6 +221,7 @@ def __init__(self, name=None, spurious_symptoms=None):
218221

219222
self.recognised_module_names = None
220223
self.spurious_symptom_resolve_event = None
224+
self.symptom_tracker = defaultdict(set)
221225

222226
def get_column_name_for_symptom(self, symptom_name):
223227
"""get the column name that corresponds to the symptom_name"""
@@ -395,6 +399,10 @@ def change_symptom(self, person_id, symptom_string, add_or_remove, disease_modul
395399
self.bsh.set(person_id, disease_module.name, columns=sy_columns)
396400
self._persons_with_newly_onset_symptoms = self._persons_with_newly_onset_symptoms.union(person_id)
397401

402+
# Update symptom tracker
403+
for pid in person_id:
404+
self.symptom_tracker[pid] |= set(symptom_string)
405+
398406
# If a duration is given, schedule the auto-resolve event to turn off these symptoms after specified time.
399407
if duration_in_days is not None:
400408
auto_resolve_event = SymptomManager_AutoResolveEvent(self,
@@ -417,10 +425,17 @@ def change_symptom(self, person_id, symptom_string, add_or_remove, disease_modul
417425
# Do the remove:
418426
self.bsh.unset(person_id, disease_module.name, columns=sy_columns)
419427

428+
# Update symptom tracker. Remove if no other module is causing this symptom.
429+
for pid in person_id:
430+
for sym in symptom_string:
431+
symptom_col = self.get_column_name_for_symptom(sym)
432+
if self.bsh.is_empty(pid, columns=symptom_col):
433+
self.symptom_tracker[pid].discard(sym)
434+
420435
def who_has(self, list_of_symptoms):
421436
"""
422437
This is a helper function to look up who has a particular symptom or set of symptoms.
423-
It returns a list of indicies for person that have all of the symptoms specified
438+
It returns a list of indices for person that has all of the symptoms specified
424439
425440
:param: list_of_symptoms : string or list of strings for the symptoms of interest
426441
:return: list of person_ids for those with all of the symptoms in list_of_symptoms who are alive
@@ -462,7 +477,7 @@ def who_not_have(self, symptom_string: str) -> pd.Index:
462477
& self.bsh.is_empty(
463478
slice(None), columns=self.get_column_name_for_symptom(symptom_string)
464479
)
465-
]
480+
]
466481

467482
def has_what(
468483
self,
@@ -492,6 +507,10 @@ def has_what(
492507
else True
493508
), "Disease Module Name is not recognised"
494509

510+
# Faster to get current symptoms using tracker when no disease is specified
511+
if disease_module is None and person_id is not None:
512+
return list(self._get_current_symptoms_from_tracker(person_id))
513+
495514
if individual_details is not None:
496515
# We are working in an IndividualDetails context, avoid lookups to the
497516
# population DataFrame as we have this context stored already.
@@ -503,10 +522,10 @@ def has_what(
503522
symptom
504523
for symptom in self.symptom_names
505524
if individual_details[
506-
self.bsh._get_columns(self.get_column_name_for_symptom(symptom))
507-
]
508-
& int_repr
509-
!= 0
525+
self.bsh._get_columns(self.get_column_name_for_symptom(symptom))
526+
]
527+
& int_repr
528+
!= 0
510529
]
511530
else:
512531
return [
@@ -582,6 +601,17 @@ def clear_symptoms(self, person_id: Union[int, Sequence[int]], disease_module: M
582601
sy_columns = [self.get_column_name_for_symptom(sym) for sym in self.symptom_names]
583602
self.bsh.unset(person_id, disease_module.name, columns=sy_columns)
584603

604+
# Update bookkeeping
605+
for pid in person_id:
606+
for sym in self.symptom_names:
607+
symptom_col = self.get_column_name_for_symptom(sym)
608+
if self.bsh.is_empty(pid, columns=symptom_col):
609+
self.symptom_tracker[pid].discard(sym)
610+
611+
# Remove the person's entry from the tracker is the symptom set is empty
612+
if pid in self.symptom_tracker and not self.symptom_tracker[pid]:
613+
del self.symptom_tracker[pid]
614+
585615
def caused_by(self, disease_module: Module):
586616
"""Find the persons experiencing symptoms due to a particular module.
587617
Returns a dict of the form {<<person_id>>, <<list_of_symptoms>>}."""
@@ -600,6 +630,17 @@ def get_persons_with_newly_onset_symptoms(self):
600630
def reset_persons_with_newly_onset_symptoms(self):
601631
self._persons_with_newly_onset_symptoms.clear()
602632

633+
def _get_current_symptoms_from_tracker(self, person_id: int) -> Set[str]:
634+
"""Get the current symptoms for a person. Works with bookkeeping dictionary"""
635+
return self.symptom_tracker.get(person_id, set())
636+
637+
def clear_symptoms_for_deceased_person(self, person_id: int):
638+
"""Clears symptoms by deleting the dead person's ID in the tracker"""
639+
# Remove person from tracker entirely
640+
if person_id in self.symptom_tracker:
641+
del self.symptom_tracker[person_id]
642+
643+
603644
# ---------------------------------------------------------------------------------------------------------
604645
# EVENTS
605646
# ---------------------------------------------------------------------------------------------------------
@@ -696,22 +737,21 @@ def apply(self, population):
696737
do_not_have_symptom = self.module.who_not_have(symptom_string=symp)
697738

698739
for group in ['children', 'adults']:
699-
700740
p = self.generic_symptoms['prob_per_day'][group][symp]
701741
dur = self.generic_symptoms['duration_in_days'][group][symp]
702742
persons_eligible_to_get_symptom = group_indices[group][
703743
group_indices[group].isin(do_not_have_symptom)
704744
]
705745
persons_to_onset_with_this_symptom = persons_eligible_to_get_symptom[
706746
self.rand(len(persons_eligible_to_get_symptom)) < p
707-
]
747+
]
708748

709749
# Do onset
710750
self.sim.modules['SymptomManager'].change_symptom(
711751
symptom_string=symp,
712752
add_or_remove='+',
713753
person_id=persons_to_onset_with_this_symptom,
714-
duration_in_days=None, # <- resolution for these is handled by the SpuriousSymptomsResolve Event
754+
duration_in_days=None, # <- resolution for these is handled by the SpuriousSymptomsResolve Event
715755
disease_module=self.module,
716756
)
717757

0 commit comments

Comments
 (0)