+{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["\nTrains a denoising autoencoder on MNIST dataset.<br>\n", "Denoising is one of the classic applications of autoencoders.<br>\n", "The denoising process removes unwanted noise that corrupted the<br>\n", "true signal.<br>\n", "Noise + Data ---> Denoising Autoencoder ---> Data<br>\n", "Given a training dataset of corrupted data as input and<br>\n", "true signal as output, a denoising autoencoder can recover the<br>\n", "hidden structure to generate clean data.<br>\n", "This example has modular design. The encoder, decoder and autoencoder<br>\n", "are 3 models that share weights. For example, after training the<br>\n", "autoencoder, the encoder can be used to generate latent vectors<br>\n", "of input data for low-dim visualization like PCA or TSNE.<br>\n", ""]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["from __future__ import absolute_import\n", "from __future__ import division\n", "from __future__ import print_function\n", "import keras\n", "from keras.layers import Activation, Dense, Input\n", "from keras.layers import Conv2D, Flatten\n", "from keras.layers import Reshape, Conv2DTranspose\n", "from keras.models import Model\n", "from keras import backend as K\n", "from keras.datasets import mnist\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from PIL import Image"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["np.random.seed(1337)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["MNIST dataset"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["(x_train, _), (x_test, _) = mnist.load_data()"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["image_size = x_train.shape[1]\n", "x_train = np.reshape(x_train, [-1, image_size, image_size, 1])\n", "x_test = np.reshape(x_test, [-1, image_size, image_size, 1])\n", "x_train = x_train.astype('float32') / 255\n", "x_test = x_test.astype('float32') / 255"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Generate corrupted MNIST images by adding noise with normal dist<br>\n", "centered at 0.5 and std=0.5"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape)\n", "x_train_noisy = x_train + noise\n", "noise = np.random.normal(loc=0.5, scale=0.5, size=x_test.shape)\n", "x_test_noisy = x_test + noise"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["x_train_noisy = np.clip(x_train_noisy, 0., 1.)\n", "x_test_noisy = np.clip(x_test_noisy, 0., 1.)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Network parameters"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["input_shape = (image_size, image_size, 1)\n", "batch_size = 128\n", "kernel_size = 3\n", "latent_dim = 16\n", "# Encoder/Decoder number of CNN layers and filters per layer\n", "layer_filters = [32, 64]"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Build the Autoencoder Model<br>\n", "First build the Encoder Model"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["inputs = Input(shape=input_shape, name='encoder_input')\n", "x = inputs\n", "# Stack of Conv2D blocks\n", "# Notes:\n", "# 1) Use Batch Normalization before ReLU on deep networks\n", "# 2) Use MaxPooling2D as alternative to strides>1\n", "# - faster but not as good as strides>1\n", "for filters in layer_filters:\n", " x = Conv2D(filters=filters,\n", " kernel_size=kernel_size,\n", " strides=2,\n", " activation='relu',\n", " padding='same')(x)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Shape info needed to build Decoder Model"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["shape = K.int_shape(x)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Generate the latent vector"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["x = Flatten()(x)\n", "latent = Dense(latent_dim, name='latent_vector')(x)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Instantiate Encoder Model"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["encoder = Model(inputs, latent, name='encoder')\n", "encoder.summary()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Build the Decoder Model"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["latent_inputs = Input(shape=(latent_dim,), name='decoder_input')\n", "x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)\n", "x = Reshape((shape[1], shape[2], shape[3]))(x)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Stack of Transposed Conv2D blocks<br>\n", "Notes:<br>\n", "1) Use Batch Normalization before ReLU on deep networks<br>\n", "2) Use UpSampling2D as alternative to strides>1<br>\n", "- faster but not as good as strides>1"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["for filters in layer_filters[::-1]:\n", " x = Conv2DTranspose(filters=filters,\n", " kernel_size=kernel_size,\n", " strides=2,\n", " activation='relu',\n", " padding='same')(x)"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["x = Conv2DTranspose(filters=1,\n", " kernel_size=kernel_size,\n", " padding='same')(x)"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["outputs = Activation('sigmoid', name='decoder_output')(x)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Instantiate Decoder Model"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["decoder = Model(latent_inputs, outputs, name='decoder')\n", "decoder.summary()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Autoencoder = Encoder + Decoder<br>\n", "Instantiate Autoencoder Model"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')\n", "autoencoder.summary()"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["autoencoder.compile(loss='mse', optimizer='adam')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Train the autoencoder"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["autoencoder.fit(x_train_noisy,\n", " x_train,\n", " validation_data=(x_test_noisy, x_test),\n", " epochs=30,\n", " batch_size=batch_size)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Predict the Autoencoder output from corrupted test images"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["x_decoded = autoencoder.predict(x_test_noisy)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Display the 1st 8 corrupted and denoised images"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["rows, cols = 10, 30\n", "num = rows * cols\n", "imgs = np.concatenate([x_test[:num], x_test_noisy[:num], x_decoded[:num]])\n", "imgs = imgs.reshape((rows * 3, cols, image_size, image_size))\n", "imgs = np.vstack(np.split(imgs, rows, axis=1))\n", "imgs = imgs.reshape((rows * 3, -1, image_size, image_size))\n", "imgs = np.vstack([np.hstack(i) for i in imgs])\n", "imgs = (imgs * 255).astype(np.uint8)\n", "plt.figure()\n", "plt.axis('off')\n", "plt.title('Original images: top rows, '\n", " 'Corrupted Input: middle rows, '\n", " 'Denoised Input: third rows')\n", "plt.imshow(imgs, interpolation='none', cmap='gray')\n", "Image.fromarray(imgs).save('corrupted_and_denoised.png')\n", "plt.show()"]}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4"}}, "nbformat": 4, "nbformat_minor": 2}
0 commit comments