Skip to content

1250 enhance testing logic in abm #1276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0263ae8
first changes
xsaschako Apr 15, 2025
bdedb1a
secando changes
xsaschako May 14, 2025
08a1b68
Refactor testing strategy methods for clarity and functionality
xsaschako May 15, 2025
6213264
Enhance TestingStrategy with compliance checks and improve documentat…
xsaschako May 15, 2025
9f93d6a
Add assertions and improve logic for testing scheme entry checks
xsaschako May 15, 2025
e43abdc
Enable benchmarks and adjust infection percentage in simulation logic
xsaschako May 15, 2025
c26ca83
Refactor testing strategy to consolidate testing schemes for multiple…
xsaschako May 15, 2025
298211f
Update CMake options, adjust infection percentage in simulation, and …
xsaschako May 15, 2025
7176447
Refactor TestingStrategy: remove deprecated methods and update testin…
xsaschako May 16, 2025
75fad09
Refactor TestingStrategy and tests: streamline method signatures, enh…
xsaschako May 16, 2025
640f581
formatting
xsaschako May 16, 2025
18f8388
Update cpp/models/abm/testing_strategy.cpp
xsaschako May 16, 2025
e2ee172
Refactor testing strategy method calls: replace add_testing_scheme wi…
xsaschako May 16, 2025
f931acd
Merge branch 'main' into 1250-enhance-testing-logic-in-abm
xsaschako May 19, 2025
1bbf4b9
Enhance TestingStrategy constructor: add overload for multiple local …
xsaschako May 19, 2025
36fcce3
Refactor test_abm: remove unused testing scheme initialization in tes…
xsaschako May 19, 2025
4bdfd4a
Fix indentation in abm.cpp: align method chaining for better readability
xsaschako May 19, 2025
3a8163f
Refactor testing strategy method names: unify method names to improve…
xsaschako May 20, 2025
1778fd0
Add tests for edge cases in TestingCriteria and TestingScheme functio…
xsaschako May 20, 2025
475b57d
Enhance run_and_test documentation: clarify return value regarding te…
xsaschako May 20, 2025
8db4f79
Refactor TestingScheme: move start and end date getters to private se…
xsaschako Jun 21, 2025
319cfc6
Merge branch 'main' into 1250-enhance-testing-logic-in-abm
xsaschako Jun 21, 2025
eca897c
test
xsaschako Jun 21, 2025
0f452ae
[skip CI] revert test
xsaschako Jun 21, 2025
91fb8ba
test bug
xsaschako Jun 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions cpp/benchmarks/abm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,9 @@ mio::abm::Simulation<> make_simulation(size_t num_persons, std::initializer_list
return mio::abm::TestingCriteria(random_ages, random_states);
};

model.get_testing_strategy().add_testing_scheme(
mio::abm::LocationType::School,
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
model.get_testing_strategy().add_testing_scheme(
mio::abm::LocationType::Work,
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
model.get_testing_strategy().add_testing_scheme(
mio::abm::LocationType::Home,
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));
model.get_testing_strategy().add_testing_scheme(
mio::abm::LocationType::SocialEvent,
model.get_testing_strategy().add_scheme(
{mio::abm::LocationType::School, mio::abm::LocationType::Work, mio::abm::LocationType::SocialEvent,
mio::abm::LocationType::Home},
mio::abm::TestingScheme(random_criteria(), mio::abm::days(3), mio::abm::TimePoint(0),
mio::abm::TimePoint(0) + mio::abm::days(10), {}, 0.5));

Expand Down
2 changes: 1 addition & 1 deletion cpp/examples/abm_history_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ int main()
auto testing_criteria_work = mio::abm::TestingCriteria();
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
test_parameters, probability);
model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work);
model.get_testing_strategy().add_scheme(mio::abm::LocationType::Work, testing_scheme_work);

