From b5d7468a475e6c005d125b7494b6684fd917c394 Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Fri, 26 Oct 2018 10:38:50 +0200 Subject: [PATCH] updated MDN RNN TD example --- notebooks/MDN-RNN-time-distributed-MDN-training.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb b/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb index a58f0e8..4f27dde 100644 --- a/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb +++ b/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb @@ -110,7 +110,7 @@ "inputs = keras.layers.Input(shape=(SEQ_LEN,OUTPUT_DIMENSION), name='inputs')\n", "lstm1_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm1', return_sequences=True)(inputs)\n", "lstm2_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm2', return_sequences=True)(lstm1_out)\n", - "mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES, name='mdn_outputs')(lstm2_out)\n", + "mdn_out = keras.layers.TimeDistributed(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES, name='mdn_outputs'), name='td_mdn')(lstm2_out)\n", "\n", "model = keras.models.Model(inputs=inputs, outputs=mdn_out)\n", "model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer='adam')\n", @@ -204,7 +204,7 @@ "source": [ "# Fit the model\n", "filepath=\"kanji_mdnrnn-{epoch:02d}-{val_acc:.2f}.hdf5\"\n", - "checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')\n", + "checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')\n", "callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint]\n", "\n", "history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_data=(Xval,yval))\n",