Skip to content

Commit 600f8e4

Browse files
committed
Add a test for TF SavedModel format saving and loading
1 parent 81d7ab7 commit 600f8e4

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/test_model_saving.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def _supervised_custom_model_saving(model_filepath, save_fn, load_fn):
130130
y_pred_2 = model_2.fit_transform(X, Y)
131131

132132
### Save and load ###
133-
def _save_ivis_model(model, filepath):
134-
model.save_model(filepath, overwrite=True)
133+
def _save_ivis_model(model, filepath, save_format='h5'):
134+
model.save_model(filepath, save_format=save_format, overwrite=True)
135135

136136
def _load_ivis_model(filepath):
137137
model_2 = Ivis()
@@ -179,6 +179,11 @@ def _undill_ivis_model(filepath):
179179
test_supervised_custom_model_pickling = partial(_supervised_custom_model_saving,
180180
save_fn=_dill_ivis_model, load_fn=_undill_ivis_model)
181181

182+
### Other ###
183+
test_tf_savedmodel_persistence = partial(_unsupervised_model_save_test,
184+
save_fn=partial(_save_ivis_model, save_format='tfs'),
185+
load_fn=_load_ivis_model)
186+
182187
def test_save_overwriting(model_filepath):
183188
model = Ivis(k=15, batch_size=16, epochs=2)
184189
iris = datasets.load_iris()

0 commit comments

Comments
 (0)