// Assign infection state to each person.
// The infection states are chosen randomly.
Expand Down
4 changes: 2 additions & 2 deletions cpp/examples/abm_minimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ int main()
auto test_parameters = model.parameters.get<mio::abm::TestData>()[test_type];
auto testing_criteria_work = mio::abm::TestingCriteria();
auto testing_scheme_work = mio::abm::TestingScheme(testing_criteria_work, validity_period, start_date, end_date,
test_parameters, probability);
model.get_testing_strategy().add_testing_scheme(mio::abm::LocationType::Work, testing_scheme_work);
test_parameters, probability);
model.get_testing_strategy().add_scheme(mio::abm::LocationType::Work, testing_scheme_work);

// Assign infection state to each person.
// The infection states are chosen randomly with the following distribution
Expand Down
2 changes: 1 addition & 1 deletion cpp/models/abm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ auto testing_criteria_work =
std::vector<mio::abm::TestingCriteria>{mio::abm::TestingCriteria({}, test_at_work, {})};
auto testing_scheme_work =
mio::abm::TestingScheme(testing_criteria_work, start_date, end_date, test_type, probability);
model.get_testing_strategy().add_testing_scheme(testing_scheme_work);
model.get_testing_strategy().add_scheme(testing_scheme_work);
```

For some infections to happen during the simulation, we have to initialize people with infections.
Expand Down
9 changes: 4 additions & 5 deletions cpp/models/abm/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ void Model::perform_mobility(TimePoint t, TimeSpan dt)
get_number_persons(target_location.get_id()) >= target_location.get_capacity().persons) {
return false;
}
// the Person cannot move if the performed TestingStrategy is positive
if (!m_testing_strategy.run_strategy(personal_rng, person, target_location, t)) {
// The person cannot move if he has a positive test result
if (!m_testing_strategy.run_and_check(personal_rng, person, target_location, t)) {
return false;
}

// update worn mask to target location's requirements
if (target_location.is_mask_required()) {
// if the current MaskProtection level is lower than required, the Person changes mask
Expand Down Expand Up @@ -190,7 +191,7 @@ void Model::perform_mobility(TimePoint t, TimeSpan dt)
continue;
}
// skip the trip if the performed TestingStrategy is positive
if (!m_testing_strategy.run_strategy(personal_rng, person, target_location, t)) {
if (!m_testing_strategy.run_and_check(personal_rng, person, target_location, t)) {
continue;
}
// all requirements are met, move to target location
Expand Down Expand Up @@ -297,8 +298,6 @@ void Model::compute_exposure_caches(TimePoint t, TimeSpan dt)

void Model::begin_step(TimePoint t, TimeSpan dt)
{
m_testing_strategy.update_activity_status(t);

if (!m_is_local_population_cache_valid) {
build_compute_local_population_cache();
m_is_local_population_cache_valid = true;
Expand Down
16 changes: 0 additions & 16 deletions cpp/models/abm/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,22 +341,6 @@ class Model
return m_id;
}

/**
* @brief Add a TestingScheme to the set of schemes that are checked for testing at all Locations that have
* the LocationType.
* @param[in] loc_type LocationId key for TestingScheme to be added.
* @param[in] scheme TestingScheme to be added.
*/
void add_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme);

/**
* @brief Remove a TestingScheme from the set of schemes that are checked for testing at all Locations that have
* the LocationType.
* @param[in] loc_type LocationId key for TestingScheme to be added.
* @param[in] scheme TestingScheme to be added.
*/
void remove_testing_scheme(const LocationType& loc_type, const TestingScheme& scheme);

/**
* @brief Get a reference to a Person from this Model.
* @param[in] person_id A Person's PersonId.
Expand Down
158 changes: 61 additions & 97 deletions cpp/models/abm/testing_strategy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,6 @@ bool TestingCriteria::operator==(const TestingCriteria& other) const
return m_ages == other.m_ages && m_infection_states == other.m_infection_states;
}

void TestingCriteria::add_age_group(const AgeGroup age_group)
{
m_ages.set(static_cast<size_t>(age_group), true);
}

void TestingCriteria::remove_age_group(const AgeGroup age_group)
{
m_ages.set(static_cast<size_t>(age_group), false);
}

void TestingCriteria::add_infection_state(const InfectionState infection_state)
{
m_infection_states.set(static_cast<size_t>(infection_state), true);
}

void TestingCriteria::remove_infection_state(const InfectionState infection_state)
{
m_infection_states.set(static_cast<size_t>(infection_state), false);
}

bool TestingCriteria::evaluate(const Person& p, TimePoint t) const
{
// An empty vector of ages or none bitset of #InfectionStates% means that no condition on the corresponding property is set.
Expand All @@ -79,6 +59,7 @@ TestingScheme::TestingScheme(const TestingCriteria& testing_criteria, TimeSpan v
, m_test_parameters(test_parameters)
, m_probability(probability)
{
assert(start_date <= end_date && "Start date must be before or equal to end date");
}

bool TestingScheme::operator==(const TestingScheme& other) const
Expand All @@ -91,122 +72,105 @@ bool TestingScheme::operator==(const TestingScheme& other) const
//To be adjusted and also TestType should be static.
}

bool TestingScheme::is_active() const
{
return m_is_active;
}

void TestingScheme::update_activity_status(TimePoint t)
bool TestingScheme::is_active(TimePoint t) const
{
m_is_active = (m_start_date <= t && t <= m_end_date);
return (m_start_date <= t && t < m_end_date);
}

bool TestingScheme::run_scheme(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const
bool TestingScheme::run_and_test(PersonalRandomNumberGenerator& rng, Person& person, TimePoint t) const
{
if (!is_active(t)) { // If the scheme is not active, do nothing; early return
return false;
}
if (!person.is_compliant(
rng, InterventionType::Testing)) { // If the person is not compliant with the testing intervention
return true; // Assume positive test result as this should not allow entry although it is not the same
}
auto test_result = person.get_test_result(m_test_parameters.type);
// If the agent has a test result valid until now, use the result directly
if ((test_result.time_of_testing > TimePoint(std::numeric_limits<int>::min())) &&
(test_result.time_of_testing + m_validity_period >= t)) {
return !test_result.result;
return test_result.result; // If the test is positive, the entry is not allowed, and vice versa
}
// Otherwise, the time_of_testing in the past (i.e. the agent has already performed it).
if (m_testing_criteria.evaluate(person, t - m_test_parameters.required_time)) {
double random = UniformDistribution<double>::get_instance()(rng);
if (random < m_probability) {
bool result = person.get_tested(rng, t - m_test_parameters.required_time, m_test_parameters);
person.add_test_result(t, m_test_parameters.type, result);
return !result;
return result; // If the test is positive, the entry is not allowed, and vice versa
}
}
return true;
// If the test is not performed, the entry is allowed
return false;
}

TestingStrategy::TestingStrategy(const std::vector<LocalStrategy>& location_to_schemes_map)
: m_location_to_schemes_map(location_to_schemes_map.begin(), location_to_schemes_map.end())
TestingStrategy::TestingStrategy(const std::vector<LocalStrategy>& location_to_schemes_id,
const std::vector<LocalStrategy>& location_to_schemes_type)
: m_testing_schemes_at_location_id(location_to_schemes_id.begin(), location_to_schemes_id.end())
, m_testing_schemes_at_location_type(location_to_schemes_type.begin(), location_to_schemes_type.end())
{
}

void TestingStrategy::add_testing_scheme(const LocationType& loc_type, const LocationId& loc_id,
const TestingScheme& scheme)
void TestingStrategy::add_scheme(const LocationId& loc_id, const TestingScheme& scheme)
{
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [&](const auto& p) {
return p.type == loc_type && p.id == loc_id;
});
if (iter_schemes == m_location_to_schemes_map.end()) {
//no schemes for this location yet, add a new list with one scheme
m_location_to_schemes_map.push_back({loc_type, loc_id, std::vector<TestingScheme>(1, scheme)});
}
else {
//add scheme to existing vector if the scheme doesn't exist yet
auto& schemes = iter_schemes->schemes;
if (std::find(schemes.begin(), schemes.end(), scheme) == schemes.end()) {
schemes.push_back(scheme);
}
if (loc_id.get() >= m_testing_schemes_at_location_id.size()) {
m_testing_schemes_at_location_id.resize(loc_id.get() + 1);
}
m_testing_schemes_at_location_id[loc_id.get()].schemes.push_back(scheme);
}

void TestingStrategy::remove_testing_scheme(const LocationType& loc_type, const LocationId& loc_id,
const TestingScheme& scheme)
void TestingStrategy::add_scheme(const LocationType& loc_type, const TestingScheme& scheme)
{
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [&](const auto& p) {
return p.type == loc_type && p.id == loc_id;
});
if (iter_schemes != m_location_to_schemes_map.end()) {
//remove the scheme from the list
auto& schemes_vector = iter_schemes->schemes;
auto last = std::remove(schemes_vector.begin(), schemes_vector.end(), scheme);
schemes_vector.erase(last, schemes_vector.end());
//delete the list of schemes for this location if no schemes left
if (schemes_vector.empty()) {
m_location_to_schemes_map.erase(iter_schemes);
}
if ((size_t)loc_type >= m_testing_schemes_at_location_type.size()) {
m_testing_schemes_at_location_type.resize((size_t)loc_type + 1);
}
m_testing_schemes_at_location_type[(size_t)loc_type].schemes.push_back(scheme);
}

void TestingStrategy::update_activity_status(TimePoint t)
bool TestingStrategy::run_and_check(PersonalRandomNumberGenerator& rng, Person& person, const Location& location,
TimePoint t)
{
for (auto& [_type, _id, testing_schemes] : m_location_to_schemes_map) {
for (auto& scheme : testing_schemes) {
scheme.update_activity_status(t);
}
}
}
// Early return if no scheme defined for this location or type
auto loc_id = location.get_id().get();
auto loc_type = static_cast<size_t>(location.get_type());

bool TestingStrategy::run_strategy(PersonalRandomNumberGenerator& rng, Person& person, const Location& location,
TimePoint t)
{
// A Person is always allowed to go home and this is never called if a person is not discharged from a hospital or ICU.
if (location.get_type() == mio::abm::LocationType::Home) {
return true;
bool has_id_schemes =
loc_id < m_testing_schemes_at_location_id.size() && !m_testing_schemes_at_location_id[loc_id].schemes.empty();

bool has_type_schemes = loc_type < m_testing_schemes_at_location_type.size() &&
!m_testing_schemes_at_location_type[loc_type].schemes.empty();

if (!has_id_schemes && !has_type_schemes) {
return true; // No applicable schemes
}

// If the Person does not comply to Testing where there is a testing scheme at the target location, it is not allowed to enter.
if (!person.is_compliant(rng, InterventionType::Testing)) {
return false;
bool entry_allowed = true; // Assume entry is allowed unless a scheme denies it
// Check schemes for specific location id
if (has_id_schemes) {
for (const auto& scheme : m_testing_schemes_at_location_id[loc_id].schemes) {
if (scheme.run_and_test(rng, person, t)) {
entry_allowed = false; // Deny entry
}
}
}

// Lookup schemes for this specific location as well as the location type
// Lookup in std::vector instead of std::map should be much faster unless for large numbers of schemes
for (auto key : {std::make_pair(location.get_type(), location.get_id()),
std::make_pair(location.get_type(), LocationId::invalid_id())}) {
auto iter_schemes =
std::find_if(m_location_to_schemes_map.begin(), m_location_to_schemes_map.end(), [&](const auto& p) {
return p.type == key.first && p.id == key.second;
});
if (iter_schemes != m_location_to_schemes_map.end()) {
// Apply all testing schemes that are found
auto& schemes = iter_schemes->schemes;
// Whether the Person is allowed to enter or not depends on the test result(s).
if (!std::all_of(schemes.begin(), schemes.end(), [&rng, &person, t](TestingScheme& ts) {
return !ts.is_active() || ts.run_scheme(rng, person, t);
})) {
return false;
// Check schemes for location type
if (has_type_schemes) {
for (const auto& scheme : m_testing_schemes_at_location_type[loc_type].schemes) {
if (scheme.run_and_test(rng, person, t)) {
entry_allowed = false; // Deny entry
}
}
}
return true;

// If the location is a home, entry is always allowed regardless of testing, no early return here because we still need to test
if (location.get_type() == LocationType::Home) {
return true;
}
else {
return entry_allowed;
}
}

} // namespace abm
Expand Down
Loading
Loading