Skip to content

Commit be985fb

Browse files
kittipatvfacebook-github-bot
authored andcommitted
Optimizing replay memory
Summary: Pull Request resolved: #175 Reviewed By: czxttkl Differential Revision: D17974234 fbshipit-source-id: 3f54b759e669d534bba70ff2ea6d0d332b38c15c
1 parent adfa6ab commit be985fb

File tree

6 files changed

+207
-62
lines changed

6 files changed

+207
-62
lines changed
Lines changed: 191 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
33

4+
import dataclasses
45
import logging
56
import random
67
from typing import Optional
@@ -16,25 +17,142 @@
1617
logger = logging.getLogger(__name__)
1718

1819

20+
@dataclasses.dataclass
21+
class MemoryBuffer:
22+
state: torch.Tensor
23+
action: torch.Tensor
24+
reward: torch.Tensor
25+
next_state: torch.Tensor
26+
next_action: torch.Tensor
27+
terminal: torch.Tensor
28+
possible_next_actions: Optional[torch.Tensor]
29+
possible_next_actions_mask: Optional[torch.Tensor]
30+
possible_actions: Optional[torch.Tensor]
31+
possible_actions_mask: Optional[torch.Tensor]
32+
time_diff: torch.Tensor
33+
policy_id: torch.Tensor
34+
35+
@torch.no_grad() # type: ignore
36+
def slice(self, indices):
37+
return MemoryBuffer(
38+
state=self.state[indices],
39+
action=self.action[indices],
40+
reward=self.reward[indices],
41+
next_state=self.next_state[indices],
42+
next_action=self.next_action[indices],
43+
terminal=self.terminal[indices],
44+
possible_next_actions=self.possible_next_actions[indices]
45+
if self.possible_next_actions is not None
46+
else None,
47+
possible_next_actions_mask=self.possible_next_actions_mask[indices]
48+
if self.possible_next_actions_mask is not None
49+
else None,
50+
possible_actions=self.possible_actions[indices]
51+
if self.possible_actions is not None
52+
else None,
53+
possible_actions_mask=self.possible_actions_mask[indices]
54+
if self.possible_actions_mask is not None
55+
else None,
56+
time_diff=self.time_diff[indices],
57+
policy_id=self.policy_id[indices],
58+
)
59+
60+
@torch.no_grad() # type: ignore
61+
def insert_at(
62+
self,
63+
idx: int,
64+
state: torch.Tensor,
65+
action: torch.Tensor,
66+
reward: float,
67+
next_state: torch.Tensor,
68+
next_action: torch.Tensor,
69+
terminal: bool,
70+
possible_next_actions: Optional[torch.Tensor],
71+
possible_next_actions_mask: Optional[torch.Tensor],
72+
time_diff: float,
73+
possible_actions: Optional[torch.Tensor],
74+
possible_actions_mask: Optional[torch.Tensor],
75+
policy_id: int,
76+
):
77+
self.state[idx] = state
78+
self.action[idx] = action
79+
self.reward[idx] = reward
80+
self.next_state[idx] = next_state
81+
self.next_action[idx] = next_action
82+
self.terminal[idx] = terminal
83+
if self.possible_actions is not None:
84+
self.possible_actions[idx] = possible_actions
85+
if self.possible_actions_mask is not None:
86+
self.possible_actions_mask[idx] = possible_actions_mask
87+
if self.possible_next_actions is not None:
88+
self.possible_next_actions[idx] = possible_next_actions
89+
if self.possible_next_actions_mask is not None:
90+
self.possible_next_actions_mask[idx] = possible_next_actions_mask
91+
self.time_diff[idx] = time_diff
92+
self.policy_id[idx] = policy_id
93+
94+
@classmethod
95+
def create(
96+
cls,
97+
max_size: int,
98+
state_dim: int,
99+
action_dim: int,
100+
max_possible_actions: Optional[int],
101+
has_possble_actions: bool,
102+
):
103+
return cls(
104+
state=torch.zeros((max_size, state_dim)),
105+
action=torch.zeros((max_size, action_dim)),
106+
reward=torch.zeros((max_size, 1)),
107+
next_state=torch.zeros((max_size, state_dim)),
108+
next_action=torch.zeros((max_size, action_dim)),
109+
terminal=torch.zeros((max_size, 1), dtype=torch.uint8),
110+
possible_next_actions=torch.zeros(
111+
(max_size, max_possible_actions, action_dim)
112+
)
113+
if has_possble_actions
114+
else None,
115+
possible_next_actions_mask=torch.zeros((max_size, max_possible_actions))
116+
if max_possible_actions
117+
else None,
118+
possible_actions=torch.zeros((max_size, max_possible_actions, action_dim))
119+
if has_possble_actions
120+
else None,
121+
possible_actions_mask=torch.zeros((max_size, max_possible_actions))
122+
if max_possible_actions
123+
else None,
124+
time_diff=torch.zeros((max_size, 1)),
125+
policy_id=torch.zeros((max_size, 1), dtype=torch.long),
126+
)
127+
128+
19129
class OpenAIGymMemoryPool:
20-
def __init__(self, max_replay_memory_size):
130+
def __init__(self, max_replay_memory_size: int):
21131
"""
22132
Creates an OpenAIGymMemoryPool object.
23133
24134
:param max_replay_memory_size: Upper bound on the number of transitions
25135
to store in replay memory.
26136
"""
27-
self.replay_memory = []
28137
self.max_replay_memory_size = max_replay_memory_size
29138
self.memory_num = 0
30-
self.skip_insert_until = self.max_replay_memory_size
139+
140+
# Not initializing in the beginning because we don't know the shapes
141+
self.memory_buffer: Optional[MemoryBuffer] = None
31142

