Skip to content

Commit

Permalink
fix the way we detect np.float64 observations
Browse files Browse the repository at this point in the history
Summary: The current way to detect np array of type float64 is incorrect. np.float64 is the type of a single number, not the type of an array.

Reviewed By: rodrigodesalvobraz

Differential Revision: D67272396

fbshipit-source-id: 00b4255e65cb65c776a50e65e8f341ad18cbb64d
  • Loading branch information
yiwan-rl authored and facebook-github-bot committed Dec 19, 2024
1 parent 00ed853 commit f35e798
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pearl/utils/instantiations/environments/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def reset(self, seed: int | None = None) -> tuple[Observation, ActionSpace]:
# TODO: Deprecate this part at some point and only support new
# version of Gymnasium?
observation = list(reset_result.values())[0] # pyre-ignore
if isinstance(observation, np.float64):
if isinstance(observation, np.ndarray) and observation.dtype == np.float64:
observation = observation.astype(np.float32)
return observation, self.action_space

Expand Down Expand Up @@ -151,7 +151,7 @@ def step(self, action: Action) -> ActionResult:
else:
available_action_space = None

if isinstance(observation, np.float64):
if isinstance(observation, np.ndarray) and observation.dtype == np.float64:
observation = observation.astype(np.float32)
if isinstance(reward, np.float64):
reward = reward.astype(np.float32)
Expand Down

0 comments on commit f35e798

Please sign in to comment.