Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1067 improve abm tests #1141

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 94 additions & 84 deletions cpp/models/abm/parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ struct MaskProtection {
using Type = CustomIndexArray<UncertainValue<>, MaskType>;
static Type get_default(AgeGroup /*size*/)
{
Type defaut_value = Type(MaskType::Count, 0.0);
Type defaut_value = Type(MaskType::Count, 0.0);
// Initial values according to http://dx.doi.org/10.15585/mmwr.mm7106e1
defaut_value[MaskType::FFP2] = 0.83;
defaut_value[MaskType::Surgical] = 0.66;
Expand Down Expand Up @@ -637,108 +637,118 @@ class Parameters : public ParametersBase
*/
bool check_constraints() const
{
for (auto i = AgeGroup(0); i < AgeGroup(m_num_groups); ++i) {

if (this->get<IncubationPeriod>()[{VirusVariant::Wildtype, i}] < 0) {
log_error("Constraint check: Parameter IncubationPeriod of age group {:.0f} smaller than {:.4f}",
(size_t)i, 0);
return true;
}

if (this->get<InfectedNoSymptomsToSymptoms>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter InfectedNoSymptomsToSymptoms of age group {:.0f} smaller "
"than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<InfectedNoSymptomsToRecovered>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter InfectedNoSymptomsToRecovered of age group {:.0f} smaller "
"than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<InfectedSymptomsToRecovered>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error(
"Constraint check: Parameter InfectedSymptomsToRecovered of age group {:.0f} smaller than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<InfectedSymptomsToSevere>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter InfectedSymptomsToSevere of age group {:.0f} smaller than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<SevereToCritical>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter SevereToCritical of age group {:.0f} smaller than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<SevereToRecovered>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter SevereToRecovered of age group {:.0f} smaller than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<CriticalToDead>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter CriticalToDead of age group {:.0f} smaller than {:d}", (size_t)i,
0);
return true;
}

if (this->get<CriticalToRecovered>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter CriticalToRecovered of age group {:.0f} smaller than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<RecoveredToSusceptible>()[{VirusVariant::Wildtype, i}] < 0.0) {
log_error("Constraint check: Parameter RecoveredToSusceptible of age group {:.0f} smaller than {:d}",
(size_t)i, 0);
return true;
}

if (this->get<DetectInfection>()[{VirusVariant::Wildtype, i}] < 0.0 ||
this->get<DetectInfection>()[{VirusVariant::Wildtype, i}] > 1.0) {
log_error("Constraint check: Parameter DetectInfection of age group {:.0f} smaller than {:d} or "
"larger than {:d}",
(size_t)i, 0, 1);
return true;
for (auto age_group = AgeGroup(0); age_group < AgeGroup(m_num_groups); ++age_group) {
for (std::uint32_t variant_count = 0; variant_count < static_cast<std::uint32_t>(VirusVariant::Count); ++variant_count) {

auto virus_variant = static_cast<mio::abm::VirusVariant>(variant_count);
Comment on lines +640 to +643
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a patch to make_index_range in index_range.h, this becomes

Suggested change
for (auto age_group = AgeGroup(0); age_group < AgeGroup(m_num_groups); ++age_group) {
for (std::uint32_t variant_count = 0; variant_count < static_cast<std::uint32_t>(VirusVariant::Count); ++variant_count) {
auto virus_variant = static_cast<mio::abm::VirusVariant>(variant_count);
for (auto age_group : make_index_range(AgeGroup{m_num_groups})) {
for (auto virus_variant : make_index_range(Index<VirusVariant>{VirusVariant::Count})) {

Try to replace make_index_range with the following, so it can work with AgeGroup (otherwise Index<AgeGroup> is required):

/**
 * @brief Construct a range that can be used to iterate over all MultiIndices in the given dimensions.
 * The range spans over [0, d) for each category in the MultiIndex, where d is that category's value in dimensions.
 * @param[in] dimensions A MultiIndex that contains the dimension for each category.
 * @tparam Categories All categories of the given MultiIndex.
 * @return An iterable range over the given dimensions.
 */
template <class... Categories>
IndexRange<Index<Categories...>> make_index_range(const Index<Categories...>& dimensions)
{
    return IndexRange<Index<Categories...>>(dimensions);
}


if (this->get<IncubationPeriod>()[{virus_variant, age_group}] < 0) {
log_error("Constraint check: Parameter IncubationPeriod of age group {:.0f} smaller than {:.4f}",
(size_t)age_group, 0);
return true;
}

if (this->get<InfectedNoSymptomsToSymptoms>()[{virus_variant, age_group}] < 0.0) {
log_error("Constraint check: Parameter InfectedNoSymptomsToSymptoms of age group {:.0f} smaller "
"than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<InfectedNoSymptomsToRecovered>()[{virus_variant, age_group}] < 0.0) {
log_error("Constraint check: Parameter InfectedNoSymptomsToRecovered of age group {:.0f} smaller "
"than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<InfectedSymptomsToRecovered>()[{virus_variant, age_group}] < 0.0) {
log_error(
"Constraint check: Parameter InfectedSymptomsToRecovered of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<InfectedSymptomsToSevere>()[{virus_variant, age_group}] < 0.0) {
log_error(
"Constraint check: Parameter InfectedSymptomsToSevere of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<SevereToCritical>()[{virus_variant, age_group}] < 0.0) {
log_error("Constraint check: Parameter SevereToCritical of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<SevereToRecovered>()[{virus_variant, age_group}] < 0.0) {
log_error("Constraint check: Parameter SevereToRecovered of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<CriticalToDead>()[{virus_variant, age_group}] < 0.0) {
log_error("Constraint check: Parameter CriticalToDead of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<CriticalToRecovered>()[{virus_variant, age_group}] < 0.0) {
log_error("Constraint check: Parameter CriticalToRecovered of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<RecoveredToSusceptible>()[{virus_variant, age_group}] < 0.0) {
log_error(
"Constraint check: Parameter RecoveredToSusceptible of age group {:.0f} smaller than {:d}",
(size_t)age_group, 0);
return true;
}

if (this->get<DetectInfection>()[{virus_variant, age_group}] < 0.0 ||
this->get<DetectInfection>()[{virus_variant, age_group}] > 1.0) {
log_error("Constraint check: Parameter DetectInfection of age group {:.0f} smaller than {:d} or "
"larger than {:d}",
(size_t)age_group, 0, 1);
return true;
}
}

if (this->get<GotoWorkTimeMinimum>()[i].seconds() < 0.0 ||
this->get<GotoWorkTimeMinimum>()[i].seconds() > this->get<GotoWorkTimeMaximum>()[i].seconds()) {
if (this->get<GotoWorkTimeMinimum>()[age_group].seconds() < 0.0 ||
this->get<GotoWorkTimeMinimum>()[age_group].seconds() >
this->get<GotoWorkTimeMaximum>()[age_group].seconds()) {
log_error("Constraint check: Parameter GotoWorkTimeMinimum of age group {:.0f} smaller {:d} or "
"larger {:d}",
(size_t)i, 0, this->get<GotoWorkTimeMaximum>()[i].seconds());
(size_t)age_group, 0, this->get<GotoWorkTimeMaximum>()[age_group].seconds());
return true;
}

if (this->get<GotoWorkTimeMaximum>()[i].seconds() < this->get<GotoWorkTimeMinimum>()[i].seconds() ||
this->get<GotoWorkTimeMaximum>()[i] > days(1)) {
if (this->get<GotoWorkTimeMaximum>()[age_group].seconds() <
this->get<GotoWorkTimeMinimum>()[age_group].seconds() ||
this->get<GotoWorkTimeMaximum>()[age_group] > days(1)) {
log_error("Constraint check: Parameter GotoWorkTimeMaximum of age group {:.0f} smaller {:d} or larger "
"than one day time span",
(size_t)i, this->get<GotoWorkTimeMinimum>()[i].seconds());
(size_t)age_group, this->get<GotoWorkTimeMinimum>()[age_group].seconds());
return true;
}

if (this->get<GotoSchoolTimeMinimum>()[i].seconds() < 0.0 ||
this->get<GotoSchoolTimeMinimum>()[i].seconds() > this->get<GotoSchoolTimeMaximum>()[i].seconds()) {
if (this->get<GotoSchoolTimeMinimum>()[age_group].seconds() < 0.0 ||
this->get<GotoSchoolTimeMinimum>()[age_group].seconds() >
this->get<GotoSchoolTimeMaximum>()[age_group].seconds()) {
log_error("Constraint check: Parameter GotoSchoolTimeMinimum of age group {:.0f} smaller {:d} or "
"larger {:d}",
(size_t)i, 0, this->get<GotoWorkTimeMaximum>()[i].seconds());
(size_t)age_group, 0, this->get<GotoWorkTimeMaximum>()[age_group].seconds());
return true;
}

if (this->get<GotoSchoolTimeMaximum>()[i].seconds() < this->get<GotoSchoolTimeMinimum>()[i].seconds() ||
this->get<GotoSchoolTimeMaximum>()[i] > days(1)) {
if (this->get<GotoSchoolTimeMaximum>()[age_group].seconds() <
this->get<GotoSchoolTimeMinimum>()[age_group].seconds() ||
this->get<GotoSchoolTimeMaximum>()[age_group] > days(1)) {
log_error("Constraint check: Parameter GotoWorkTimeMaximum of age group {:.0f} smaller {:d} or larger "
"than one day time span",
(size_t)i, this->get<GotoSchoolTimeMinimum>()[i].seconds());
(size_t)age_group, this->get<GotoSchoolTimeMinimum>()[age_group].seconds());
return true;
}
}
Expand Down
67 changes: 41 additions & 26 deletions cpp/tests/test_abm_household.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,69 +21,82 @@
#include "abm_helpers.h"
#include <gtest/gtest.h>


/**
* @brief Test adding a household to a model.
* Verifies correct number of persons, age groups, and location assignments.
*/
TEST(TestHouseholds, test_add_household_to_model)
{
// Create household members
auto member1 = mio::abm::HouseholdMember(num_age_groups);
member1.set_age_weight(age_group_0_to_4, 1);
member1.set_age_weight(age_group_0_to_4, 1); // Member is a child (age 0-4)

auto member2 = mio::abm::HouseholdMember(num_age_groups);
member2.set_age_weight(age_group_5_to_14, 1);
member2.set_age_weight(age_group_5_to_14, 1); // Member is a child (age 5-14)

// Create household and add members
auto household = mio::abm::Household();
household.add_members(member1, 2);
household.add_members(member2, 2);
household.add_members(member1, 2); // Add two members of age group 0-4
household.add_members(member2, 2); // Add two members of age group 5-14

// Create model and add the household
auto model = mio::abm::Model(num_age_groups);

add_household_to_model(model, household);
auto persons = model.get_persons();

// Test size
// Test: Correct number of persons added to the model
EXPECT_EQ(persons.size(), 4);

// Test age
// Test: Correct age groups assigned to persons
EXPECT_EQ(persons[0].get_age(), age_group_0_to_4);
EXPECT_EQ(persons[1].get_age(), age_group_0_to_4);
EXPECT_EQ(persons[2].get_age(), age_group_5_to_14);
EXPECT_EQ(persons[3].get_age(), age_group_5_to_14);

// Test location
// Test: Ensure persons of the same age group are in the same location
EXPECT_EQ(persons[0].get_location(), persons[1].get_location());
EXPECT_EQ(persons[2].get_location(), persons[3].get_location());
}

/**
* @brief Test adding a group of households to a model.
* Checks correct number of persons, age distribution, and location assignments.
*/
TEST(TestHouseholds, test_add_household_group_to_model)
{

// Create household members
auto member1 = mio::abm::HouseholdMember(num_age_groups);
member1.set_age_weight(age_group_35_to_59, 1);
member1.set_age_weight(age_group_35_to_59, 1); // Member is an adult (age 35-59)

auto member2 = mio::abm::HouseholdMember(num_age_groups);
member2.set_age_weight(age_group_5_to_14, 1);

auto household_group = mio::abm::HouseholdGroup();
member2.set_age_weight(age_group_5_to_14, 1); // Member is a child (age 5-14)

// Create the first household and add members
auto household1 = mio::abm::Household();
household1.add_members(member1, 10);
household1.add_members(member2, 2);
household_group.add_households(household1, 5);
household1.add_members(member1, 10); // Add ten members of age group 35-59
household1.add_members(member2, 2); // Add two members of age group 5-14

// Create the second household and add members
auto household2 = mio::abm::Household();
household2.add_members(member1, 2);
household2.add_members(member2, 2);
household_group.add_households(household2, 10);
household2.add_members(member1, 2); // Add two members of age group 35-59
household2.add_members(member2, 2); // Add two members of age group 5-14

auto model = mio::abm::Model(num_age_groups);
// Create a household group and add households
auto household_group = mio::abm::HouseholdGroup();
household_group.add_households(household1, 5); // Add household1 5 times
household_group.add_households(household2, 10); // Add household2 10 times

// Create model and add the household group
auto model = mio::abm::Model(num_age_groups);
add_household_group_to_model(model, household_group);
auto persons = model.get_persons();

// Test size
// Test: Correct number of persons in the model (5 * 12 + 10 * 4 = 100 persons)
EXPECT_EQ(persons.size(), 100);

// Test age
// Test: Count persons in each age group
int number_of_age5to14_year_olds = 0, number_of_age35to59_year_olds = 0;

for (auto& person : persons) {
if (person.get_age() == age_group_5_to_14) {
number_of_age5to14_year_olds++;
Expand All @@ -92,10 +105,12 @@ TEST(TestHouseholds, test_add_household_group_to_model)
number_of_age35to59_year_olds++;
}
}
EXPECT_EQ(number_of_age5to14_year_olds, 30);
EXPECT_EQ(number_of_age35to59_year_olds, 70);

// Test location for some people
// Verify the age group distribution
EXPECT_EQ(number_of_age5to14_year_olds, 30); // 30 children (age 5-14)
EXPECT_EQ(number_of_age35to59_year_olds, 70); // 70 adults (age 35-59)

// Test: Ensure people in the same household share the same location (for a few checks)
EXPECT_EQ(persons[0].get_location(), persons[1].get_location());
EXPECT_EQ(persons[1].get_location(), persons[5].get_location());
EXPECT_EQ(persons[5].get_location(), persons[10].get_location());
Expand Down
Loading
Loading