32143
@property
33144
def size(self):
34-
return len(self.replay_memory)
145+
return min(self.memory_num, self.max_replay_memory_size)
146+
147+
@property
148+
def state_dim(self):
149+
assert self.memory_buffer is not None
150+
return self.memory_buffer.state.shape[1]
35151

36-
def shuffle(self):
37-
random.shuffle(self.replay_memory)
152+
@property
153+
def action_dim(self):
154+
assert self.memory_buffer is not None
155+
return self.memory_buffer.action.shape[1]
38156

39157
def sample_memories(self, batch_size, model_type, chunk=None):
40158
"""
@@ -49,72 +167,63 @@ def sample_memories(self, batch_size, model_type, chunk=None):
49167
:param model_type: Model type (discrete, parametric).
50168
:param chunk: Index of chunk of data (for deterministic sampling).
51169
"""
52-
cols = [[], [], [], [], [], [], [], [], [], [], [], []]
53-
54170
if chunk is None:
55-
indices = np.random.randint(0, len(self.replay_memory), size=batch_size)
171+
indices = torch.randint(0, self.size, size=(batch_size,))
56172
else:
57173
start_idx = chunk * batch_size
58174
end_idx = start_idx + batch_size
59175
indices = range(start_idx, end_idx)
60176

61-
for idx in indices:
62-
memory = self.replay_memory[idx]
63-
for col, value in zip(cols, memory):
64-
col.append(value)
177+
memory = self.memory_buffer.slice(indices)
65178

66-
states = stack(cols[0])
67-
next_states = stack(cols[3])
179+
states = memory.state
180+
next_states = memory.next_state
68181

69182
assert states.dim() == 2
70183
assert next_states.dim() == 2
71184

72185
if model_type == ModelType.PYTORCH_PARAMETRIC_DQN.value:
73-
num_possible_actions = len(cols[7][0])
186+
num_possible_actions = memory.possible_actions_mask.shape[1]
74187

75-
actions = stack(cols[1])
76-
next_actions = stack(cols[4])
188+
actions = memory.action
189+
next_actions = memory.next_action
77190

78191
tiled_states = states.repeat(1, num_possible_actions).reshape(
79192
-1, states.shape[1]
80193
)
81-
possible_actions = torch.cat(cols[8])
194+
possible_actions = memory.possible_actions.reshape(-1, actions.shape[1])
82195
possible_actions_state_concat = torch.cat(
83196
(tiled_states, possible_actions), dim=1
84197
)
85-
possible_actions_mask = stack(cols[9])
198+
possible_actions_mask = memory.possible_actions_mask
86199

