Skip to content

Commit

Permalink
setup obs and variables duckdb backed, first proper minimal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Oct 29, 2024
1 parent 4cec971 commit d3fca1f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 157 deletions.
143 changes: 19 additions & 124 deletions src/ehrdata/io/omop/omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def setup_variables(
edata
The EHRData object to which the variables should be added.
data_tables
The tables to be used.
The tables to be used. For now, only one can be used.
data_field_to_keep
The CDM Field in the data table to be kept. Can be e.g. "value_as_number" or "value_as_concept_id".
If multiple tables are used, this can be a dictionary with the table name as key and the column name as value, e.g. {"measurement": "value_as_number", "observation": "value_as_concept_id"}.
Expand All @@ -150,78 +150,33 @@ def setup_variables(
-------
An EHRData object with populated .r and .var field.
"""
time_interval_tables = []
from ehrdata import EHRData

time_defining_table = edata.uns.get("omop_io_observation_table", None)
if time_defining_table is None:
raise ValueError("The observation table must be set up first, use the `setup_obs` function.")

for data_table in data_tables:
ds = (
time_interval_table_query_long_format(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_table,
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
)
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
ds = (
time_interval_table_query_long_format(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
)
# TODO: interval_start to var
# TODO: concept_ids to var
# TODO: concept_names to var
# TODO: for measurement, observation: store unit_concept_id and unit_name in var
time_interval_tables.append(ds)

return ds
# for table in tables:
# if table not in VALID_VARIABLE_TABLES:
# raise ValueError(f"tables must be a sequence of from [{VALID_VARIABLE_TABLES}].")

# id_column = f"{table}_type_concept_id" if table in ["note", "death"] else f"{table}_concept_id"

# concept_ids_present = _lowercase_column_names(
# backend_handle.sql(f"SELECT DISTINCT {id_column} FROM {table}").df()
# )

# personxfeature_pairs_of_value_timestamp = _extract_personxfeature_pairs_of_value_timestamp(backend_handle)

# # Create the time interval table
# time_interval_table = get_time_interval_table(
# backend_handle,
# personxfeature_pairs_of_value_timestamp,
# edata.obs,
# start_time="observation_period_start",
# interval_length_number=interval_length_number,
# interval_length_unit=interval_length_unit,
# num_intervals=num_intervals,
# concept_ids=concept_ids,
# aggregation_strategy=aggregation_strategy,
# )

# # Append
# concept_ids_present_list.append(concept_ids_present)
# time_interval_tables.append(time_interval_table)

# Combine time interval tables
# if len(time_interval_tables) > 1:
# time_interval_table = np.concatenate([time_interval_table, time_interval_table], axis=1)
# concept_ids_present = pd.concat(concept_ids_present_list)
# else:
# time_interval_table = time_interval_tables[0]
# concept_ids_present = concept_ids_present_list[0]

# # Update edata with the new variables
# edata = EHRData(r=time_interval_table, obs=edata.obs, var=concept_ids_present)
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
)

# return edata
var = ds["data_table_concept_id"].to_dataframe()
t = ds["interval_step"].to_dataframe()

edata = EHRData(r=ds[data_field_to_keep[0]].values, obs=edata.obs, var=var, uns=edata.uns, t=t)

# DEVICE EXPOSURE and DRUG EXPOSURE NEEDS TO BE IMPLEMENTED BECAUSE THEY CONTAIN START DATE
return edata


def load(
Expand Down Expand Up @@ -258,66 +213,6 @@ def _get_table_join(
)


def _extract_personxfeature_pairs_of_value_timestamp(
duckdb_instance, table_name: str, concept_id_col: str, value_col: str, timestamp_col: str
):
"""
Generalized extraction function to extract data from an OMOP CDM table.
Parameters
----------
duckdb_instance: duckdb.DuckDB
The DuckDB instance for querying the database.
table_name: str
The name of the table to extract data from (e.g., "measurement", "observation").
concept_id_col: str
The name of the column that contains the concept IDs (e.g., "measurement_concept_id").
value_col: str
The name of the column that contains the values (e.g., "value_as_number").
timestamp_col: str
The name of the column that contains the timestamps (e.g., "measurement_datetime").
Returns
-------
ak.Array
An Awkward Array with the structure: n_person x n_features x 2 (value, time).
"""
# Load the specified table
table_df = duckdb_instance.sql(f"SELECT * FROM {table_name}").df()
table_df = _lowercase_column_names(table_df)

# Load the person table to get unique person IDs
person_id_df = _lowercase_column_names(duckdb_instance.sql("SELECT * FROM person").df())
person_ids = person_id_df["person_id"].unique()

# Get unique features (concept IDs) for the table
features = table_df[concept_id_col].unique()

# Initialize the collection for all persons
person_collection = []

for person in person_ids:
person_as_list = []
# Get rows for the current person
person_data = table_df[table_df["person_id"] == person]

# For each feature, get values and timestamps
for feature in features:
feature_data = person_data[person_data[concept_id_col] == feature]

# Extract the values and timestamps
feature_values = feature_data[value_col]
feature_timestamps = feature_data[timestamp_col]

# Append values and timestamps for this feature
person_as_list.append([feature_values, feature_timestamps])

# Append this person's data to the collection
person_collection.append(person_as_list)

return ak.Array(person_collection)


def extract_measurement(duckdb_instance):
"""Extract a table of an OMOP CDM Database."""
return get_table(
Expand Down
3 changes: 3 additions & 0 deletions tests/data/toy_omop/vanilla/death.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
person_id,death_date,death_datetime,death_type_concept_id,cause_concept_id,cause_source_value,cause_source_concept_id
1,2100-03-31,2100-03-31 00:00:00,32817,0,0,
2,2100-03-31,2100-03-31 00:00:00,32817,0,0,
66 changes: 33 additions & 33 deletions tests/test_io/test_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,23 @@

import ehrdata as ed

# def test_register_omop_to_db_connection():
# register_omop_to_db_connection(path="tests/data/toy_omop/vanilla", backend_handle=duckdb.connect(), source="csv")


# TODO: add test for death argument
@pytest.mark.parametrize(
"observation_table, expected_length, expected_obs_num_columns",
"observation_table, death_table, expected_length, expected_obs_num_columns",
[
("person", 4, 18),
("person_cohort", 3, 22),
("person_observation_period", 3, 23),
("person_visit_occurrence", 3, 35),
("person", False, 4, 18),
("person", True, 4, 24),
("person_cohort", False, 3, 22),
("person_cohort", True, 3, 28),
("person_observation_period", False, 3, 23),
("person_observation_period", True, 3, 29),
("person_visit_occurrence", False, 3, 35),
("person_visit_occurrence", True, 3, 41),
],
)
def test_setup_obs(omop_connection_vanilla, observation_table, expected_length, expected_obs_num_columns):
def test_setup_obs(omop_connection_vanilla, observation_table, death_table, expected_length, expected_obs_num_columns):
con = omop_connection_vanilla
edata = ed.io.omop.setup_obs(backend_handle=con, observation_table=observation_table)
edata = ed.io.omop.setup_obs(backend_handle=con, observation_table=observation_table, death_table=death_table)
assert isinstance(edata, ed.EHRData)

# 4 persons, only 3 are in cohort, or have observation period, or visit occurrence
Expand All @@ -44,32 +44,32 @@ def test_setup_obs_invalid_observation_table_argument(omop_connection_vanilla):
ed.io.omop.setup_obs(backend_handle=con, observation_table="perso")


def test_setup_variables_measurement_startdate_fixed(omop_connection_vanilla):
@pytest.mark.parametrize(
"observation_table",
["person_cohort", "person_observation_period", "person_visit_occurrence"],
)
@pytest.mark.parametrize(
"data_tables",
[["measurement"], ["observation"]],
)
@pytest.mark.parametrize(
"data_field_to_keep",
[["value_as_number"], ["value_as_concept_id"]],
)
def test_setup_variables(omop_connection_vanilla, observation_table, data_tables, data_field_to_keep):
con = omop_connection_vanilla
edata = ed.io.omop.setup_obs(backend_handle=con, observation_table="person")
ed.io.omop.setup_variables(
edata = ed.io.omop.setup_obs(backend_handle=con, observation_table=observation_table)
edata = ed.io.omop.setup_variables(
edata,
backend_handle=con,
tables=["measurement"],
start_time="2100-01-01",
data_tables=data_tables,
data_field_to_keep=data_field_to_keep,
interval_length_number=1,
interval_length_unit="day",
num_intervals=31,
num_intervals=30,
)
# check precise expected table
assert edata.vars.shape[1] == 8


def test_setup_var_measurement_startdate_observation_period():
# check precise expected table
pass


def test_setup_var_observation_startdate_fixed():
# check precise expected table
pass


def test_setup_var_observation_startdate_observation_period():
# check precise expected table
pass
assert isinstance(edata, ed.EHRData)
assert edata.n_obs == 3
assert edata.n_vars == 2
assert edata.r.shape[2] == 30

0 comments on commit d3fca1f

Please sign in to comment.