We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 273f914 commit 7c5bfb9Copy full SHA for 7c5bfb9
keras_mdn_layer/tests/test_mdn.py
@@ -37,6 +37,6 @@ def test_save_mdn():
37
model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))
38
model.add(mdn.MDN(1, N_MIXES))
39
model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam())
40
- model.save('test_save.h5', overwrite=True, save_format="h5")
41
- m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})
+ model.save('test_save.keras', overwrite=True)
+ m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})
42
assert isinstance(m_2, keras.Sequential)
0 commit comments