Skip to content

Commit d576061

Browse files
committed
updated kanji and robojam examples
1 parent 59f945d commit d576061

4 files changed

+294
-442
lines changed

notebooks/MDN-RNN-RoboJam-touch-generation.ipynb

+98-42
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,24 @@
5454
"plt.style.use('seaborn-talk')"
5555
]
5656
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"# Only for GPU use:\n",
64+
"import os\n",
65+
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
66+
"\n",
67+
"import tensorflow as tf\n",
68+
"config = tf.ConfigProto()\n",
69+
"config.gpu_options.allow_growth = True\n",
70+
"sess = tf.Session(config=config)\n",
71+
"from keras import backend as K\n",
72+
"K.set_session(sess)"
73+
]
74+
},
5775
{
5876
"cell_type": "markdown",
5977
"metadata": {},
@@ -95,83 +113,108 @@
95113
"metadata": {},
96114
"outputs": [],
97115
"source": [
98-
"def perf_df_to_array(perf_df):\n",
116+
"SCALE_FACTOR = 1\n",
117+
"\n",
118+
"def perf_df_to_array(perf_df, include_moving=False):\n",
99119
" \"\"\"Converts a dataframe of a performance into array a,b,dt format.\"\"\"\n",
100120
" perf_df['dt'] = perf_df.time.diff()\n",
101121
" perf_df.dt = perf_df.dt.fillna(0.0)\n",
102122
" # Clean performance data\n",
103123
" # Tiny Performance bounds defined to be in [[0,1],[0,1]], edit to fix this.\n",
104-
" perf_df.set_value(perf_df[perf_df.dt > 5].index, 'dt', 5.0)\n",
105-
" perf_df.set_value(perf_df[perf_df.dt < 0].index, 'dt', 0.0)\n",
106-
" perf_df.set_value(perf_df[perf_df.x > 1].index, 'x', 1.0)\n",
107-
" perf_df.set_value(perf_df[perf_df.x < 0].index, 'x', 0.0)\n",
108-
" perf_df.set_value(perf_df[perf_df.y > 1].index, 'y', 1.0)\n",
109-
" perf_df.set_value(perf_df[perf_df.y < 0].index, 'y', 0.0)\n",
110-
" return np.array(perf_df[['x', 'y', 'dt']])\n",
124+
" perf_df.at[perf_df[perf_df.dt > 5].index, 'dt'] = 5.0\n",
125+
" perf_df.at[perf_df[perf_df.dt < 0].index, 'dt'] = 0.0\n",
126+
" perf_df.at[perf_df[perf_df.x > 1].index, 'x'] = 1.0\n",
127+
" perf_df.at[perf_df[perf_df.x < 0].index, 'x'] = 0.0\n",
128+
" perf_df.at[perf_df[perf_df.y > 1].index, 'y'] = 1.0\n",
129+
" perf_df.at[perf_df[perf_df.y < 0].index, 'y'] = 0.0\n",
130+
" if include_moving:\n",
131+
" output = np.array(perf_df[['x', 'y', 'dt', 'moving']])\n",
132+
" else:\n",
133+
" output = np.array(perf_df[['x', 'y', 'dt']])\n",
134+
" return output\n",
111135
"\n",
112136
"\n",
113137
"def perf_array_to_df(perf_array):\n",
114-
" \"\"\"Converts an array of a performance (a,b,dt format) into a dataframe.\"\"\"\n",
138+
" \"\"\"Converts an array of a performance (a,b,dt(,moving) format) into a dataframe.\"\"\"\n",
115139
" perf_array = perf_array.T\n",
116140
" perf_df = pd.DataFrame({'x': perf_array[0], 'y': perf_array[1], 'dt': perf_array[2]})\n",
141+
" if len(perf_array) == 4:\n",
142+
" perf_df['moving'] = perf_array[3]\n",
143+
" else:\n",
144+
" # As a rule of thumb, could classify taps with dt>0.1 as taps, dt<0.1 as moving touches.\n",
145+
" perf_df['moving'] = 1\n",
146+
" perf_df.at[perf_df[perf_df.dt > 0.1].index, 'moving'] = 0\n",
117147
" perf_df['time'] = perf_df.dt.cumsum()\n",
118148
" perf_df['z'] = 38.0\n",
119-
" # As a rule of thumb, could classify taps with dt>0.1 as taps, dt<0.1 as moving touches.\n",
120-
" perf_df['moving'] = 1\n",
121-
" perf_df.set_value(perf_df[perf_df.dt > 0.1].index, 'moving', 0)\n",
122149
" perf_df = perf_df.set_index(['time'])\n",
123150
" return perf_df[['x', 'y', 'z', 'moving']]\n",
124151
"\n",
125152
"\n",
126-
"def random_touch():\n",
153+
"def random_touch(with_moving=False):\n",
127154
" \"\"\"Generate a random tiny performance touch.\"\"\"\n",
128-
" return np.array([np.random.rand(), np.random.rand(), 0.01])\n",
155+
" if with_moving:\n",
156+
" return np.array([np.random.rand(), np.random.rand(), 0.01, 0])\n",
157+
" else:\n",
158+
" return np.array([np.random.rand(), np.random.rand(), 0.01])\n",
129159
"\n",
130160
"\n",
131-
"def constrain_touch(touch):\n",
161+
"def constrain_touch(touch, with_moving=False):\n",
132162
" \"\"\"Constrain touch values from the MDRNN\"\"\"\n",
133163
" touch[0] = min(max(touch[0], 0.0), 1.0) # x in [0,1]\n",
134164
" touch[1] = min(max(touch[1], 0.0), 1.0) # y in [0,1]\n",
135165
" touch[2] = max(touch[2], 0.001) # dt # define minimum time step\n",
166+
" if with_moving:\n",
167+
" touch[3] = np.greater(touch[3], 0.5) * 1.0\n",
136168
" return touch\n",
137169
"\n",
138-
"def generate_random_tiny_performance(model, n_mixtures, first_touch, time_limit=5.0, steps_limit=1000, temp=1.0):\n",
170+
"\n",
171+
"def generate_random_tiny_performance(model, n_mixtures, first_touch, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False):\n",
139172
" \"\"\"Generates a tiny performance up to 5 seconds in length.\"\"\"\n",
173+
" if predict_moving:\n",
174+
" out_dim = 4\n",
175+
" else:\n",
176+
" out_dim = 3\n",
140177
" time = 0\n",
141178
" steps = 0\n",
142179
" previous_touch = first_touch\n",
143-
" performance = [previous_touch.reshape((3,))]\n",
180+
" performance = [previous_touch.reshape((out_dim,))]\n",
144181
" while (steps < steps_limit and time < time_limit):\n",
145-
" params = model.predict(previous_touch.reshape(1,1,3))\n",
146-
" previous_touch = mdn.sample_from_output(params[0], 3, n_mixtures, temp=temp)\n",
147-
" output_touch = previous_touch.reshape(3,)\n",
148-
" output_touch = constrain_touch(output_touch)\n",
149-
" performance.append(output_touch.reshape((3,)))\n",
182+
" params = model.predict(previous_touch.reshape(1,1,out_dim) * SCALE_FACTOR)\n",
183+
" previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
184+
" output_touch = previous_touch.reshape(out_dim,)\n",
185+
" output_touch = constrain_touch(output_touch, with_moving=predict_moving)\n",
186+
" performance.append(output_touch.reshape((out_dim,)))\n",
150187
" steps += 1\n",
151188
" time += output_touch[2]\n",
152189
" return np.array(performance)\n",
153190
"\n",
154191
"\n",
155-
"def condition_and_generate(model, perf, n_mixtures, time_limit=5.0, steps_limit=1000, temp=1.0):\n",
192+
"def condition_and_generate(model, perf, n_mixtures, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False):\n",
156193
" \"\"\"Conditions the network on an existing tiny performance, then generates a new one.\"\"\"\n",
194+
" if predict_moving:\n",
195+
" out_dim = 4\n",
196+
" else:\n",
197+
" out_dim = 3\n",
157198
" time = 0\n",
158199
" steps = 0\n",
159200
" # condition\n",
160201
" for touch in perf:\n",
161-
" params = model.predict(touch.reshape(1,1,3))\n",
162-
" previous_touch = mdn.sample_from_output(params[0], 3, n_mixtures, temp=temp)\n",
163-
" output = [previous_touch.reshape((3,))]\n",
202+
" params = model.predict(touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n",
203+
" previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
204+
" output = [previous_touch.reshape((out_dim,))]\n",
205+
" # generate\n",
164206
" while (steps < steps_limit and time < time_limit):\n",
165-
" params = model.predict(previous_touch.reshape(1,1,3))\n",
166-
" previous_touch = mdn.sample_from_output(params[0], 3, n_mixtures, temp=temp)\n",
167-
" output_touch = previous_touch.reshape(3,)\n",
168-
" output_touch = constrain_touch(output_touch)\n",
169-
" output.append(output_touch.reshape((3,)))\n",
207+
" params = model.predict(previous_touch.reshape(1, 1, out_dim) * SCALE_FACTOR)\n",
208+
" previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR\n",
209+
" output_touch = previous_touch.reshape(out_dim,)\n",
210+
" output_touch = constrain_touch(output_touch, with_moving=predict_moving)\n",
211+
" output.append(output_touch.reshape((out_dim,)))\n",
170212
" steps += 1\n",
171213
" time += output_touch[2]\n",
172214
" net_output = np.array(output)\n",
173215
" return net_output\n",
174216
"\n",
217+
"\n",
175218
"def divide_performance_into_swipes(perf_df):\n",
176219
" \"\"\"Divides a performance into a sequence of swipe dataframes for plotting.\"\"\"\n",
177220
" touch_starts = perf_df[perf_df.moving == 0].index\n",
@@ -184,10 +227,15 @@
184227
" performance_swipes.append(remainder)\n",
185228
" return performance_swipes\n",
186229
"\n",
187-
"def plot_2D(perf_df, name=\"foo\", saving=False):\n",
230+
"\n",
231+
"input_colour = 'darkblue'\n",
232+
"gen_colour = 'firebrick'\n",
233+
"\n",
234+
"\n",
235+
"def plot_2D(perf_df, name=\"foo\", saving=False, figsize=(8, 8)):\n",
188236
" \"\"\"Plot a 2D representation of a performance 2D\"\"\"\n",
189237
" swipes = divide_performance_into_swipes(perf_df)\n",
190-
" plt.figure(figsize=(8, 8))\n",
238+
" plt.figure(figsize=figsize)\n",
191239
" for swipe in swipes:\n",
192240
" p = plt.plot(swipe.x, swipe.y, 'o-')\n",
193241
" plt.setp(p, color=gen_colour, linewidth=5.0)\n",
@@ -200,10 +248,11 @@
200248
" plt.close()\n",
201249
" else:\n",
202250
" plt.show()\n",
203-
" \n",
204-
"def plot_double_2d(perf1, perf2, name=\"foo\", saving=False):\n",
251+
"\n",
252+
"\n",
253+
"def plot_double_2d(perf1, perf2, name=\"foo\", saving=False, figsize=(8, 8)):\n",
205254
" \"\"\"Plot two performances in 2D\"\"\"\n",
206-
" plt.figure(figsize=(8, 8))\n",
255+
" plt.figure(figsize=figsize)\n",
207256
" swipes = divide_performance_into_swipes(perf1)\n",
208257
" for swipe in swipes:\n",
209258
" p = plt.plot(swipe.x, swipe.y, 'o-')\n",
@@ -387,10 +436,17 @@
387436
"outputs": [],
388437
"source": [
389438
"# Train the model\n",
390-
"history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, validation_split=VAL_SPLIT)\n",
439+
"\n",
440+
"# Define callbacks\n",
441+
"filepath=\"robojam_mdrnn-E{epoch:02d}-VL{val_loss:.2f}.h5\"\n",
442+
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min')\n",
443+
"early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)\n",
444+
"callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping]\n",
445+
"\n",
446+
"history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_split=VAL_SPLIT)\n",
391447
"\n",
392448
"# Save the Model\n",
393-
"#model.save('robojam-mdn-rnn.h5') # creates a HDF5 file of the model\n",
449+
"model.save('robojam-mdrnn.h5') # creates a HDF5 file of the model\n",
394450
"\n",
395451
"# Plot the loss\n",
396452
"%matplotlib inline\n",
@@ -428,7 +484,7 @@
428484
"decoder.summary()\n",
429485
"\n",
430486
"# decoder.set_weights(model.get_weights())\n",
431-
"decoder.load_weights(\"robojam-mdn-rnn.h5\")"
487+
"decoder.load_weights(\"robojam-mdrnn.h5\")"
432488
]
433489
},
434490
{
@@ -451,7 +507,7 @@
451507
"ex = microjam_corpus[t:t+length] #sequences[600]\n",
452508
"\n",
453509
"decoder.reset_states()\n",
454-
"p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=0.2)\n",
510+
"p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=1.0, sigma_temp=0.05)\n",
455511
"plot_double_2d(perf_array_to_df(ex), perf_array_to_df(p))"
456512
]
457513
},
@@ -472,7 +528,7 @@
472528
"source": [
473529
"decoder.reset_states()\n",
474530
"t = random_touch()\n",
475-
"p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=0.1)\n",
531+
"p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=1.1, sigma_temp=0.05)\n",
476532
"plot_2D(perf_array_to_df(p))"
477533
]
478534
},
@@ -506,7 +562,7 @@
506562
"name": "python",
507563
"nbconvert_exporter": "python",
508564
"pygments_lexer": "ipython3",
509-
"version": "3.6.6"
565+
"version": "3.6.7"
510566
}
511567
},
512568
"nbformat": 4,

