Skip to content

A tensorflow implement of Memory-Augmented Neural Network

License

Notifications You must be signed in to change notification settings

Leputa/MANN-meta-learning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Meta-Learning with Memory-Augmented Neural Networks in Tensorflow

A concise alternative Tensorflow Implementation of Papar Santoro, Adam, et al."Meta-learning with memory-augmented neural networks."International conference on machine learning. 2016. And the model are encapsulated into class MANNCell which can be used as BasicRNNCell. The code is inspired by the excellent implementations of tristandeleu and snowkylin.

Memory-Augmented Neural Networks

As shown in reference paper, MANNs(Memory-Augmented Neural Networks) refer to the class of external memory equipped networkds such as NTMs(Neural Turing Machines).

MANN

Dependencies

  • Python 3.6
  • Tensorflow==1.14
  • numpy==1.16.4
  • PIL==7.1.1

Usage

Omniglot DataSet

Download images_background.zip (964 classes) and images_evaluation.zip (679 classes), and place them in the ./omniglot folder.

Running

python run_mann.py
python run_mann.py --mode test
python run_mann.py --model LSTM
python run_mann.py --model LSTM --mode test

Class MANNCell()

from mann.mann_cell import MANNCell
cell = MANNCell(
    lstm_size = 200, 
    memory_size = 128,
    memory_dim = 40,
    nb_reads = 4,
    gamma = 0.95
)
state = cell.zero_state(batch_size, tf.float32)  
output, state = tf.scan(lambda init, elem: cell(elem, init[1]), elems=tf.transpose(input, perm=[1, 0, 2]), initializer=(tf.zeros(shape=(batch_size, lstm_size+nb_reads*memory_dim)), state))  
output = tf.transpose(output, perm=[1, 0, 2])

Performance

Omniglot Classfication:

LSTM MANN

Test-set classfication accuracies on the Omniglot dataset, using one-hot encodings of labels and five classes presented per episode.

Model 1st 2nd 3rd 4th 5th 10th
LSTMref 24.4% 49.5% 55.3% 61.0% 63.6% 62.5%
LSTMrepo 30.4% 77.9% 85.3% 87.5% 88.8% 91.6%
MANNref 36.4% 82.8% 91.0% 92.6% 94.9% 98.1%
MANNrepo 35.4% 89.2% 95.2% 96.3% 96.9% 97.8%

Releases

No releases published

Packages

No packages published

Languages