Skip to content

Testing pattern completion learning on MNIST dataset.

Notifications You must be signed in to change notification settings

ytixu/MNIST-via-PCL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

50 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MNIST with Pattern Completion Learning

Note: this project is ongoing.

Quick run

python ae_classifier.py --model cnn

Packages dependencies

Keras 2.1.3

Matplotlib 2.1.0 (for the figures)


Pattern completion learning (PCL) is an inference strategy where given data pairs (X,Y), a PCL model tries to learn correlation between the latent representations of the partial pattern (X) and the complete pattern (XY). The main advantage of PCL is after learning the latent representations, downstream tasks can be learn fast and with high quality.

Our procedure

For MNIST dataset, X represents the handwritten digit image (28x28 matrix) and Y, the image label (10 dimentional one-hot vector). XY is the contatenation of the two data types.

In our experiement, the proceedure is the following:

  1. Use an autoencoder to learn the latent representations of labels (X), images (Y) and label-image compounds (XY). The latent representation space is set to 32 dimension with each bounded between 0 and 1.
  2. Learn completion functions for each of the following tasks: classification and generation. For classification, X is the partial pattern. For generation, Y is the partial pattern. The completion function models the correlation between the partial and its complete pattern (XY). For the output, we ignore input partial pattern.

Autoencoding models

We test three different autoencoders. They are all defined in file autoencoders/models.py.

Flatten

(Flatten_AE) Encoder and decoder each have three dense layers.

Dense+CNN

(Dense_CNN_AE) Use a dense layer to reduce input (XY) in a vector that can be reshaped to a size of 28x28x1. Then treat the rest as an image through a convolutional autoencoder. Last layer in the decoder is a dense layer that outputs a vector of size 28x28 + 10.

2chan-CNN

(CNN_AE) Use a dense layer to map input (XY) to a vector that can be reshaped to a size of 28x28. Append this vector to the input digit (X). Reshape the result into a 28x28x2 matrix and treat it as an image (with 2 channels) through a convolutional autoencoder. Use the same decoder as the previous model.

Reweight-CNN

(RW_CNN) Same as 2chan-CNN, but with distinguishing the kind of data X and Y represents. Before outputting XY, Y is passed through a softmax layer. The loss function is modified to treat the loss of X and the loss of Y with equal weights.

Completion Functions

We test between the following completion functions:

  1. Vector addition (ADD).
  2. K-nearest neighbors (KNN).
  3. Least squares Linear Regression (LLR).
  4. Single dense layer network (SDL) without bias and regularization.

Details on vector addition

We take the average difference d between the latent representations of the complete pattern (XY) and the partial pattern (X or Y) in the training set. At inference, on input X', do X'+d to obtain the predicted completion for X'.

Baselines

  1. FN: End-to-end foward network, baselines\flatten_classifier.py and baselines\flatten_generator.py. From here.
  2. CNN: end-to-end convolutional network, baselines\cnn_classifier.py and baselines\cnn_generator.py. From here.

Results

Autoencoding

Model Loss (binary_cross_entropy) L1 Distance
Flatten 0.1109 0.0492
Dense+CNN 0.0897 0.0314
2chan-CNN 0.0845 0.0285
Reweight-CNN 0.1921 0.0300

Classification

Input handwritten digit, output class as a probability vector.

Learning paradigm Model Accuracy
ADD KNN LLR SDL
End-to-End (E2E) FN 0.9797 - - - -
CNN 0.9907 - - - -
Feature extraction Flatten 0.8989 - - - -
Dense+CNN 0.8362 - - - -
2chan-CNN 0.8227 - - - -
Reweight-CNN 0.8958 - - - -
Pattern matching Flatten - 0.7457 - - 0.7733
Dense+CNN - 0.5832 0.9684 0.6530 0.7035
2chan-CNN - 0.4348 0.9682 0.6363 0.6493
Reweight-CNN - 0.7389 0.9794 0.9389 0.7859
Pattern completion
(learn both X --> XY
and Y --> XY together)
Flatten - 0.8498 - - 0.9097
Dense+CNN - 0.7641 0.9650 0.9523 0.9067
2chan-CNN - 0.7897 0.9690 0.9537 0.9469
Reweight-CNN - 0.9205 0.9789 0.9697 0.9642
Pattern completion (PCL) Flatten - 0.9240 - - 0.9212
Dense+CNN - 0.9723 0.9650 0.9747 0.9741
2chan-CNN - 0.9844 0.9690 0.9841 0.9836
Reweight-CNN - 0.9871 0.9789 0.9871 0.9866

Generation from labels

Input one-hot encoded label, output handwritten digit.

Learning paradigm Model Result
E2E FN Digit generation using end-to-end model
CNN Digit generation using end-to-end model
PCL Flatten + ADD Digit generation using PCL model
Flatten + SDL Digit generation using PCL model
Dense+CNN + ADD Digit generation using PCL model
Dense+CNN + SDL Digit generation using PCL model
2chan-CNN + ADD Digit generation using PCL model
2chan-CNN + SDL Digit generation using PCL model

Adding noise

Adding random gaussian noise to the latent representation of the generated digit, output from SDL. The center digit has zero noise, the digits on the first layer around the center has 50% of the mean STD, and those on the last layer has 100% of the mean STD. (STD is the standard deviation of the difference computed for ADD.)

Flatten Dense+CNN 2chan-CNN
Digit generation using PCL model Digit generation using PCL model Digit generation using PCL model

Discussion

  • PCL requires good latent representations to guarantee high performance for downstream tasks.
  • Adding noise the output of SDL can generate blurry or invalid digits. (This might get fixed by adding an adversarial component?)
  • In Dense+CNN the first dense layer mixes the digit image with the label information, thus important features in the digit image might be lost. Alternatively, 2chan-CNN stores the original image as a seperate channel, which can avoid such loss.

Reference

Yi Tian Xu, Yaqiao Li, David Meger, Human motion prediction via pattern completion in latent representation space, CRV 2019 (16th conference on Computer and Robot Vision). project link

About

Testing pattern completion learning on MNIST dataset.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages