Skip to content

Commit

Permalink
Merge branch 'master' of github.com:cpmpercussion/keras-mdn-layer
Browse files Browse the repository at this point in the history
* 'master' of github.com:cpmpercussion/keras-mdn-layer:
  Update README.md
  fixed small error in sampling procedure in MDN-2D notebook
  • Loading branch information
cpmpercussion committed May 5, 2019
2 parents e7d4723 + a99f912 commit fbf95e1
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 28 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ Two important functions are provided for training and prediction:

## Installation

You clone or download this repository and then install via `python setup.py install`, or copy the `mdn` folder into your own project.
This project requires Python 3.6+. You can clone or download this repository and then install via `python setup.py install`, or copy the `mdn` folder into your own project.

You can easily install this package directly from Github via `pip` like so:

pip install git+git://github.com/cpmpercussion/keras-mdn-layer.git#egg=keras-mdn-layer


And finally, import the `mdn` module in Python: `import mdn`

Expand Down
34 changes: 7 additions & 27 deletions notebooks/MDN-2D-spiral-prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,6 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sampling Functions\n",
"\n",
"The MDN model outputs parameters of a mixture model---a list of means (mu), variances (sigma), and weights (pi).\n",
"\n",
"The MDN package contains some functions to split up these parameters and sample from the normal distributions that they form.\n",
"\n",
"We use \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -191,12 +171,12 @@
"# To find points on the graph, we need to sample from each distribution.\n",
"\n",
"# Split up the mixture parameters (for future fun)\n",
"mus = np.apply_along_axis((lambda a: a[:N_MIXES*OUTPUT_DIMS]),1, y_test)\n",
"sigs = np.apply_along_axis((lambda a: a[N_MIXES*OUTPUT_DIMS:2*N_MIXES*OUTPUT_DIMS]),1, y_test)\n",
"pis = np.apply_along_axis((lambda a: softmax(a[-N_MIXES:])),1, y_test)\n",
"mus = np.apply_along_axis((lambda a: a[:N_MIXES*OUTPUT_DIMS]), 1, y_test)\n",
"sigs = np.apply_along_axis((lambda a: a[N_MIXES*OUTPUT_DIMS:2*N_MIXES*OUTPUT_DIMS]), 1, y_test)\n",
"pis = np.apply_along_axis((lambda a: mdn.softmax(a[-N_MIXES:])), 1, y_test)\n",
"\n",
"# Sample from the predicted distributions\n",
"y_samples = np.apply_along_axis(sample_from_output, 1, y_test, N_MIXES,OUTPUT_DIMS,temp=1.0)"
"y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0, sigma_temp=1.0)"
]
},
{
Expand Down Expand Up @@ -254,9 +234,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "venv",
"language": "python",
"name": "python3"
"name": "venv"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -268,7 +248,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.6.8"
}
},
"nbformat": 4,
Expand Down

0 comments on commit fbf95e1

Please sign in to comment.