diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index bc323ae..84fa182 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -7,9 +7,9 @@ ## Changes Made -- -- -- +- +- +- ## Testing @@ -27,4 +27,3 @@ ## Additional Notes - diff --git a/.gitignore b/.gitignore index d168a8c..aaa3416 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ scrap/ .DS_Store .vscode/ .ruff_cache/ +.python-version diff --git a/healthchain/data_generators/cdsdatagenerator.py b/healthchain/data_generators/cdsdatagenerator.py index 07ff218..b3d1340 100644 --- a/healthchain/data_generators/cdsdatagenerator.py +++ b/healthchain/data_generators/cdsdatagenerator.py @@ -75,6 +75,7 @@ def generate( constraints: Optional[list] = None, free_text_path: Optional[str] = None, column_name: Optional[str] = None, + random_seed: Optional[int] = None, ) -> BaseModel: """ Generates CDS data based on the current workflow, constraints, and optional free text data. @@ -83,6 +84,7 @@ def generate( constraints (Optional[list]): A list of constraints to apply to the data generation. free_text_path (Optional[str]): The path to a CSV file containing free text data. column_name (Optional[str]): The column name in the CSV file to use for free text data. + random_seed (Optional[int]): The random seed to use for reproducible data generation. Returns: BaseModel: The generated CDS FHIR data. @@ -95,7 +97,9 @@ def generate( for resource in self.mappings[self.workflow]: generator_name = resource["generator"] generator = self.fetch_generator(generator_name) - result = generator.generate(constraints=constraints) + result = generator.generate( + constraints=constraints, random_seed=random_seed + ) results.append(BundleEntry(resource=result)) diff --git a/healthchain/data_generators/conditiongenerators.py b/healthchain/data_generators/conditiongenerators.py index cef3fa8..924c367 100644 --- a/healthchain/data_generators/conditiongenerators.py +++ b/healthchain/data_generators/conditiongenerators.py @@ -147,7 +147,9 @@ def generate( subject_reference: Optional[str] = None, encounter_reference: Optional[str] = None, constraints: Optional[list] = None, + random_seed: Optional[int] = None, ): + Faker.seed(random_seed) subject_reference = subject_reference or "Patient/123" encounter_reference = encounter_reference or "Encounter/123" code = generator_registry.get("SnomedCodeGenerator").generate( diff --git a/healthchain/data_generators/encountergenerators.py b/healthchain/data_generators/encountergenerators.py index 1f1b3e4..c9c6681 100644 --- a/healthchain/data_generators/encountergenerators.py +++ b/healthchain/data_generators/encountergenerators.py @@ -143,14 +143,16 @@ class EncounterGenerator(BaseGenerator): A generator class for creating FHIR Encounter resources. Methods: - generate(constraints: Optional[list] = None) -> Encounter: - Generates a FHIR Encounter resource with optional constraints. + generate(constraints: Optional[list] = None, random_seed: Optional[int] = None) -> Encounter: + Generates a FHIR Encounter resource with optional constraints and random_seed. """ @staticmethod def generate( constraints: Optional[list] = None, + random_seed: Optional[int] = None, ) -> Encounter: + Faker.seed(random_seed) patient_reference = "Patient/123" return Encounter( resourceType="Encounter", diff --git a/healthchain/data_generators/medicationrequestgenerators.py b/healthchain/data_generators/medicationrequestgenerators.py index 979efca..3b1f2ab 100644 --- a/healthchain/data_generators/medicationrequestgenerators.py +++ b/healthchain/data_generators/medicationrequestgenerators.py @@ -46,7 +46,9 @@ class MedicationRequestGenerator(BaseGenerator): @staticmethod def generate( constraints: Optional[list] = None, + random_seed: Optional[int] = None, ): + Faker.seed(random_seed) subject_reference = "Patient/123" encounter_reference = "Encounter/123" contained_medication = Medication( diff --git a/healthchain/data_generators/patientgenerators.py b/healthchain/data_generators/patientgenerators.py index a09ba30..7076ed3 100644 --- a/healthchain/data_generators/patientgenerators.py +++ b/healthchain/data_generators/patientgenerators.py @@ -116,7 +116,11 @@ def generate(): @register_generator class PatientGenerator(BaseGenerator): @staticmethod - def generate(constraints: Optional[list] = None): + def generate( + constraints: Optional[list] = None, + random_seed: Optional[int] = None, + ) -> Patient: + Faker.seed(random_seed) return Patient( resourceType="Patient", id=generator_registry.get("IdGenerator").generate(), diff --git a/healthchain/data_generators/proceduregenerators.py b/healthchain/data_generators/proceduregenerators.py index eb51ef6..d2a4426 100644 --- a/healthchain/data_generators/proceduregenerators.py +++ b/healthchain/data_generators/proceduregenerators.py @@ -42,7 +42,9 @@ def generate( subject_reference: Optional[str] = None, encounter_reference: Optional[str] = None, constraints: Optional[list] = None, + random_seed: Optional[int] = None, ): + Faker.seed(random_seed) subject_reference = subject_reference or "Patient/123" encounter_reference = encounter_reference or "Encounter/123" code = generator_registry.get("ProcedureSnomedCodeGenerator").generate( diff --git a/healthchain/models/data/concept.py b/healthchain/models/data/concept.py index 0ef4d7d..f6b859c 100644 --- a/healthchain/models/data/concept.py +++ b/healthchain/models/data/concept.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from typing import Optional, Dict, Union @@ -21,6 +21,28 @@ class Quantity(DataType): value: Optional[Union[str, float]] = None unit: Optional[str] = None + @field_validator("value") + @classmethod + def validate_value(cls, value: Union[str, float]): + if value is None: + return None + + if not isinstance(value, (str, float)): + raise TypeError( + f"Value CANNOT be a {type(value)} object. Must be float or string in float format." + ) + + try: + return float(value) + + except ValueError: + raise ValueError(f"Invalid value '{value}' . Must be a float Number.") + + except OverflowError: + raise OverflowError( + "Invalid value . Value is too large resulting in overflow." + ) + class Range(DataType): low: Optional[Quantity] = None diff --git a/poetry.lock b/poetry.lock index bd41e59..f653cf9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" diff --git a/tests/test_cdaannotator.py b/tests/test_cdaannotator.py index dd1562e..86e4d85 100644 --- a/tests/test_cdaannotator.py +++ b/tests/test_cdaannotator.py @@ -68,12 +68,12 @@ def test_extract_medications(cda_annotator): assert medications[0].route.code_system_name == "NCI Thesaurus" assert medications[0].route.display_name == "Oral" - assert medications[0].frequency.period.value == ".5" + assert medications[0].frequency.period.value == 0.5 assert medications[0].frequency.period.unit == "d" assert medications[0].frequency.institution_specified assert medications[0].duration.low is None - assert medications[0].duration.high.value == "20221020" + assert medications[0].duration.high.value == 20221020 assert medications[0].precondition == { "@typeCode": "PRCN", diff --git a/tests/test_quantity_class.py b/tests/test_quantity_class.py new file mode 100644 index 0000000..1413fd5 --- /dev/null +++ b/tests/test_quantity_class.py @@ -0,0 +1,40 @@ +import pytest +from healthchain.models.data.concept import Quantity +from pydantic import ValidationError + + +# Valid Cases +def test_valid(): + valid_floats = [1.0, 0.1, 4.5, 5.99999, 12455.321, 33, 1234, None] + for num in valid_floats: + q = Quantity(value=num, unit="mg") + assert q.value == num + + +def test_valid_string(): + valid_strings = ["100", "100.000001", ".1", "1.", ".123", "1234.", "123989"] + for string in valid_strings: + q = Quantity(value=string, unit="mg") + assert q.value == float(string) + + +# Invalid Cases +def test_invalid_strings(): + invalid_strings = [ + "1.0.0", + "1..123", + "..123", + "12..", + "12a.56", + "1e4.6", + "12#.45", + "12.12@3", + "12@3", + "abc", + "None", + "", + ] + for string in invalid_strings: + with pytest.raises(ValidationError) as exception_info: + Quantity(value=string, unit="mg") + assert "Invalid value" in str(exception_info.value)