Skip to content

Commit

Permalink
Merge pull request #47 from cpmpercussion/move_to_tf2_api
Browse files Browse the repository at this point in the history
Move to tf2 api, thanks u7484835
  • Loading branch information
cpmpercussion authored Jun 21, 2024
2 parents 7c5bfb9 + 5283d73 commit 336cfc9
Show file tree
Hide file tree
Showing 11 changed files with 428 additions and 113 deletions.
10 changes: 5 additions & 5 deletions examples/4-robojam-touch-generation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from tensorflow.compat.v1 import keras
from tensorflow.compat.v1.keras import backend as K
from tensorflow.compat.v1.keras.layers import Dense, Input
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Input
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf
import math
import h5py
import random
Expand All @@ -25,7 +25,7 @@

# Download microjam performance data if needed.
import urllib.request
url = 'http://folk.uio.no/charlepm/datasets/TinyPerformanceCorpus.h5'
url = 'https://github.com/cpmpercussion/creative-prediction-datasets/raw/main/datasets/TinyPerformanceCorpus.h5'
urllib.request.urlretrieve(url, './TinyPerformanceCorpus.h5')


Expand Down
14 changes: 8 additions & 6 deletions keras_mdn_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
Provided under MIT License
"""
from .version import __version__
from tensorflow.compat.v1 import keras
from tensorflow.compat.v1.keras import backend as K
from tensorflow.compat.v1.keras import layers
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow as tf
from tensorflow_probability import distributions as tfd


Expand Down Expand Up @@ -180,6 +180,7 @@ def split_mixture_params(params, output_dim, num_mixes):
output_dim -- the dimension of the normal models in the mixture model
num_mixes -- the number of mixtures represented
"""
assert len(params) == num_mixes + (output_dim * 2 * num_mixes), "The size of params needs to match the mixture configuration"
mus = params[:num_mixes * output_dim]
sigs = params[num_mixes * output_dim:2 * num_mixes * output_dim]
pi_logits = params[-num_mixes:]
Expand Down Expand Up @@ -236,8 +237,9 @@ def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0):
sigma_temp -- the temperature for sampling from the normal distribution (default 1.0)
Returns:
One sample from the the mixture model.
One sample from the the mixture model, that is a numpy array of length output_dim
"""
assert len(params) == num_mixes + (output_dim * 2 * num_mixes), "The size of params needs to match the mixture configuration"
mus, sigs, pi_logits = split_mixture_params(params, output_dim, num_mixes)
pis = softmax(pi_logits, t=temp)
m = sample_from_categorical(pis)
Expand All @@ -249,4 +251,4 @@ def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0):
cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared.
cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature
sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)
return sample
return sample[0]
32 changes: 31 additions & 1 deletion keras_mdn_layer/tests/test_mdn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tensorflow.compat.v1 import keras
from tensorflow import keras
import keras_mdn_layer as mdn
import numpy as np

Expand Down Expand Up @@ -40,3 +40,33 @@ def test_save_mdn():
model.save('test_save.keras', overwrite=True)
m_2 = keras.models.load_model('test_save.keras', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})
assert isinstance(m_2, keras.Sequential)

def test_output_shapes():
"""Checks that the output shapes on an MDN model end up correct. Builds a 1-layer LSTM network to do it."""
# parameters
N_HIDDEN = 5
N_MIXES = 10
N_DIMENSION = 7

# set up a 1-layer MDRNN
inputs = keras.layers.Input(shape=(1,N_DIMENSION))
lstm_1_state_h_input = keras.layers.Input(shape=(N_HIDDEN,))
lstm_1_state_c_input = keras.layers.Input(shape=(N_HIDDEN,))
lstm_1_state_input = [lstm_1_state_h_input, lstm_1_state_c_input]
lstm_1, state_h_1, state_c_1 = keras.layers.LSTM(N_HIDDEN, return_state=True)(inputs, initial_state=lstm_1_state_input)
lstm_1_state_output = [state_h_1, state_c_1]
mdn_out = mdn.MDN(N_DIMENSION, N_MIXES)(lstm_1)
decoder = keras.Model(inputs=[inputs] + lstm_1_state_input, outputs=[mdn_out] + lstm_1_state_output)

# create starting input and generate one output
starting_input = np.zeros((1, 1, N_DIMENSION), dtype=np.float32)
initial_state = [np.zeros((1,N_HIDDEN), dtype=np.float32), np.zeros((1,N_HIDDEN), dtype=np.float32)]
output_list = decoder([starting_input] + initial_state) # run the network
mdn_parameters = output_list[0][0].numpy()

# sample from the output to test sampling functions
generated_sample = mdn.sample_from_output(mdn_parameters, N_DIMENSION, N_MIXES)
print("Sample shape:", generated_sample.shape)
print("Sample:", generated_sample)
# test that the length of the generated sample is the same as N_DIMENSION
assert len(generated_sample) == N_DIMENSION
4 changes: 2 additions & 2 deletions notebooks/MDN-1D-sine-prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))\n",
"model.add(keras.layers.Dense(N_HIDDEN, activation='relu'))\n",
"model.add(mdn.MDN(1, N_MIXES))\n",
"model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer=keras.optimizers.Adam()) #, metrics=[mdn.get_mixture_mse_accuracy(1,N_MIXES)])\n",
"model.compile(loss=mdn.get_mixture_loss_func(1,N_MIXES), optimizer='adam') #, metrics=[mdn.get_mixture_mse_accuracy(1,N_MIXES)])\n",
"model.summary()"
]
},
Expand Down Expand Up @@ -223,7 +223,7 @@
"source": [
"# Plot the samples\n",
"plt.figure(figsize=(8, 8))\n",
"plt.plot(x_data,y_data,'ro', x_test, y_samples[:,:,0], 'bo',alpha=0.3)\n",
"plt.plot(x_data,y_data,'ro', x_test, y_samples[:,0], 'bo',alpha=0.3)\n",
"plt.show()\n",
"# These look pretty good!"
]
Expand Down
15 changes: 8 additions & 7 deletions notebooks/MDN-2D-spiral-prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))\n",
"model.add(keras.layers.Dense(N_HIDDEN, activation='relu'))\n",
"model.add(mdn.MDN(OUTPUT_DIMS, N_MIXES))\n",
"model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer=keras.optimizers.Adam())\n",
"model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer='adam')\n",
"model.summary()"
]
},
Expand All @@ -114,7 +114,7 @@
"metadata": {},
"outputs": [],
"source": [
"history = model.fit(x=x_input, y=y_input, batch_size=128, epochs=300, validation_split=0.15, callbacks=[keras.callbacks.TerminateOnNaN()])"
"history = model.fit(x=x_input, y=y_input, batch_size=128, epochs=200, validation_split=0.15, callbacks=[keras.callbacks.TerminateOnNaN()])"
]
},
{
Expand Down Expand Up @@ -159,13 +159,14 @@
"outputs": [],
"source": [
"## Sample on some test data:\n",
"x_test = np.float32(np.arange(-15,15,0.1))\n",
"x_test = np.float32(np.arange(-15,15,0.05))\n",
"NTEST = x_test.size\n",
"\n",
"print(\"Testing:\", NTEST, \"samples.\")\n",
"x_test = x_test.reshape(NTEST,1) # needs to be a matrix, not a vector\n",
"x_test_pred = x_test.reshape(NTEST,1) # needs to be a matrix for predictions but a vector for display, not a vector\n",
"\n",
"# Make predictions from the model\n",
"y_test = model.predict(x_test)\n",
"y_test = model.predict(x_test_pred)\n",
"# y_test contains parameters for distributions, not actual points on the graph.\n",
"# To find points on the graph, we need to sample from each distribution.\n",
"\n",
Expand All @@ -187,8 +188,8 @@
"# Plot the predicted samples.\n",
"fig = plt.figure(figsize=(8, 8))\n",
"ax = fig.add_subplot(111, projection='3d')\n",
"ax.scatter(x_data, y_data, z_data, alpha=0.1, c='r') #c=perf_down_sampled.moving\n",
"ax.scatter(y_samples.T[0], y_samples.T[1], x_test, alpha=0.1, c='b') #c=perf_down_sampled.moving\n",
"ax.scatter(x_data, y_data, z_data, alpha=0.05, c='r') #c=perf_down_sampled.moving\n",
"ax.scatter(y_samples.T[0], y_samples.T[1], x_test, alpha=0.2, c='b') #c=perf_down_sampled.moving\n",
"plt.show()"
]
},
Expand Down
81 changes: 33 additions & 48 deletions notebooks/MDN-RNN-RoboJam-touch-generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.compat.v1 import keras\n",
"from tensorflow.compat.v1.keras import backend as K\n",
"from tensorflow.compat.v1.keras.layers import Dense, Input\n",
"from tensorflow import keras\n",
"from tensorflow.keras.layers import Dense, Input\n",
"import numpy as np\n",
"import tensorflow.compat.v1 as tf\n",
"import tensorflow as tf\n",
"import math\n",
"import h5py\n",
"import random\n",
Expand All @@ -54,22 +53,6 @@
"gen_colour = 'firebrick'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Only for GPU use:\n",
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
"\n",
"config = tf.ConfigProto()\n",
"config.gpu_options.allow_growth = True\n",
"sess = tf.Session(config=config)\n",
"K.set_session(sess)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -131,7 +114,6 @@
" output = np.array(perf_df[['x', 'y', 'dt']])\n",
" return output\n",
"\n",
"\n",
"def perf_array_to_df(perf_array):\n",
" \"\"\"Converts an array of a performance (a,b,dt(,moving) format) into a dataframe.\"\"\"\n",
" perf_array = perf_array.T\n",
Expand All @@ -140,8 +122,7 @@
" perf_df['moving'] = perf_array[3]\n",
" else:\n",
" # As a rule of thumb, could classify taps with dt>0.1 as taps, dt<0.1 as moving touches.\n",
" perf_df['moving'] = 1\n",
" perf_df.at[perf_df[perf_df.dt > 0.1].index, 'moving'] = 0\n",
" perf_df['moving'] = perf_df['dt'].apply(lambda dt: 0 if dt > 0.1 else 1)\n",
" perf_df['time'] = perf_df.dt.cumsum()\n",
" perf_df['z'] = 38.0\n",
" perf_df = perf_df.set_index(['time'])\n",
Expand Down Expand Up @@ -177,8 +158,9 @@
" previous_touch = first_touch\n",
" performance = [previous_touch.reshape((out_dim,))]\n",
" while (steps < steps_limit and time < time_limit):\n",
" params = model.predict(previous_touch.reshape(1,1,out_dim) * SCALE_FACTOR)\n",
" previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
" net_output = model(previous_touch.reshape(1,1,out_dim) * SCALE_FACTOR)\n",
" mdn_params = net_output[0].numpy()\n",
" previous_touch = mdn.sample_from_output(mdn_params, out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
" output_touch = previous_touch.reshape(out_dim,)\n",
" output_touch = constrain_touch(output_touch, with_moving=predict_moving)\n",
" performance.append(output_touch.reshape((out_dim,)))\n",
Expand All @@ -197,13 +179,15 @@
" steps = 0\n",
" # condition\n",
" for touch in perf:\n",
" params = model.predict(touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n",
" previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
" net_output = model(touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n",
" mdn_params = net_output[0].numpy()\n",
" previous_touch = mdn.sample_from_output(mdn_params, out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
" output = [previous_touch.reshape((out_dim,))]\n",
" # generate\n",
" while (steps < steps_limit and time < time_limit):\n",
" params = model.predict(previous_touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n",
" previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
" net_output = model(previous_touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n",
" mdn_params = net_output[0].numpy()\n",
" previous_touch = mdn.sample_from_output(mdn_params, out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
" output_touch = previous_touch.reshape(out_dim,)\n",
" output_touch = constrain_touch(output_touch, with_moving=predict_moving)\n",
" output.append(output_touch.reshape((out_dim,)))\n",
Expand Down Expand Up @@ -340,8 +324,8 @@
"# Training Hyperparameters:\n",
"SEQ_LEN = 30\n",
"BATCH_SIZE = 256\n",
"HIDDEN_UNITS = 256\n",
"EPOCHS = 100\n",
"HIDDEN_UNITS = 128\n",
"EPOCHS = 10\n",
"VAL_SPLIT=0.15\n",
"\n",
"# Set random seed for reproducibility\n",
Expand Down Expand Up @@ -388,11 +372,13 @@
"OUTPUT_DIMENSION = 3\n",
"NUMBER_MIXTURES = 5\n",
"\n",
"model = keras.Sequential()\n",
"model.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(None,SEQ_LEN,OUTPUT_DIMENSION), return_sequences=True))\n",
"model.add(keras.layers.LSTM(HIDDEN_UNITS))\n",
"model.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES))\n",
"model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam())\n",
"inputs = keras.layers.Input(shape=(SEQ_LEN,OUTPUT_DIMENSION))\n",
"lstm_1 = keras.layers.LSTM(HIDDEN_UNITS, return_sequences=True)(inputs)\n",
"lstm_2 = keras.layers.LSTM(HIDDEN_UNITS)(lstm_1)\n",
"mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)(lstm_2)\n",
"\n",
"model = keras.Model(inputs=inputs, outputs=mdn_out, name=\"robojam-training\")\n",
"model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer='adam')\n",
"model.summary()"
]
},
Expand Down Expand Up @@ -425,15 +411,15 @@
"# Train the model\n",
"\n",
"# Define callbacks\n",
"filepath=\"robojam_mdrnn-E{epoch:02d}-VL{val_loss:.2f}.h5\"\n",
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min')\n",
"filepath=\"robojam_mdrnn-E{epoch:02d}-VL{val_loss:.2f}.keras\"\n",
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, verbose=1, save_best_only=True, mode='min')\n",
"early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)\n",
"callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping]\n",
"\n",
"history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_split=VAL_SPLIT)\n",
"\n",
"# Save the Model\n",
"model.save('robojam-mdrnn.h5') # creates a HDF5 file of the model\n",
"model.save('robojam-mdrnn.keras')\n",
"\n",
"# Plot the loss\n",
"%matplotlib inline\n",
Expand Down Expand Up @@ -462,16 +448,15 @@
"metadata": {},
"outputs": [],
"source": [
"# Decoding Model\n",
"decoder = keras.Sequential()\n",
"decoder.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(1,1,OUTPUT_DIMENSION), return_sequences=True, stateful=True))\n",
"decoder.add(keras.layers.LSTM(HIDDEN_UNITS, stateful=True))\n",
"decoder.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES))\n",
"decoder.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam())\n",
"decoder.summary()\n",
"# Stateful Decoding Model\n",
"inputs = keras.layers.Input(batch_input_shape=(1,1,OUTPUT_DIMENSION))\n",
"lstm_1 = keras.layers.LSTM(HIDDEN_UNITS, return_sequences=True, stateful=True)(inputs)\n",
"lstm_2 = keras.layers.LSTM(HIDDEN_UNITS, stateful=True)(lstm_1)\n",
"mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)(lstm_2)\n",
"\n",
"# decoder.set_weights(model.get_weights())\n",
"decoder.load_weights(\"robojam-mdrnn.h5\")"
"decoder = keras.Model(inputs=inputs, outputs=mdn_out, name=\"robojam-generating\")\n",
"decoder.summary()\n",
"decoder.load_weights(\"robojam-mdrnn.keras\")"
]
},
{
Expand Down
Loading

0 comments on commit 336cfc9

Please sign in to comment.