Skip to content

Commit a4f1a72

Browse files
authored
add csai test cases (#535)
1 parent 6c5777e commit a4f1a72

File tree

2 files changed

+285
-0
lines changed

2 files changed

+285
-0
lines changed

tests/classification/csai.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
Test cases for CSAI classification model.
3+
"""
4+
5+
# Created by Linglong Qian <[email protected]>
6+
# License: BSD-3-Clause
7+
8+
import os
9+
import unittest
10+
11+
import pytest
12+
13+
from pypots.classification import CSAI
14+
from pypots.optim import Adam
15+
from pypots.utils.logging import logger
16+
from pypots.utils.metrics import calc_binary_classification_metrics
17+
from tests.global_test_config import (
18+
DATA,
19+
EPOCHS,
20+
DEVICE,
21+
TRAIN_SET,
22+
VAL_SET,
23+
TEST_SET,
24+
GENERAL_H5_TRAIN_SET_PATH,
25+
GENERAL_H5_VAL_SET_PATH,
26+
GENERAL_H5_TEST_SET_PATH,
27+
RESULT_SAVING_DIR_FOR_CLASSIFICATION,
28+
check_tb_and_model_checkpoints_existence,
29+
)
30+
31+
32+
class TestCSAI(unittest.TestCase):
33+
logger.info("Running tests for a classification model CSAI...")
34+
35+
# Set the log and model saving path
36+
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_CLASSIFICATION, "CSAI")
37+
model_save_name = "saved_CSAI_model.pypots"
38+
39+
# Initialize an Adam optimizer
40+
optimizer = Adam(lr=0.001, weight_decay=1e-5)
41+
42+
# Initialize the CSAI model for classification
43+
csai = CSAI(
44+
n_steps=DATA["n_steps"],
45+
n_features=DATA["n_features"],
46+
n_classes=DATA["n_classes"],
47+
rnn_hidden_size=32,
48+
imputation_weight=0.7,
49+
consistency_weight=0.3,
50+
classification_weight=1.0,
51+
removal_percent=10,
52+
increase_factor=0.1,
53+
compute_intervals=True,
54+
step_channels=16,
55+
batch_size=64,
56+
epochs=EPOCHS,
57+
dropout=0.5,
58+
optimizer=optimizer,
59+
num_workers=4,
60+
device=DEVICE,
61+
saving_path=saving_path,
62+
model_saving_strategy="better",
63+
verbose=True,
64+
)
65+
66+
@pytest.mark.xdist_group(name="classification-csai")
67+
def test_0_fit(self):
68+
# Fit the CSAI model on the training and validation datasets
69+
self.csai.fit(TRAIN_SET, VAL_SET)
70+
71+
@pytest.mark.xdist_group(name="classification-csai")
72+
def test_1_classify(self):
73+
# Classify test set using the trained CSAI model
74+
results = self.csai.classify(TEST_SET)
75+
76+
# Calculate binary classification metrics
77+
metrics = calc_binary_classification_metrics(
78+
results, DATA["test_y"]
79+
)
80+
81+
logger.info(
82+
f'CSAI ROC_AUC: {metrics["roc_auc"]}, '
83+
f'PR_AUC: {metrics["pr_auc"]}, '
84+
f'F1: {metrics["f1"]}, '
85+
f'Precision: {metrics["precision"]}, '
86+
f'Recall: {metrics["recall"]}'
87+
)
88+
89+
assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"
90+
91+
@pytest.mark.xdist_group(name="classification-csai")
92+
def test_2_parameters(self):
93+
# Ensure that CSAI model parameters are properly initialized and trained
94+
assert hasattr(self.csai, "model") and self.csai.model is not None
95+
96+
assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None
97+
98+
assert hasattr(self.csai, "best_loss")
99+
self.assertNotEqual(self.csai.best_loss, float("inf"))
100+
101+
assert (
102+
hasattr(self.csai, "best_model_dict")
103+
and self.csai.best_model_dict is not None
104+
)
105+
106+
@pytest.mark.xdist_group(name="classification-csai")
107+
def test_3_saving_path(self):
108+
# Ensure the root saving directory exists
109+
assert os.path.exists(
110+
self.saving_path
111+
), f"file {self.saving_path} does not exist"
112+
113+
# Check if the tensorboard file and model checkpoints exist
114+
check_tb_and_model_checkpoints_existence(self.csai)
115+
116+
# Save the trained model to file, and verify the file existence
117+
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
118+
self.csai.save(saved_model_path)
119+
120+
# Test loading the saved model
121+
self.csai.load(saved_model_path)
122+
123+
@pytest.mark.xdist_group(name="classification-csai")
124+
def test_4_lazy_loading(self):
125+
# Fit the CSAI model using lazy-loading datasets from H5 files
126+
self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
127+
128+
# Perform classification using lazy-loaded data
129+
results = self.csai.classify(GENERAL_H5_TEST_SET_PATH)
130+
131+
# Calculate binary classification metrics
132+
metrics = calc_binary_classification_metrics(
133+
results, DATA["test_y"]
134+
)
135+
136+
logger.info(
137+
f'Lazy-loading CSAI ROC_AUC: {metrics["roc_auc"]}, '
138+
f'PR_AUC: {metrics["pr_auc"]}, '
139+
f'F1: {metrics["f1"]}, '
140+
f'Precision: {metrics["precision"]}, '
141+
f'Recall: {metrics["recall"]}'
142+
)
143+
144+
assert metrics["roc_auc"] >= 0.5, "ROC-AUC < 0.5"
145+
146+
147+
if __name__ == "__main__":
148+
unittest.main()

tests/imputation/csai.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Test cases for CSAI imputation model.
3+
"""
4+
5+
# Created by Linglong Qian <[email protected]>
6+
# License: BSD-3-Clause
7+
8+
9+
import os.path
10+
import unittest
11+
12+
import numpy as np
13+
import pytest
14+
15+
from pypots.imputation import CSAI
16+
from pypots.optim import Adam
17+
from pypots.utils.logging import logger
18+
from pypots.utils.metrics import calc_mse
19+
from tests.global_test_config import (
20+
DATA,
21+
EPOCHS,
22+
DEVICE,
23+
TRAIN_SET,
24+
VAL_SET,
25+
TEST_SET,
26+
GENERAL_H5_TRAIN_SET_PATH,
27+
GENERAL_H5_VAL_SET_PATH,
28+
GENERAL_H5_TEST_SET_PATH,
29+
RESULT_SAVING_DIR_FOR_IMPUTATION,
30+
check_tb_and_model_checkpoints_existence,
31+
)
32+
33+
34+
class TestCSAI(unittest.TestCase):
35+
logger.info("Running tests for the CSAI imputation model...")
36+
37+
# Set the log and model saving path
38+
saving_path = os.path.join(RESULT_SAVING_DIR_FOR_IMPUTATION, "CSAI")
39+
model_save_name = "saved_CSAI_model.pypots"
40+
41+
# Initialize an Adam optimizer
42+
optimizer = Adam(lr=0.001, weight_decay=1e-5)
43+
44+
# Initialize the CSAI model
45+
csai = CSAI(
46+
n_steps=DATA["n_steps"],
47+
n_features=DATA["n_features"],
48+
rnn_hidden_size=32,
49+
imputation_weight=0.7,
50+
consistency_weight=0.3,
51+
removal_percent=10, # Assume we are removing 10% of the data
52+
increase_factor=0.1,
53+
compute_intervals=True,
54+
step_channels=16,
55+
batch_size=64,
56+
epochs=EPOCHS,
57+
optimizer=optimizer,
58+
num_workers=0,
59+
device=DEVICE,
60+
saving_path=saving_path,
61+
model_saving_strategy="best",
62+
verbose=True,
63+
)
64+
65+
@pytest.mark.xdist_group(name="imputation-csai")
66+
def test_0_fit(self):
67+
# Fit the CSAI model on the training and validation datasets
68+
self.csai.fit(TRAIN_SET, VAL_SET)
69+
70+
@pytest.mark.xdist_group(name="imputation-csai")
71+
def test_1_impute(self):
72+
# Impute missing values using the trained CSAI model
73+
imputed_X = self.csai.impute(TEST_SET)
74+
assert not np.isnan(
75+
imputed_X
76+
).any(), "Output still has missing values after running impute()."
77+
78+
# Calculate mean squared error (MSE) for the test set
79+
test_MSE = calc_mse(
80+
imputed_X, DATA["test_X_ori"], DATA["test_X_indicating_mask"]
81+
)
82+
logger.info(f"CSAI test_MSE: {test_MSE}")
83+
84+
@pytest.mark.xdist_group(name="imputation-csai")
85+
def test_2_parameters(self):
86+
# Ensure that CSAI model parameters are properly initialized and trained
87+
assert hasattr(self.csai, "model") and self.csai.model is not None
88+
89+
assert hasattr(self.csai, "optimizer") and self.csai.optimizer is not None
90+
91+
assert hasattr(self.csai, "best_loss")
92+
self.assertNotEqual(self.csai.best_loss, float("inf"))
93+
94+
assert (
95+
hasattr(self.csai, "best_model_dict")
96+
and self.csai.best_model_dict is not None
97+
)
98+
99+
@pytest.mark.xdist_group(name="imputation-csai")
100+
def test_3_saving_path(self):
101+
# Ensure the root saving directory exists
102+
assert os.path.exists(
103+
self.saving_path
104+
), f"file {self.saving_path} does not exist"
105+
106+
# Check if the tensorboard file and model checkpoints exist
107+
check_tb_and_model_checkpoints_existence(self.csai)
108+
109+
# Save the trained model to file, and verify the file existence
110+
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
111+
self.csai.save(saved_model_path)
112+
113+
# Test loading the saved model
114+
self.csai.load(saved_model_path)
115+
116+
@pytest.mark.xdist_group(name="imputation-csai")
117+
def test_4_lazy_loading(self):
118+
# Fit the CSAI model using lazy-loading datasets from H5 files
119+
self.csai.fit(GENERAL_H5_TRAIN_SET_PATH, GENERAL_H5_VAL_SET_PATH)
120+
121+
# Perform imputation using lazy-loaded data
122+
imputation_results = self.csai.predict(GENERAL_H5_TEST_SET_PATH)
123+
assert not np.isnan(
124+
imputation_results["imputation"]
125+
).any(), "Output still has missing values after running impute()."
126+
127+
# Calculate the MSE on the test set
128+
test_MSE = calc_mse(
129+
imputation_results["imputation"],
130+
DATA["test_X_ori"],
131+
DATA["test_X_indicating_mask"],
132+
)
133+
logger.info(f"Lazy-loading CSAI test_MSE: {test_MSE}")
134+
135+
136+
if __name__ == "__main__":
137+
unittest.main()

0 commit comments

Comments
 (0)