87200
tiled_next_states = next_states.repeat(1, num_possible_actions).reshape(
88201
-1, next_states.shape[1]
89202
)
90-
possible_next_actions = torch.cat(cols[6])
203+
possible_next_actions = memory.possible_next_actions.reshape(
204+
-1, actions.shape[1]
205+
)
91206
possible_next_actions_state_concat = torch.cat(
92207
(tiled_next_states, possible_next_actions), dim=1
93208
)
94-
possible_next_actions_mask = stack(cols[7])
209+
possible_next_actions_mask = memory.possible_next_actions_mask
95210
else:
96211
possible_actions = None
97212
possible_actions_state_concat = None
98213
possible_next_actions = None
99214
possible_next_actions_state_concat = None
100-
if cols[7] is None or cols[7][0] is None:
101-
possible_next_actions_mask = None
102-
else:
103-
possible_next_actions_mask = stack(cols[7])
104-
if cols[9] is None or cols[9][0] is None:
105-
possible_actions_mask = None
106-
else:
107-
possible_actions_mask = stack(cols[9])
215+
possible_next_actions_mask = memory.possible_next_actions_mask
216+
possible_actions_mask = memory.possible_actions_mask
108217

109-
actions = stack(cols[1])
110-
next_actions = stack(cols[4])
218+
actions = memory.action
219+
next_actions = memory.next_action
111220

112221
assert len(actions.size()) == 2
113222
assert len(next_actions.size()) == 2
114223

115-
rewards = torch.tensor(cols[2], dtype=torch.float32).reshape(-1, 1)
116-
not_terminal = (1 - torch.tensor(cols[5], dtype=torch.int32)).reshape(-1, 1)
117-
time_diffs = torch.tensor(cols[10], dtype=torch.int32).reshape(-1, 1)
224+
rewards = memory.reward
225+
not_terminal = 1 - memory.terminal
226+
time_diffs = memory.time_diff
118227

