@@ -172,7 +172,7 @@ def get_q_values(
172
172
q_values = self .forward (x ).squeeze (
173
173
- 1
174
174
) # (batch_size, number_of_actions_to_query)
175
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 1 )
175
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 1 )
176
176
177
177
@property
178
178
def state_dim (self ) -> int :
@@ -239,7 +239,7 @@ def get_q_values(
239
239
q_values , # (batch_size x num actions x 1)
240
240
) # (batch_size x number of query actions x 1)
241
241
q_values = q_values .squeeze (- 1 ) # (batch_size x number of query actions)
242
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 1 )
242
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 1 )
243
243
244
244
@property
245
245
def state_dim (self ) -> int :
@@ -324,7 +324,7 @@ def get_q_value_distribution(
324
324
q_values = self .forward (
325
325
x
326
326
) # (batch_size, number_of_actions_to_query, number_of_quantiles)
327
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 2 )
327
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 2 )
328
328
329
329
@property
330
330
def quantiles (self ) -> Tensor :
@@ -503,7 +503,7 @@ def get_q_values(
503
503
state_value + advantage - advantage_mean
504
504
) # shape: (batch_size, number of query actions)
505
505
506
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 1 )
506
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 1 )
507
507
508
508
509
509
"""
@@ -590,7 +590,7 @@ def get_q_values(
590
590
q_values = self ._interaction_features .forward (x ).squeeze (
591
591
- 1
592
592
) # (batch_size, number_of_actions_to_query)
593
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 1 )
593
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 1 )
594
594
595
595
@property
596
596
def state_dim (self ) -> int :
@@ -695,7 +695,7 @@ def get_q_values(
695
695
q_values = self .forward (x , z = z , persistent = persistent ).squeeze (
696
696
- 1
697
697
) # (batch_size, number_of_actions_to_query)
698
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 1 )
698
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 1 )
699
699
700
700
@property
701
701
def state_dim (self ) -> int :
@@ -806,7 +806,7 @@ def get_q_values(
806
806
q_values = self ._model_fc (x ).reshape (
807
807
batch_size , num_query_actions
808
808
) # (batch_size, number_of_actions_to_query)
809
- return q_values if len (action_batch ) == 3 else q_values .squeeze (- 1 )
809
+ return q_values if len (action_batch . shape ) == 3 else q_values .squeeze (- 1 )
810
810
811
811
@property
812
812
def state_dim (self ) -> int :
0 commit comments