11#!/usr/bin/env python3
22# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33
4+ import dataclasses
45import logging
56import random
67from typing import Optional
1617logger = logging .getLogger (__name__ )
1718
1819
20+ @dataclasses .dataclass
21+ class MemoryBuffer :
22+ state : torch .Tensor
23+ action : torch .Tensor
24+ reward : torch .Tensor
25+ next_state : torch .Tensor
26+ next_action : torch .Tensor
27+ terminal : torch .Tensor
28+ possible_next_actions : Optional [torch .Tensor ]
29+ possible_next_actions_mask : Optional [torch .Tensor ]
30+ possible_actions : Optional [torch .Tensor ]
31+ possible_actions_mask : Optional [torch .Tensor ]
32+ time_diff : torch .Tensor
33+ policy_id : torch .Tensor
34+
35+ @torch .no_grad () # type: ignore
36+ def slice (self , indices ):
37+ return MemoryBuffer (
38+ state = self .state [indices ],
39+ action = self .action [indices ],
40+ reward = self .reward [indices ],
41+ next_state = self .next_state [indices ],
42+ next_action = self .next_action [indices ],
43+ terminal = self .terminal [indices ],
44+ possible_next_actions = self .possible_next_actions [indices ]
45+ if self .possible_next_actions is not None
46+ else None ,
47+ possible_next_actions_mask = self .possible_next_actions_mask [indices ]
48+ if self .possible_next_actions_mask is not None
49+ else None ,
50+ possible_actions = self .possible_actions [indices ]
51+ if self .possible_actions is not None
52+ else None ,
53+ possible_actions_mask = self .possible_actions_mask [indices ]
54+ if self .possible_actions_mask is not None
55+ else None ,
56+ time_diff = self .time_diff [indices ],
57+ policy_id = self .policy_id [indices ],
58+ )
59+
60+ @torch .no_grad () # type: ignore
61+ def insert_at (
62+ self ,
63+ idx : int ,
64+ state : torch .Tensor ,
65+ action : torch .Tensor ,
66+ reward : float ,
67+ next_state : torch .Tensor ,
68+ next_action : torch .Tensor ,
69+ terminal : bool ,
70+ possible_next_actions : Optional [torch .Tensor ],
71+ possible_next_actions_mask : Optional [torch .Tensor ],
72+ time_diff : float ,
73+ possible_actions : Optional [torch .Tensor ],
74+ possible_actions_mask : Optional [torch .Tensor ],
75+ policy_id : int ,
76+ ):
77+ self .state [idx ] = state
78+ self .action [idx ] = action
79+ self .reward [idx ] = reward
80+ self .next_state [idx ] = next_state
81+ self .next_action [idx ] = next_action
82+ self .terminal [idx ] = terminal
83+ if self .possible_actions is not None :
84+ self .possible_actions [idx ] = possible_actions
85+ if self .possible_actions_mask is not None :
86+ self .possible_actions_mask [idx ] = possible_actions_mask
87+ if self .possible_next_actions is not None :
88+ self .possible_next_actions [idx ] = possible_next_actions
89+ if self .possible_next_actions_mask is not None :
90+ self .possible_next_actions_mask [idx ] = possible_next_actions_mask
91+ self .time_diff [idx ] = time_diff
92+ self .policy_id [idx ] = policy_id
93+
94+ @classmethod
95+ def create (
96+ cls ,
97+ max_size : int ,
98+ state_dim : int ,
99+ action_dim : int ,
100+ max_possible_actions : Optional [int ],
101+ has_possble_actions : bool ,
102+ ):
103+ return cls (
104+ state = torch .zeros ((max_size , state_dim )),
105+ action = torch .zeros ((max_size , action_dim )),
106+ reward = torch .zeros ((max_size , 1 )),
107+ next_state = torch .zeros ((max_size , state_dim )),
108+ next_action = torch .zeros ((max_size , action_dim )),
109+ terminal = torch .zeros ((max_size , 1 ), dtype = torch .uint8 ),
110+ possible_next_actions = torch .zeros (
111+ (max_size , max_possible_actions , action_dim )
112+ )
113+ if has_possble_actions
114+ else None ,
115+ possible_next_actions_mask = torch .zeros ((max_size , max_possible_actions ))
116+ if max_possible_actions
117+ else None ,
118+ possible_actions = torch .zeros ((max_size , max_possible_actions , action_dim ))
119+ if has_possble_actions
120+ else None ,
121+ possible_actions_mask = torch .zeros ((max_size , max_possible_actions ))
122+ if max_possible_actions
123+ else None ,
124+ time_diff = torch .zeros ((max_size , 1 )),
125+ policy_id = torch .zeros ((max_size , 1 ), dtype = torch .long ),
126+ )
127+
128+
19129class OpenAIGymMemoryPool :
20- def __init__ (self , max_replay_memory_size ):
130+ def __init__ (self , max_replay_memory_size : int ):
21131 """
22132 Creates an OpenAIGymMemoryPool object.
23133
24134 :param max_replay_memory_size: Upper bound on the number of transitions
25135 to store in replay memory.
26136 """
27- self .replay_memory = []
28137 self .max_replay_memory_size = max_replay_memory_size
29138 self .memory_num = 0
30- self .skip_insert_until = self .max_replay_memory_size
139+
140+ # Not initializing in the beginning because we don't know the shapes
141+ self .memory_buffer : Optional [MemoryBuffer ] = None
31142
32143 @property
33144 def size (self ):
34- return len (self .replay_memory )
145+ return min (self .memory_num , self .max_replay_memory_size )
146+
147+ @property
148+ def state_dim (self ):
149+ assert self .memory_buffer is not None
150+ return self .memory_buffer .state .shape [1 ]
35151
36- def shuffle (self ):
37- random .shuffle (self .replay_memory )
152+ @property
153+ def action_dim (self ):
154+ assert self .memory_buffer is not None
155+ return self .memory_buffer .action .shape [1 ]
38156
39157 def sample_memories (self , batch_size , model_type , chunk = None ):
40158 """
@@ -49,72 +167,63 @@ def sample_memories(self, batch_size, model_type, chunk=None):
49167 :param model_type: Model type (discrete, parametric).
50168 :param chunk: Index of chunk of data (for deterministic sampling).
51169 """
52- cols = [[], [], [], [], [], [], [], [], [], [], [], []]
53-
54170 if chunk is None :
55- indices = np . random . randint (0 , len ( self .replay_memory ) , size = batch_size )
171+ indices = torch . randint (0 , self .size , size = ( batch_size ,) )
56172 else :
57173 start_idx = chunk * batch_size
58174 end_idx = start_idx + batch_size
59175 indices = range (start_idx , end_idx )
60176
61- for idx in indices :
62- memory = self .replay_memory [idx ]
63- for col , value in zip (cols , memory ):
64- col .append (value )
177+ memory = self .memory_buffer .slice (indices )
65178
66- states = stack ( cols [ 0 ])
67- next_states = stack ( cols [ 3 ])
179+ states = memory . state
180+ next_states = memory . next_state
68181
69182 assert states .dim () == 2
70183 assert next_states .dim () == 2
71184
72185 if model_type == ModelType .PYTORCH_PARAMETRIC_DQN .value :
73- num_possible_actions = len ( cols [ 7 ][ 0 ])
186+ num_possible_actions = memory . possible_actions_mask . shape [ 1 ]
74187
75- actions = stack ( cols [ 1 ])
76- next_actions = stack ( cols [ 4 ])
188+ actions = memory . action
189+ next_actions = memory . next_action
77190
78191 tiled_states = states .repeat (1 , num_possible_actions ).reshape (
79192 - 1 , states .shape [1 ]
80193 )
81- possible_actions = torch . cat ( cols [ 8 ])
194+ possible_actions = memory . possible_actions . reshape ( - 1 , actions . shape [ 1 ])
82195 possible_actions_state_concat = torch .cat (
83196 (tiled_states , possible_actions ), dim = 1
84197 )
85- possible_actions_mask = stack ( cols [ 9 ])
198+ possible_actions_mask = memory . possible_actions_mask
86199
87200 tiled_next_states = next_states .repeat (1 , num_possible_actions ).reshape (
88201 - 1 , next_states .shape [1 ]
89202 )
90- possible_next_actions = torch .cat (cols [6 ])
203+ possible_next_actions = memory .possible_next_actions .reshape (
204+ - 1 , actions .shape [1 ]
205+ )
91206 possible_next_actions_state_concat = torch .cat (
92207 (tiled_next_states , possible_next_actions ), dim = 1
93208 )
94- possible_next_actions_mask = stack ( cols [ 7 ])
209+ possible_next_actions_mask = memory . possible_next_actions_mask
95210 else :
96211 possible_actions = None
97212 possible_actions_state_concat = None
98213 possible_next_actions = None
99214 possible_next_actions_state_concat = None
100- if cols [7 ] is None or cols [7 ][0 ] is None :
101- possible_next_actions_mask = None
102- else :
103- possible_next_actions_mask = stack (cols [7 ])
104- if cols [9 ] is None or cols [9 ][0 ] is None :
105- possible_actions_mask = None
106- else :
107- possible_actions_mask = stack (cols [9 ])
215+ possible_next_actions_mask = memory .possible_next_actions_mask
216+ possible_actions_mask = memory .possible_actions_mask
108217
109- actions = stack ( cols [ 1 ])
110- next_actions = stack ( cols [ 4 ])
218+ actions = memory . action
219+ next_actions = memory . next_action
111220
112221 assert len (actions .size ()) == 2
113222 assert len (next_actions .size ()) == 2
114223
115- rewards = torch . tensor ( cols [ 2 ], dtype = torch . float32 ). reshape ( - 1 , 1 )
116- not_terminal = ( 1 - torch . tensor ( cols [ 5 ], dtype = torch . int32 )). reshape ( - 1 , 1 )
117- time_diffs = torch . tensor ( cols [ 10 ], dtype = torch . int32 ). reshape ( - 1 , 1 )
224+ rewards = memory . reward
225+ not_terminal = 1 - memory . terminal
226+ time_diffs = memory . time_diff
118227
119228 return TrainingDataPage (
120229 states = states ,
@@ -144,32 +253,58 @@ def insert_into_memory(
144253 time_diff : float ,
145254 possible_actions : Optional [torch .Tensor ],
146255 possible_actions_mask : Optional [torch .Tensor ],
147- policy_id : str ,
256+ policy_id : int ,
148257 ):
149258 """
150259 Inserts transition into replay memory in such a way that retrieving
151260 transitions uniformly at random will be equivalent to reservoir sampling.
152261 """
153- item = (
154- state ,
155- action ,
156- reward ,
157- next_state ,
158- next_action ,
159- terminal ,
160- possible_next_actions ,
161- possible_next_actions_mask ,
162- possible_actions ,
163- possible_actions_mask ,
164- time_diff ,
165- policy_id ,
166- )
167262
263+ if self .memory_buffer is None :
264+ assert state .shape == next_state .shape
265+ assert len (state .shape ) == 1
266+ assert action .shape == next_action .shape
267+ assert len (action .shape ) == 1
268+ if possible_actions_mask is not None :
269+ assert possible_next_actions_mask is not None
270+ assert possible_actions_mask .shape == possible_next_actions_mask .shape
271+ assert len (possible_actions_mask .shape ) == 1
272+ max_possible_actions = possible_actions_mask .shape [0 ]
273+ else :
274+ max_possible_actions = None
275+
276+ assert (possible_actions is not None ) == (possible_next_actions is not None )
277+
278+ self .memory_buffer = MemoryBuffer .create (
279+ max_size = self .max_replay_memory_size ,
280+ state_dim = state .shape [0 ],
281+ action_dim = action .shape [0 ],
282+ max_possible_actions = max_possible_actions ,
283+ has_possble_actions = possible_actions is not None ,
284+ )
285+
286+ insert_idx = None
168287 if self .memory_num < self .max_replay_memory_size :
169- self .replay_memory .append (item )
170- elif self .memory_num >= self .skip_insert_until :
171- p = float (self .max_replay_memory_size ) / self .memory_num
172- self .skip_insert_until += np .random .geometric (p )
173- rand_index = np .random .randint (self .max_replay_memory_size )
174- self .replay_memory [rand_index ] = item
288+ insert_idx = self .memory_num
289+ else :
290+ rand_idx = torch .randint (0 , self .memory_num , size = (1 ,)).item ()
291+ if rand_idx < self .max_replay_memory_size :
292+ insert_idx = rand_idx # type: ignore
293+
294+ if insert_idx is not None :
295+ self .memory_buffer .insert_at (
296+ insert_idx ,
297+ state ,
298+ action ,
299+ reward ,
300+ next_state ,
301+ next_action ,
302+ terminal ,
303+ possible_next_actions ,
304+ possible_next_actions_mask ,
305+ time_diff ,
306+ possible_actions ,
307+ possible_actions_mask ,
308+ policy_id ,
309+ )
175310 self .memory_num += 1
0 commit comments