119228
return TrainingDataPage(
120229
states=states,
@@ -144,32 +253,58 @@ def insert_into_memory(
144253
time_diff: float,
145254
possible_actions: Optional[torch.Tensor],
146255
possible_actions_mask: Optional[torch.Tensor],
147-
policy_id: str,
256+
policy_id: int,
148257
):
149258
"""
150259
Inserts transition into replay memory in such a way that retrieving
151260
transitions uniformly at random will be equivalent to reservoir sampling.
152261
"""
153-
item = (
154-
state,
155-
action,
156-
reward,
157-
next_state,
158-
next_action,
159-
terminal,
160-
possible_next_actions,
161-
possible_next_actions_mask,
162-
possible_actions,
163-
possible_actions_mask,
164-
time_diff,
165-
policy_id,
166-
)
167262

263+
if self.memory_buffer is None:
264+
assert state.shape == next_state.shape
265+
assert len(state.shape) == 1
266+
assert action.shape == next_action.shape
267+
assert len(action.shape) == 1
268+
if possible_actions_mask is not None:
269+
assert possible_next_actions_mask is not None
270+
assert possible_actions_mask.shape == possible_next_actions_mask.shape
271+
assert len(possible_actions_mask.shape) == 1
272+
max_possible_actions = possible_actions_mask.shape[0]
273+
else:
274+
max_possible_actions = None
275+
276+
assert (possible_actions is not None) == (possible_next_actions is not None)
277+
278+
self.memory_buffer = MemoryBuffer.create(
279+
max_size=self.max_replay_memory_size,
280+
state_dim=state.shape[0],
281+
action_dim=action.shape[0],
282+
max_possible_actions=max_possible_actions,
283+
has_possble_actions=possible_actions is not None,
284+
)
285+
286+
insert_idx = None
168287
if self.memory_num < self.max_replay_memory_size:
169-
self.replay_memory.append(item)
170-
elif self.memory_num >= self.skip_insert_until:
171-
p = float(self.max_replay_memory_size) / self.memory_num
172-
self.skip_insert_until += np.random.geometric(p)
173-
rand_index = np.random.randint(self.max_replay_memory_size)
174-
self.replay_memory[rand_index] = item
288+
insert_idx = self.memory_num
289+
else:
290+
rand_idx = torch.randint(0, self.memory_num, size=(1,)).item()
291+
if rand_idx < self.max_replay_memory_size:
292+
insert_idx = rand_idx # type: ignore
293+
294+
if insert_idx is not None:
295+
self.memory_buffer.insert_at(
296+
insert_idx,
297+
state,
298+
action,
299+
reward,
300+
next_state,
301+
next_action,
302+
terminal,
303+
possible_next_actions,
304+
possible_next_actions_mask,
305+
time_diff,
306+
possible_actions,
307+
possible_actions_mask,
308+
policy_id,
309+
)
175310
self.memory_num += 1

ml/rl/test/gym/run_gym.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def create_replay_buffer(
123123
replay_buffer = OpenAIGymMemoryPool(params.max_replay_memory_size)
124124
if path_to_pickled_transitions:
125125
create_stored_policy_offline_dataset(replay_buffer, path_to_pickled_transitions)
126-
replay_state_dim = replay_buffer.replay_memory[0][0].shape[0]
127-
replay_action_dim = replay_buffer.replay_memory[0][1].shape[0]
126+
replay_state_dim = replay_buffer.state_dim
127+
replay_action_dim = replay_buffer.action_dim
128128
assert replay_state_dim == env.state_dim
129129
assert replay_action_dim == env.action_dim
130130
elif offline_train:
@@ -490,7 +490,7 @@ def train_gym_online_rl(
490490
if (
491491
total_timesteps % train_every_ts == 0
492492
and total_timesteps > train_after_ts
493-
and len(replay_buffer.replay_memory) >= trainer.minibatch_size
493+
and replay_buffer.size >= trainer.minibatch_size
494494
and not (stop_training_after_solved and solved)
495495
):
496496
for _ in range(num_train_batches):

ml/rl/test/gym/world_model/mdnrnn_gym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def concat_batch(batch):
454454
dataset.insert(
455455
state=state_embed,
456456
action=torch.tensor(action_batch[i][hidden_idx + 1]), # type: ignore
457-
reward=reward_batch[i][hidden_idx + 1], # type: ignore
457+
reward=float(reward_batch[i][hidden_idx + 1]), # type: ignore
458458
next_state=next_state_embed,
459459
next_action=torch.tensor(
460460
next_action_batch[i][next_hidden_idx + 1] # type: ignore

ml/rl/test/gym/world_model/state_embed_gym.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def run_gym(
220220
for row in embed_rl_dataset.rows:
221221
replay_buffer.insert_into_memory(**row)
222222

223-
state_mem = torch.cat([m[0] for m in replay_buffer.replay_memory])
223+
assert replay_buffer.memory_buffer is not None
224+
state_mem = replay_buffer.memory_buffer.state
224225
state_min_value = torch.min(state_mem).item()
225226
state_max_value = torch.max(state_mem).item()
226227
state_embed_env = StateEmbedGymEnvironment(

ml/rl/workflow/dqn_workflow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ml.rl.json_serialize import from_json
1414
from ml.rl.parameters import (
1515
DiscreteActionModelParameters,
16+
EvaluationParameters,
1617
NormalizationParameters,
1718
RainbowDQNParameters,
1819
RLParameters,
@@ -117,12 +118,17 @@ def single_process_main(gpu_index, *args):
117118
rl_parameters = from_json(params["rl"], RLParameters)
118119
training_parameters = from_json(params["training"], TrainingParameters)
119120
rainbow_parameters = from_json(params["rainbow"], RainbowDQNParameters)
121+
if "evaluation" in params:
122+
evaluation_parameters = from_json(params["evaluation"], EvaluationParameters)
123+
else:
124+
evaluation_parameters = EvaluationParameters()
120125

121126
model_params = DiscreteActionModelParameters(
122127
actions=action_names,
123128
rl=rl_parameters,
124129
training=training_parameters,
125130
rainbow=rainbow_parameters,
131+
evaluation=evaluation_parameters,
126132
)
127133
state_normalization = BaseWorkflow.read_norm_file(params["state_norm_data_path"])
128134

0 commit comments

Comments
 (0)