Skip to content

Commit 334a08c

Browse files
yiwan-rlfacebook-github-bot
authored andcommitted
add the missing .shape to q value networks
Summary: Add the missing .shape to q value networks. Reviewed By: rodrigodesalvobraz Differential Revision: D67272397 fbshipit-source-id: 88cf32b9efa43fd9051c8de17fcc82ec3214fbe7
1 parent f35e798 commit 334a08c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pearl/neural_networks/sequential_decision_making/q_value_networks.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def get_q_values(
172172
q_values = self.forward(x).squeeze(
173173
-1
174174
) # (batch_size, number_of_actions_to_query)
175-
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
175+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)
176176

177177
@property
178178
def state_dim(self) -> int:
@@ -239,7 +239,7 @@ def get_q_values(
239239
q_values, # (batch_size x num actions x 1)
240240
) # (batch_size x number of query actions x 1)
241241
q_values = q_values.squeeze(-1) # (batch_size x number of query actions)
242-
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
242+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)
243243

244244
@property
245245
def state_dim(self) -> int:
@@ -324,7 +324,7 @@ def get_q_value_distribution(
324324
q_values = self.forward(
325325
x
326326
) # (batch_size, number_of_actions_to_query, number_of_quantiles)
327-
return q_values if len(action_batch) == 3 else q_values.squeeze(-2)
327+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-2)
328328

329329
@property
330330
def quantiles(self) -> Tensor:
@@ -503,7 +503,7 @@ def get_q_values(
503503
state_value + advantage - advantage_mean
504504
) # shape: (batch_size, number of query actions)
505505

506-
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
506+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)
507507

508508

509509
"""
@@ -590,7 +590,7 @@ def get_q_values(
590590
q_values = self._interaction_features.forward(x).squeeze(
591591
-1
592592
) # (batch_size, number_of_actions_to_query)
593-
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
593+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)
594594

595595
@property
596596
def state_dim(self) -> int:
@@ -695,7 +695,7 @@ def get_q_values(
695695
q_values = self.forward(x, z=z, persistent=persistent).squeeze(
696696
-1
697697
) # (batch_size, number_of_actions_to_query)
698-
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
698+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)
699699

700700
@property
701701
def state_dim(self) -> int:
@@ -806,7 +806,7 @@ def get_q_values(
806806
q_values = self._model_fc(x).reshape(
807807
batch_size, num_query_actions
808808
) # (batch_size, number_of_actions_to_query)
809-
return q_values if len(action_batch) == 3 else q_values.squeeze(-1)
809+
return q_values if len(action_batch.shape) == 3 else q_values.squeeze(-1)
810810

811811
@property
812812
def state_dim(self) -> int:

0 commit comments

Comments
 (0)