Skip to content

Commit

Permalink
updated RoboJam and Kanji notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Feb 26, 2019
1 parent d576061 commit c12f4cd
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 58 deletions.
70 changes: 30 additions & 40 deletions notebooks/MDN-RNN-RoboJam-touch-generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@
"import time\n",
"import pandas as pd\n",
"from context import * # imports MDN\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"input_colour = 'darkblue'\n",
"gen_colour = 'firebrick'\n",
Expand Down Expand Up @@ -228,48 +229,38 @@
" return performance_swipes\n",
"\n",
"\n",
"input_colour = 'darkblue'\n",
"gen_colour = 'firebrick'\n",
"\n",
"input_colour = \"#4388ff\"\n",
"gen_colour = \"#ec0205\"\n",
"\n",
"def plot_2D(perf_df, name=\"foo\", saving=False, figsize=(8, 8)):\n",
"def plot_perf_on_ax(perf_df, ax, color=\"#ec0205\", linewidth=3, alpha=0.5):\n",
" \"\"\"Plot a 2D representation of a performance 2D\"\"\"\n",
" swipes = divide_performance_into_swipes(perf_df)\n",
" plt.figure(figsize=figsize)\n",
" for swipe in swipes:\n",
" p = plt.plot(swipe.x, swipe.y, 'o-')\n",
" plt.setp(p, color=gen_colour, linewidth=5.0)\n",
" plt.ylim(1.0,0)\n",
" plt.xlim(0,1.0)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" p = ax.plot(swipe.x, swipe.y, 'o-', alpha=alpha, markersize=linewidth)\n",
" plt.setp(p, color=color, linewidth=linewidth)\n",
" ax.set_ylim([1.0,0])\n",
" ax.set_xlim([0,1.0])\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
"\n",
"def plot_2D(perf_df, name=\"foo\", saving=False, figsize=(5, 5)):\n",
" \"\"\"Plot a 2D representation of a performance 2D\"\"\"\n",
" fig, ax = plt.subplots(figsize=(figsize))\n",
" plot_perf_on_ax(perf_df, ax, color=gen_colour, linewidth=5, alpha=0.7)\n",
" if saving:\n",
" plt.savefig(name+\".png\", bbox_inches='tight')\n",
" plt.close()\n",
" else:\n",
" plt.show()\n",
"\n",
" fig.savefig(name+\".png\", bbox_inches='tight')\n",
"\n",
"def plot_double_2d(perf1, perf2, name=\"foo\", saving=False, figsize=(8, 8)):\n",
" \"\"\"Plot two performances in 2D\"\"\"\n",
" plt.figure(figsize=figsize)\n",
" swipes = divide_performance_into_swipes(perf1)\n",
" for swipe in swipes:\n",
" p = plt.plot(swipe.x, swipe.y, 'o-')\n",
" plt.setp(p, color=input_colour, linewidth=5.0)\n",
" swipes = divide_performance_into_swipes(perf2)\n",
" for swipe in swipes:\n",
" p = plt.plot(swipe.x, swipe.y, 'o-')\n",
" plt.setp(p, color=gen_colour, linewidth=5.0)\n",
" plt.ylim(1.0,0)\n",
" plt.xlim(0,1.0)\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" fig, ax = plt.subplots(figsize=(figsize))\n",
" plot_perf_on_ax(perf1, ax, color=input_colour, linewidth=5, alpha=0.7)\n",
" plot_perf_on_ax(perf2, ax, color=gen_colour, linewidth=5, alpha=0.7)\n",
" if saving:\n",
" plt.savefig(name+\".png\", bbox_inches='tight')\n",
" plt.close()\n",
" else:\n",
" plt.show()"
" fig.savefig(name+\".png\", bbox_inches='tight')\n",
" \n",
"# fig, ax = plt.subplots(figsize=(5, 5))\n",
"# plot_perf_on_ax(perf_array_to_df(p), ax, color=\"#ec0205\", linewidth=4, alpha=0.7)\n",
"# fig.show()"
]
},
{
Expand Down Expand Up @@ -338,7 +329,6 @@
" - Training model will have a sequence length of 30 (prediction model: 1 in, 1 out)\n",
" \n",
"![RoboJam MDN RNN Model](https://preview.ibb.co/cKZk9T/robojam_mdn_diagram.png)\n",
"\n",
" \n",
"- Here's the model parameters and training data preparation. \n",
"- We end up with 172K training examples."
Expand Down Expand Up @@ -493,7 +483,7 @@
"source": [
"Plotting some conditioned performances.\n",
"\n",
"This model seems to work best with a very low temperature (0.1). Might be able to do better with a large dataset, or larger model! (?)"
"This model seems to work best with a very low temperature for sampling from the Gaussian elements (`sigma_temp=0.05`) and a temperature for choosing between mixtures (pi-temperature) of around 1.0."
]
},
{
Expand All @@ -507,8 +497,8 @@
"ex = microjam_corpus[t:t+length] #sequences[600]\n",
"\n",
"decoder.reset_states()\n",
"p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=1.0, sigma_temp=0.05)\n",
"plot_double_2d(perf_array_to_df(ex), perf_array_to_df(p))"
"p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=1.5, sigma_temp=0.05)\n",
"plot_double_2d(perf_array_to_df(ex), perf_array_to_df(p), figsize=(4,4))"
]
},
{
Expand All @@ -528,8 +518,8 @@
"source": [
"decoder.reset_states()\n",
"t = random_touch()\n",
"p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=1.1, sigma_temp=0.05)\n",
"plot_2D(perf_array_to_df(p))"
"p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=1.2, sigma_temp=0.01)\n",
"plot_2D(perf_array_to_df(p), figsize=(4,4))"
]
},
{
Expand Down
140 changes: 122 additions & 18 deletions notebooks/MDN-RNN-time-distributed-MDN-training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"import keras\n",
"from context import * # imports the MDN layer \n",
Expand Down Expand Up @@ -45,9 +53,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"('./kanji.rdp25.npz', <http.client.HTTPMessage at 0x7fa3e2847860>)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Train from David Ha's Kanji dataset from Sketch-RNN: https://github.com/hardmaru/sketch-rnn-datasets\n",
"# Other datasets in \"Sketch 3\" format should also work.\n",
Expand All @@ -69,9 +88,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training kanji: 10358\n",
"Validation kanji: 600\n",
"Testing kanji: 500\n"
]
}
],
"source": [
"with np.load('./kanji.rdp25.npz') as data:\n",
" train_set = data['train']\n",
Expand All @@ -92,9 +121,31 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"inputs (InputLayer) (None, 50, 3) 0 \n",
"_________________________________________________________________\n",
"lstm1 (LSTM) (None, 50, 256) 266240 \n",
"_________________________________________________________________\n",
"lstm2 (LSTM) (None, 50, 256) 525312 \n",
"_________________________________________________________________\n",
"td_mdn (TimeDistributed) (None, 50, 70) 17990 \n",
"=================================================================\n",
"Total params: 809,542\n",
"Trainable params: 809,542\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# Training Hyperparameters:\n",
"SEQ_LEN = 50\n",
Expand Down Expand Up @@ -130,9 +181,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training examples:\n",
"X: (154279, 50, 3)\n",
"y: (154279, 50, 3)\n"
]
}
],
"source": [
"# Functions for slicing up data\n",
"def slice_sequence_examples(sequence, num_steps):\n",
Expand Down Expand Up @@ -168,9 +229,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training examples:\n",
"X: (8928, 50, 3)\n",
"y: (8928, 50, 3)\n"
]
}
],
"source": [
"# Prepare validation data as X and Y.\n",
"slices = []\n",
Expand Down Expand Up @@ -244,9 +315,29 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"lstm_1 (LSTM) (1, 1, 256) 266240 \n",
"_________________________________________________________________\n",
"lstm_2 (LSTM) (1, 256) 525312 \n",
"_________________________________________________________________\n",
"mdn_1 (MDN) (1, 70) 17990 \n",
"=================================================================\n",
"Total params: 809,542\n",
"Trainable params: 809,542\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"# Decoding Model\n",
"# Same as training model except for dimension and mixtures.\n",
Expand Down Expand Up @@ -274,7 +365,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -321,7 +412,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -391,11 +482,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {
"scrolled": true
},
"outputs": [],
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg baseProfile=\"full\" height=\"244.18982567367752\" version=\"1.1\" width=\"159.97207010041976\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"><defs/><rect fill=\"white\" height=\"244.18982567367752\" width=\"159.97207010041976\" x=\"0\" y=\"0\"/><path d=\"M73.16515202279237,25 m0.0,0.0 m-34.67354909397204,1.6999544679462064 l1.2270268001350046,6.5445035376313845 -0.5708425890942318,41.61957807495959 m-14.147787139861096,-25.999227374543576 l5.610492050199192,1.8445756139703013 7.456378154620832,0.9592931208009936 l24.281225921703257,-4.290794133981884 m-21.265233573015582,15.77353682626787 l5.208090429674627,2.65424867382075 6.121680655193215,1.7704861064460389 m10.450784149172987,-16.64599174091382 l10.678654603310228,0.6073467188547201 24.382424016040815,-2.2727920365529193 m-45.741715140888026,22.037427081526925 l8.254333858925236,-0.0579562789311176 m41.79251024423384,-13.429875840835967 l4.415989602919442,0.49634210130026013 2.437233870246526,1.1896038071869723 l0.7785985077187373,1.7786542713895621 -2.81454383862975,17.875032534731023 m-58.59165303633291,0.8483468664585445 l61.98871964988354,-5.593724498128884 m-57.70274329442421,-3.8456319077379466 l1.8743106493081076,2.7601238414062763 0.8507976282615768,3.313686246509133 l0.18245103553031433,33.08474490989249 1.3521355931901835,4.002305699881221 l3.217697336357727,3.0940797808091602 6.996763683309625,2.2681637124931324 l12.430605392889836,1.0338306139207825 25.997681746146338,0.03307595604957769 l17.231325008945625,-1.422613725860504 6.403611419741662,-2.37651581918968 l3.794351455087284,-3.762538676168176 1.8991123210994898,-6.118099452631576 m-65.70767495457584,-59.04070308379762 l8.373250545833612,3.2858953823897306 11.447602044767061,6.065352330193637 m-42.35245503345113,18.040315308838636 l10.927460660887606,5.430454927327993 8.656028457557388,5.41301652764957 m-10.676762714881848,5.065189525787988 l16.503196971949148,3.1289540919649115 7.659929891210453,3.6420219818203274 m-47.28641343233738,31.240316874273674 l6.71227594034642,0.5414302590154325 21.44210527169062,-2.809780673708531 l9.17185689782044,0.22263406412836897 m-17.404765197659028,3.703386695025114 l2.6528435454951005,4.719220170427025 0.25191967245783,42.76291623193676 m-20.559962161486624,-21.1292790194749 l2.6722719935009245,1.1802014652912376 4.892041933730098,-0.3474041470988909 l43.363149727153946,-10.204695047436674 m-3.5695881423062383,-19.86491016338829 l2.479148126423477,3.3371240575327645 1.8273539876741904,12.057539549142513 l1.968352406530112,16.48005304021596 m-10.051027127885208,-4.131630609052722 l2.776146690414147,2.3940681989973362 1.5225818195545688,1.1226672436536174 l1.1525695916238874,1.4012029092821388 m9.805245189377763,-3.8934176088059145 l0.20329037675582845,1.5852614453657528 -1.270937821096872,4.589417675239919 m-17.63311393448342,1.8121000205034652 l1.6513247870931977,2.0733619898605395 m-2.323064061244624,-11.52195798017885 l9.12907407479106,3.0362207092706517 7.8578296691859375,2.1381934995409773 m-4.387263822835018,3.5291824877832876 l1.8550041019330588,3.4620619446537213 m-0.628905464447897,4.260338557267909 l-0.5053915935844143,1.4119863630846663 -3.7802206127786624,5.134151173967034 m-14.822738611087157,-6.614431341803379 l1.4913206515137998,1.4036057024891502 0.13171697712446925,3.8286269872773535 m-13.299747677442307,2.789117518976471 l6.3006674852805284,-0.7083616666690149 12.591027923096728,-1.8495719876450167 l5.54842966046134,-0.3315682902591638 1.0683800071418665,0.4864742647690695 l0.4378514468981187,1.6661489259482662 -0.9584912624200326,15.423775801528873 l-0.7923447195273654,4.209539073850624 -0.8435920006144774,2.7917947495548 l-1.1058023992321888,1.3124693479632699 -2.09181404793987,0.4528108767777787 m-41.05879896823274,-13.667826320973932 l28.790319953604726,4.800734440737193 m-30.06105278119629,10.800174191796557 l37.24149108091147,-2.322857087672493 m-6.001720062933102,-49.66747457949144 l3.306675608792087,2.9371949109086666 1.797581060237075,3.181513260207713 l-0.00020853187087258723,49.540164697814994 -1.8878813282488678,7.6754249133929555 l-2.5588594191659824,3.2567098369310683 \" fill=\"none\" stroke=\"black\" stroke-width=\"2\"/></svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Predict a character and plot the result.\n",
"temperature = 1.5 # seems to work well with rather high temperature (2.5)\n",
Expand Down

0 comments on commit c12f4cd

Please sign in to comment.