@@ -390,9 +390,7 @@ def __init__(
390390
391391 def reset (self ) -> None :
392392 self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .observation_space .dtype )
393- self .actions = np .zeros (
394- (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .action_space .dtype )
395- )
393+ self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .action_space .dtype )
396394 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
397395 self .returns = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
398396 self .episode_starts = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -514,29 +512,15 @@ def _get_samples(
514512 ) -> RolloutBufferSamples :
515513 data = (
516514 self .observations [batch_inds ],
517- self .actions [batch_inds ],
515+ # Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
516+ self .actions [batch_inds ].astype (np .float32 , copy = False ),
518517 self .values [batch_inds ].flatten (),
519518 self .log_probs [batch_inds ].flatten (),
520519 self .advantages [batch_inds ].flatten (),
521520 self .returns [batch_inds ].flatten (),
522521 )
523522 return RolloutBufferSamples (* tuple (map (self .to_torch , data )))
524523
525- @staticmethod
526- def _maybe_cast_dtype (dtype : np .typing .DTypeLike ) -> np .typing .DTypeLike :
527- """
528- Cast `np.int8` action datatype to `np.float32`, keep the others dtype unchanged.
529- Otherwise, this would lead to
530- "RuntimeError: result type Float can't be cast to the desired output type Char"
531- when trying to compute the log prob for MultiBinary space.
532-
533- :param dtype: The original action space dtype
534- :return: ``np.float32`` if the dtype was int8, the original dtype otherwise.
535- """
536- if dtype == np .int8 :
537- return np .float32
538- return dtype
539-
540524
541525class DictReplayBuffer (ReplayBuffer ):
542526 """
@@ -765,9 +749,7 @@ def reset(self) -> None:
765749 self .observations [key ] = np .zeros (
766750 (self .buffer_size , self .n_envs , * obs_input_shape ), dtype = self .observation_space [key ].dtype
767751 )
768- self .actions = np .zeros (
769- (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .action_space .dtype )
770- )
752+ self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .action_space .dtype )
771753 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
772754 self .returns = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
773755 self .episode_starts = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -853,7 +835,8 @@ def _get_samples( # type: ignore[override]
853835 ) -> DictRolloutBufferSamples :
854836 return DictRolloutBufferSamples (
855837 observations = {key : self .to_torch (obs [batch_inds ]) for (key , obs ) in self .observations .items ()},
856- actions = self .to_torch (self .actions [batch_inds ]),
838+ # Cast to float32 (backward compatible), this would lead to RuntimeError for MultiBinary space
839+ actions = self .to_torch (self .actions [batch_inds ].astype (np .float32 , copy = False )),
857840 old_values = self .to_torch (self .values [batch_inds ].flatten ()),
858841 old_log_prob = self .to_torch (self .log_probs [batch_inds ].flatten ()),
859842 advantages = self .to_torch (self .advantages [batch_inds ].flatten ()),
0 commit comments