File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -585,7 +585,7 @@ def _get_samples(
585585 ) -> RolloutBufferSamples :
586586 data = (
587587 self .observations [batch_inds ],
588- self .actions [batch_inds ],
588+ self .actions [batch_inds ]. astype ( np . float32 , copy = False ) ,
589589 self .values [batch_inds ].flatten (),
590590 self .log_probs [batch_inds ].flatten (),
591591 self .advantages [batch_inds ].flatten (),
@@ -907,7 +907,7 @@ def _get_samples( # type: ignore[override]
907907 ) -> DictRolloutBufferSamples :
908908 return DictRolloutBufferSamples (
909909 observations = {key : self .to_torch (obs [batch_inds ]) for (key , obs ) in self .observations .items ()},
910- actions = self .to_torch (self .actions [batch_inds ]),
910+ actions = self .to_torch (self .actions [batch_inds ]. astype ( np . float32 , copy = False ) ),
911911 old_values = self .to_torch (self .values [batch_inds ].flatten ()),
912912 old_log_prob = self .to_torch (self .log_probs [batch_inds ].flatten ()),
913913 advantages = self .to_torch (self .advantages [batch_inds ].flatten ()),
You can’t perform that action at this time.
0 commit comments