Skip to content

Latest commit

 

History

History
904 lines (874 loc) · 29 KB

README.md

File metadata and controls

904 lines (874 loc) · 29 KB

MMN

This is the implementation of Moore Machine Network (MMN) introduced in the work "Learning Finite State Representations of Recurrent Policy Networks".

If you find it useful in your research, please cite it using :

@inproceedings{
  koul2018learning,
  title={Learning Finite State Representations of Recurrent Policy Networks},
  author={Anurag Koul and Alan Fern and Sam Greydanus},
  booktitle={International Conference on Learning Representations},
  year={2019},
  url={https://openreview.net/forum?id=S1gOpsCctm},
}

Also, here is the link to the poster presented in the ICLR2019.

Installation

  • Python 3.5+
  • Pytorch
  • gym_x
  • To install dependencies:
    pip install -r requirements.txt

Usage

We use main_mce.py , main_tomita.py and main_atari.py for experimenting with Mode Counter Environment(a.k.a Gold Rush) , Tomita Grammar and Atari, respectively.

In the following, we describe usage w.r.t main_atari.py. However, the same would apply for other cases.

Parameters

usage: main_atari.py [-h] [--generate_train_data] [--generate_bn_data]
                     [--generate_max_steps GENERATE_MAX_STEPS] [--gru_train]
                     [--gru_test] [--gru_size GRU_SIZE] [--gru_lr GRU_LR]
                     [--bhx_train] [--ox_train] [--bhx_test] [--ox_test]
                     [--bgru_train] [--bgru_test] [--bhx_size BHX_SIZE]
                     [--bhx_suffix BHX_SUFFIX] [--ox_size OX_SIZE]
                     [--train_epochs TRAIN_EPOCHS] [--batch_size BATCH_SIZE]
                     [--bgru_lr BGRU_LR] [--gru_scratch] [--bx_scratch]
                     [--generate_fsm] [--evaluate_fsm]
                     [--bn_episodes BN_EPISODES] [--bn_epochs BN_EPOCHS]
                     [--no_cuda] [--env ENV] [--env_seed ENV_SEED]
                     [--result_dir RESULT_DIR]

GRU to FSM

optional arguments:
  -h, --help            show this help message and exit
  --generate_train_data
                        Generate Train Data
  --generate_bn_data    Generate Bottle-Neck Data
  --generate_max_steps GENERATE_MAX_STEPS
                        Maximum number of steps to be used for data generation
  --gru_train           Train GRU Network
  --gru_test            Test GRU Network
  --gru_size GRU_SIZE   No. of GRU Cells
  --gru_lr GRU_LR       No. of GRU Cells
  --bhx_train           Train bx network
  --ox_train            Train ox network
  --bhx_test            Test bx network
  --ox_test             Test ox network
  --bgru_train          Train binary gru network
  --bgru_test           Test binary gru network
  --bhx_size BHX_SIZE   binary encoding size
  --bhx_suffix BHX_SUFFIX
                        suffix fo bhx folder
  --ox_size OX_SIZE     binary encoding size
  --train_epochs TRAIN_EPOCHS
                        No. of training episodes
  --batch_size BATCH_SIZE
                        batch size used for training
  --bgru_lr BGRU_LR     Learning rate for binary GRU
  --gru_scratch         use scratch gru for BGRU
  --bx_scratch          use scratch bx network for BGRU
  --generate_fsm        extract fsm from fmm net
  --evaluate_fsm        evaluate fsm
  --bn_episodes BN_EPISODES
                        No. of episodes for generating data for Bottleneck
                        Network
  --bn_epochs BN_EPOCHS
                        No. of Training epochs
  --no_cuda             no cuda usage
  --env ENV             Name of the environment
  --env_seed ENV_SEED   Seed for the environment
  --result_dir RESULT_DIR
                        Directory Path to store results

For most of the experiments we've done, we've set generate_max_steps = 100. Based on the environment you're using, you can change it accordingly. Other parameters' values were set to the default ones, except for ox_size, hx_size, and gru_size which were set based on the experiment we ran.

Use prepared scripts

Formation of MMN requires following multiple steps which could be found here. These steps could also be sequentially executed for bhx_size=64,ox_size=100 using the following script. This script could be easily customized for other environments.

./run_atari.sh PongDeterministic-v4

Steps

  1. Test RNN: We assume existence of pre-trained RNN model. The following step is optional and evaluates the performance of this model:

    python main_atari.py --env PongDeterministic-v4 --gru_test --gru_size 32
  2. Generate Bottleneck Data: It involves generating and storing data for training quantized bottleneck data (QBN).

    python main_atari.py --env PongDeterministic-v4 --generate_bn_data --gru_size 32 --generate_max_steps 100
  3. Train BHX : It involves training QBN for Hidden State (hx). After each epoch, the QBN is inserted into orginal rnn model and the overall model is evaluated with environment. The Best performing QBN is saved.

    python main_atari.py --env PongDeterministic-v4 --bhx_train --bhx_size 64 --gru_size 32 --generate_max_steps 100

    After it's done, the model and plots will be saved here:

    results/Atari/PongDeterministic-v4/gru_32_bhx_64/
  4. Test BHX (optional): Inserts the saved BHX model into original rnn model and evaluates the model with environment.

    python main_atari.py --env PongDeterministic-v4 --bhx_test --bhx_size 64 --gru_size 32 --generate_max_steps 100
  5. Train OX : It involves training QBN for learned observation features(X) given as input to RNN.

    python main_atari.py --env PongDeterministic-v4 --ox_train --ox_size 100 --bhx_size 64 --gru_size 32 --generate_max_steps 100

    After it's done, the model and plots will be saved here:

    results/Atari/PongDeterministic-v4/gru_32_ox_100/
  6. Test BHX (optional): Inserts the saved OX model into original rnn model and evaluates the model with environment.

    python main_atari.py --env PongDeterministic-v4 --ox_test --ox_size 100 --bhx_size 64 --gru_size 32 --generate_max_steps 100
  7. MMN: We form the Moore Machine Network by inserting both the BHX and OX qbn's into the original rnn model. Thereafter the performance of the mmn is evaluated on the environment. Fine-Tuning of MMN is performed if there is a fall in performance which could be caused by accumulated error by both the qbn's.

    python main_atari.py --env PongDeterministic-v4 --bgru_train --ox_size 100 --bhx_size 64 --gru_size 32 --generate_max_steps 100

    When the fine-tuning is done, model and plots will be saved here:

    results/Atari/PongDeterministic-v4/gru_32_hx_(64,100)_bgru
  8. Test MMN (optional): Loads and tests the saved MMN model.

    python main_atari.py --env PongDeterministic-v4 --bgru_test --bhx_size 64 --ox_size 100 --gru_size 32 --generate_max_steps 100
  9. Extract Moore Machine: In this final step, quantized observation and hidden state space are enumarated to form a moore machine. Thereafter minimization is performed on top of it.

    python main_atari.py --env PongDeterministic-v4 --generate_fsm --bhx_size 64 --ox_size 100 --gru_size 32 --generate_max_steps 100

    Final Results before and after minimization are stored in text files (fsm.txt and minimized_moore_machine.txt ) here:

    results/Atari/PongDeterministic-v4/gru_32_hx_(64,100)_bgru/

Using pre-trained models

For results to be easily reproducible, previously trained GRU models on different environments have been provided. You can simply use them to train new QBNs and reproduce the results presented in the paper. Models are accessible through this directory: results/Atari/. The GRU cell size can be determined from the models' path, i.e. if a model is saved in a folder named as gru_32, then the GRU cell size is 32. Having the pretrained GRU model, you can go to how to run the code step by step to start training the QBNs.

Results

MCE

Presenting the Mode Counter Environments(MCE) results, number of states and observations of the MMs extracted from the MMNs both before and after minimization. Moore Machine extraction for MCE(table 1 in paper):

Game Bh Bf Fine-Tuning Score Before Minimization After Minimization
Before(%) After(%) |H| |O| Acc(%) |H| |O| Acc(%)
Amnesia
(gold rush read)
4 4 98 100 7 5 100 4 4 100
4 8 99 100 7 7 100 4 4 100
8 4 100 - 6 5 100 4 4 100
8 8 99 100 7 7 100 4 4 100
Blind
(gold rush blind)
4 4 100 - 12 6 100 10 1 100
4 8 100 - 12 8 100 10 1 100
8 4 100 - 15 6 100 10 1 100
8 8 78 100 13 8 100 10 1 100
Tracker
(gold rush sneak)
4 4 98 98 58 5 98 50 4 98
4 8 99 100 23 5 100 10 4 100
8 4 98 100 91 5 100 10 4 100
8 8 99 100 85 5 100 10 4 100

Tomita Grammar

The below table presents the test results for the trained RNNs giving the accuracy over a test set of 100 strings drawn from the same distribution as used for training. Moore Machine extraction for Tomita grammar(table 2 in paper):

Grammar RNN Acc(%) Bh Fine-Tuning Score Before Minimization After Minimization
Before(%) After(%) |H| Acc(%) |H| Acc(%)
1 100 8 100 - 13 100 2 100
100 16 100 - 28 100 2 100
2 100 8 100 - 13 100 3 100
100 16 100 - 14 100 3 100
3 100 8 100 - 34 100 5 100
100 16 100 - 39 100 5 100
4 100 8 100 - 17 100 4 100
100 16 100 - 18 100 4 100
5 100 8 95 96 192 96 115 96
100 16 100 - 316 100 4 100
6 99 8 98 98 100 98 12 98
99 16 99 99 518 99 11 99
7 100 8 100 - 25 100 5 100
100 16 100 - 107 100 5 100

Control Tasks

To run the whole thing over control tasks, you only need to run the run_control.sh file. Below, is an example of how to do it:

sh run_control.sh Acrobot-v1 32 64 64

More experiments on control tasks have been done. Results are presented in the following table:

Game(# of actions) Bh Bf Before Minimization After Minimization
|H| |O| Score |H| |O| Score
Cart Pole(2) 64 64 27 859 500 4 32 500
Lunar Lander(4) 128 64 1502 1165 198 52 89 115
Acrobot(3) 64 64 769 649 -73.95 11 23 -89.4

Atari

This table shows the performance of the trained MMNs before and after finetuning for different combinations of Bh and Bf. A few more games investigated and the results are added to the table 3 of the paper: Results may slightly vary.

Game(# of actions) RNN(score) Bh Bf Fine-Tuning Score Before Minimization After Minimization
Before After |H| |O| Score |H| |O| Score
Pong(3) 21 64 100 20 21 380 374 21 4 12 21
64 400 20 21 373 372 21 3 10 21
128 100 20 21 383 373 21 3 12 21
128 400 20 21 379 371 21 3 11 21
Freeway(3) 21 64 100 21 - 1 1 21 1 1 21
64 400 21 - 1 1 21 1 1 21
128 100 21 - 1 1 21 1 1 21
128 400 21 - 1 1 21 1 1 21
Breakout(4) 773 64 100 32 423 1898 1874 423 8 30 423
64 400 25 415 1888 1871 415 8 30 415
128 100 41 377 1583 1514 377 11 27 377
128 400 85 379 1729 1769 379 8 30 379
Space Invaders(4) 1820 64 100 520 1335 1495 1502 1335 8 29 1335
64 400 365 1235 1625 1620 1235 12 29 1235
128 100 390 1040 1563 1457 1040 12 35 1040
128 400 520 1430 1931 1921 1430 6 27 1430
Bowling(6) 60 64 100 60 - 49 1 60 33 1 60
64 400 60 - 49 1 60 33 1 60
128 100 60 - 26 1 60 24 1 60
128 400 60 - 26 1 60 24 1 60
Boxing(18) 100 64 100 94 100 1173 1167 100 13 79 100
64 400 98 100 2621 2605 100 14 119 100
128 100 94 97 2499 2482 97 14 106 97
128 400 97 100 1173 1169 100 14 88 100
Chopper Command(18) 5300 64 100 4000 3710 3731 4000 38 182 1890