@@ -26,31 +26,55 @@ def __init__(self, state_dim, action_cnt):
26
26
rnn_in = tf .expand_dims (self .states , [0 ])
27
27
28
28
# create LSTM
29
+ lstm_layers = 2
29
30
lstm_state_dim = 256
30
- lstm_cell = rnn .BasicLSTMCell (lstm_state_dim )
31
+ lstm_cell_list = []
32
+ for i in xrange (lstm_layers ):
33
+ lstm_cell_list .append (rnn .BasicLSTMCell (lstm_state_dim ))
34
+ stacked_cell = rnn .MultiRNNCell (lstm_cell_list )
31
35
32
- c_init = np .zeros ([1 , lstm_cell .state_size .c ], np .float32 )
33
- h_init = np .zeros ([1 , lstm_cell .state_size .h ], np .float32 )
34
- self .lstm_state_init = (c_init , h_init )
36
+ self .lstm_state_init = []
37
+ self .lstm_state_in = []
38
+ lstm_state_in = []
39
+ for i in xrange (lstm_layers ):
40
+ c_init = np .zeros ([1 , lstm_state_dim ], np .float32 )
41
+ h_init = np .zeros ([1 , lstm_state_dim ], np .float32 )
42
+ self .lstm_state_init .append ((c_init , h_init ))
43
+
44
+ c_in = tf .placeholder (tf .float32 , [1 , lstm_state_dim ])
45
+ h_in = tf .placeholder (tf .float32 , [1 , lstm_state_dim ])
46
+ self .lstm_state_in .append ((c_in , h_in ))
47
+ lstm_state_in .append (rnn .LSTMStateTuple (c_in , h_in ))
48
+
49
+ self .lstm_state_init = tuple (self .lstm_state_init )
50
+ self .lstm_state_in = tuple (self .lstm_state_in )
51
+ lstm_state_in = tuple (lstm_state_in )
35
52
36
53
c_in = tf .placeholder (tf .float32 , [1 , lstm_cell .state_size .c ])
37
54
h_in = tf .placeholder (tf .float32 , [1 , lstm_cell .state_size .h ])
38
55
self .lstm_state_in = (c_in , h_in )
39
56
40
57
lstm_outputs , lstm_state_out = tf .nn .dynamic_rnn (
41
- lstm_cell , rnn_in ,
42
- initial_state = rnn .LSTMStateTuple (c_in , h_in ))
58
+ stacked_cell , rnn_in , initial_state = lstm_state_in )
59
+
60
+ self .lstm_state_out = []
61
+ for i in xrange (lstm_layers ):
62
+ self .lstm_state_out .append (
63
+ (lstm_state_out [i ].c , lstm_state_out [i ].h ))
64
+ self .lstm_state_out = tuple (self .lstm_state_out )
43
65
44
66
rnn_out = tf .reshape (lstm_outputs , [- 1 , lstm_state_dim ])
45
67
c_out , h_out = lstm_state_out
46
68
self .lstm_state_out = (c_out [:1 , :], h_out [:1 , :])
47
69
48
70
# actor
49
- self .action_scores = layers .linear (rnn_out , action_cnt )
71
+ actor_h1 = layers .relu (rnn_out , 64 )
72
+ self .action_scores = layers .linear (actor_h1 , action_cnt )
50
73
self .action_probs = tf .nn .softmax (self .action_scores )
51
74
52
75
# critic
53
- self .state_values = tf .reshape (layers .linear (rnn_out , 1 ), [- 1 ])
76
+ critic_h1 = layers .relu (rnn_out , 64 )
77
+ self .state_values = tf .reshape (layers .linear (critic_h1 , 1 ), [- 1 ])
54
78
55
79
self .trainable_vars = tf .get_collection (
56
80
tf .GraphKeys .TRAINABLE_VARIABLES , tf .get_variable_scope ().name )
0 commit comments