Skip to content

Commit 7e130fe

Browse files
committed
models.py: add multiple-layer LSTM
1 parent a4cb8ef commit 7e130fe

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

a3c/models.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,55 @@ def __init__(self, state_dim, action_cnt):
2626
rnn_in = tf.expand_dims(self.states, [0])
2727

2828
# create LSTM
29+
lstm_layers = 2
2930
lstm_state_dim = 256
30-
lstm_cell = rnn.BasicLSTMCell(lstm_state_dim)
31+
lstm_cell_list = []
32+
for i in xrange(lstm_layers):
33+
lstm_cell_list.append(rnn.BasicLSTMCell(lstm_state_dim))
34+
stacked_cell = rnn.MultiRNNCell(lstm_cell_list)
3135

32-
c_init = np.zeros([1, lstm_cell.state_size.c], np.float32)
33-
h_init = np.zeros([1, lstm_cell.state_size.h], np.float32)
34-
self.lstm_state_init = (c_init, h_init)
36+
self.lstm_state_init = []
37+
self.lstm_state_in = []
38+
lstm_state_in = []
39+
for i in xrange(lstm_layers):
40+
c_init = np.zeros([1, lstm_state_dim], np.float32)
41+
h_init = np.zeros([1, lstm_state_dim], np.float32)
42+
self.lstm_state_init.append((c_init, h_init))
43+
44+
c_in = tf.placeholder(tf.float32, [1, lstm_state_dim])
45+
h_in = tf.placeholder(tf.float32, [1, lstm_state_dim])
46+
self.lstm_state_in.append((c_in, h_in))
47+
lstm_state_in.append(rnn.LSTMStateTuple(c_in, h_in))
48+
49+
self.lstm_state_init = tuple(self.lstm_state_init)
50+
self.lstm_state_in = tuple(self.lstm_state_in)
51+
lstm_state_in = tuple(lstm_state_in)
3552

3653
c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c])
3754
h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h])
3855
self.lstm_state_in = (c_in, h_in)
3956

4057
lstm_outputs, lstm_state_out = tf.nn.dynamic_rnn(
41-
lstm_cell, rnn_in,
42-
initial_state=rnn.LSTMStateTuple(c_in, h_in))
58+
stacked_cell, rnn_in, initial_state=lstm_state_in)
59+
60+
self.lstm_state_out = []
61+
for i in xrange(lstm_layers):
62+
self.lstm_state_out.append(
63+
(lstm_state_out[i].c, lstm_state_out[i].h))
64+
self.lstm_state_out = tuple(self.lstm_state_out)
4365

4466
rnn_out = tf.reshape(lstm_outputs, [-1, lstm_state_dim])
4567
c_out, h_out = lstm_state_out
4668
self.lstm_state_out = (c_out[:1, :], h_out[:1, :])
4769

4870
# actor
49-
self.action_scores = layers.linear(rnn_out, action_cnt)
71+
actor_h1 = layers.relu(rnn_out, 64)
72+
self.action_scores = layers.linear(actor_h1, action_cnt)
5073
self.action_probs = tf.nn.softmax(self.action_scores)
5174

5275
# critic
53-
self.state_values = tf.reshape(layers.linear(rnn_out, 1), [-1])
76+
critic_h1 = layers.relu(rnn_out, 64)
77+
self.state_values = tf.reshape(layers.linear(critic_h1, 1), [-1])
5478

5579
self.trainable_vars = tf.get_collection(
5680
tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)

0 commit comments

Comments
 (0)