notebooks/MDN-RNN-kanji-generation-example.ipynb

+186-14
Large diffs are not rendered by default.

notebooks/MDN-RNN-time-distributed-MDN-training.ipynb

+10-8
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,10 @@
203203
"outputs": [],
204204
"source": [
205205
"# Fit the model\n",
206-
"filepath=\"kanji_mdnrnn-{epoch:02d}.hdf5\"\n",
207-
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')\n",
208-
"callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint]\n",
209-
"\n",
206+
"filepath=\"kanji_mdnrnn-{epoch:02d}.h5\"\n",
207+
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min')\n",
208+
"early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)\n",
209+
"callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping]\n",
210210
"history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_data=(Xval,yval))\n",
211211
"model.save('kanji_mdnrnn_model_time_distributed.h5') # creates a HDF5 file 'my_model.h5'"
212212
]
@@ -259,7 +259,8 @@
259259
"decoder.summary()\n",
260260
"\n",
261261
"#decoder.load_weights('kanji_mdnrnn_model_time_distributed.h5') # load weights independently from file\n",
262-
"decoder.load_weights('kanji_mdnrnn-99.hdf5')"
262+
"#decoder.load_weights('kanji_mdnrnn-99.hdf5')\n",
263+
"decoder.load_weights('kanji_mdnrnn_model_time_distributed.h5')"
263264
]
264265
},
265266
{
@@ -397,14 +398,15 @@
397398
"outputs": [],
398399
"source": [
399400
"# Predict a character and plot the result.\n",
400-
"temperature = 2 # seems to work well with rather high temperature (2.5)\n",
401+
"temperature = 1.5 # seems to work well with rather high temperature (2.5)\n",
402+
"sigma_temp = 0.01\n",
401403
"\n",
402404
"p = zero_start_position()\n",
403405
"sketch = [p.reshape(3,)]\n",
404406
"\n",
405407
"for i in range(100):\n",
406408
" params = decoder.predict(p.reshape(1,1,3))\n",
407-
" p = mdn.sample_from_output(params[0], OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=temperature)\n",
409+
" p = mdn.sample_from_output(params[0], OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=temperature, sigma_temp=sigma_temp)\n",
408410
" sketch.append(p.reshape((3,)))\n",
409411
"\n",
410412
"sketch = np.array(sketch)\n",
@@ -432,7 +434,7 @@
432434
"name": "python",
433435
"nbconvert_exporter": "python",
434436
"pygments_lexer": "ipython3",
435-
"version": "3.6.6"
437+
"version": "3.6.7"
436438
}
437439
},
438440
"nbformat": 4,

0 commit comments

Comments
 (0)