11import warnings
22from abc import ABC , abstractmethod
3- from collections .abc import Generator
4- from typing import Any , Optional , Union
3+ from collections .abc import Generator , Mapping
4+ from dataclasses import InitVar , dataclass , field
5+ from types import MappingProxyType
6+ from typing import Any , ClassVar , Optional , Union
57
68import numpy as np
79import torch as th
1113from stable_baselines3 .common .type_aliases import (
1214 DictReplayBufferSamples ,
1315 DictRolloutBufferSamples ,
16+ DTypeLike ,
1417 ReplayBufferSamples ,
1518 RolloutBufferSamples ,
1619)
2427 psutil = None
2528
2629
30+ @dataclass
31+ class BufferDTypes :
32+ """
33+ Data class representing the data types used by a buffer.
34+
35+ :param _observations: Datatype of observation space
36+ :param _actions: Datatype of action space
37+ """
38+
39+ MAP_TORCH_DTYPES : ClassVar [dict ] = dict (complex32 = "complex64" , float = "float32" , bfloat16 = "float32" , bool = "bool_" )
40+ _observations : InitVar [Union [DTypeLike , Mapping [str , DTypeLike ]]]
41+ _actions : InitVar [DTypeLike ]
42+ observations : Union [np .dtype , MappingProxyType [str , np .dtype ]] = field (init = False )
43+ actions : np .dtype = field (init = False )
44+
45+ def __post_init__ (self , _observations : Union [DTypeLike , Mapping [str , DTypeLike ]], _actions : DTypeLike ):
46+ if isinstance (_observations , Mapping ):
47+ self .observations = MappingProxyType ({k : self .to_numpy_dtype (v ) for k , v in _observations .items ()})
48+ else :
49+ self .observations = self .to_numpy_dtype (_observations )
50+ self .actions = self .to_numpy_dtype (_actions )
51+
52+ @classmethod
53+ def to_numpy_dtype (cls , dtype_like : DTypeLike ) -> np .dtype :
54+ if isinstance (dtype_like , np .dtype ):
55+ return dtype_like
56+ elif isinstance (dtype_like , th .dtype ):
57+ torch_dtype_name = repr (dtype_like ).removeprefix ("torch." )
58+ numpy_dtype_name = cls .MAP_TORCH_DTYPES .get (torch_dtype_name , torch_dtype_name )
59+ try :
60+ return np .dtype (getattr (np , numpy_dtype_name ))
61+ except AttributeError as e :
62+ raise TypeError (f"Cannot cast torch dtype '{ torch_dtype_name } ' to numpy.dtype implicitly." ) from e
63+ elif isinstance (dtype_like , type ) and issubclass (dtype_like , np .generic ):
64+ return np .dtype (dtype_like )
65+ elif isinstance (dtype_like , str ):
66+ try :
67+ return np .dtype (dtype_like )
68+ except TypeError as e :
69+ raise TypeError (f"Cannot interpret str '{ dtype_like } ' as a valid numpy datatype." ) from e
70+ elif dtype_like is None :
71+ return np .dtype (dtype_like )
72+ raise TypeError (f"Cannot interpret unknown object '{ dtype_like } ' as a valid numpy datatype." )
73+
74+
2775class BaseBuffer (ABC ):
2876 """
2977 Base class that represent a buffer (rollout or replay)
@@ -46,7 +94,6 @@ def __init__(
4694 action_space : spaces .Space ,
4795 device : Union [th .device , str ] = "auto" ,
4896 n_envs : int = 1 ,
49- dtypes : Optional [dict ] = None ,
5097 ):
5198 super ().__init__ ()
5299 self .buffer_size = buffer_size
@@ -62,38 +109,13 @@ def __init__(
62109
63110 # unify the dtype decision logic for all buffer classes
64111 # see https://github.com/DLR-RM/stable-baselines3/issues/2162
65- dtypes = dtypes or dict ()
66- dtypes = dtypes .copy ()
67- object_dtype = np .dtype (object )
68-
69- # Ensure dtypes override is valid for dict observations
70- obs_dtype : Union [dict , np .dtype ]
71112 if isinstance (observation_space , spaces .Dict ):
72- if dtypes .get ("observations" ):
73- if not isinstance (dtypes ["observations" ], dict ):
74- dtypes ["observations" ] = {key : np .dtype (dtypes ["observations" ]) for key in self .obs_shape }
75- else :
76- dtypes ["observations" ] = {key : np .dtype (dtype ) for (key , dtype ) in dtypes ["observations" ].items ()}
77- obs_dtype = {key : np .dtype (space .dtype ) for (key , space ) in observation_space .spaces .items ()} # type: ignore[misc]
113+ self .dtypes = BufferDTypes (
114+ {key : space .dtype for (key , space ) in observation_space .spaces .items ()},
115+ action_space .dtype ,
116+ )
78117 else :
79- obs_dtype = np .dtype (observation_space .dtype )
80-
81- # Validate the dtypes
82- self .dtypes = dict (
83- observations = dtypes .get ("observations" , obs_dtype ), actions = np .dtype (dtypes .get ("actions" , action_space .dtype ))
84- )
85- for space , dtype in self .dtypes .items ():
86- if not isinstance (dtype , dict ):
87- dtype = {"" : dtype }
88- for key , subspace_dtype in dtype .items ():
89- if subspace_dtype == object_dtype :
90- if key :
91- key = f"[{ key } ]"
92- warnings .warn (
93- f"An object dtype has been assigned to { space } { key } , you are likely using a custom "
94- f"environment, please use it with caution and ensure that { space } { key } is properly "
95- "dereferenced / copied within each step to avoid unwanted consequences."
96- )
118+ self .dtypes = BufferDTypes (observation_space .dtype , action_space .dtype )
97119
98120 @staticmethod
99121 def swap_and_flatten (arr : np .ndarray ) -> np .ndarray :
@@ -227,9 +249,8 @@ def __init__(
227249 n_envs : int = 1 ,
228250 optimize_memory_usage : bool = False ,
229251 handle_timeout_termination : bool = True ,
230- dtypes : Optional [dict ] = None ,
231252 ):
232- super ().__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs , dtypes = dtypes )
253+ super ().__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs )
233254
234255 # Adjust buffer size
235256 self .buffer_size = max (buffer_size // n_envs , 1 )
@@ -247,16 +268,14 @@ def __init__(
247268 )
248269 self .optimize_memory_usage = optimize_memory_usage
249270
250- self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes [ " observations" ] )
271+ self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes . observations )
251272
252273 if not optimize_memory_usage :
253274 # When optimizing memory, `observations` contains also the next observation
254- self .next_observations = np .zeros (
255- (self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes ["observations" ]
256- )
275+ self .next_observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes .observations )
257276
258277 self .actions = np .zeros (
259- (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes [ " actions" ] )
278+ (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes . actions )
260279 )
261280
262281 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -420,17 +439,16 @@ def __init__(
420439 gae_lambda : float = 1 ,
421440 gamma : float = 0.99 ,
422441 n_envs : int = 1 ,
423- dtypes : Optional [dict ] = None ,
424442 ):
425- super ().__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs , dtypes = dtypes )
443+ super ().__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs )
426444 self .gae_lambda = gae_lambda
427445 self .gamma = gamma
428446 self .generator_ready = False
429447 self .reset ()
430448
431449 def reset (self ) -> None :
432- self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes [ " observations" ] )
433- self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes [ " actions" ] )
450+ self .observations = np .zeros ((self .buffer_size , self .n_envs , * self .obs_shape ), dtype = self .dtypes . observations )
451+ self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes . actions )
434452 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
435453 self .returns = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
436454 self .episode_starts = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -592,9 +610,8 @@ def __init__(
592610 n_envs : int = 1 ,
593611 optimize_memory_usage : bool = False ,
594612 handle_timeout_termination : bool = True ,
595- dtypes : Optional [dict ] = None ,
596613 ):
597- super (ReplayBuffer , self ).__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs , dtypes = dtypes )
614+ super (ReplayBuffer , self ).__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs )
598615
599616 assert isinstance (self .obs_shape , dict ), "DictReplayBuffer must be used with Dict obs space only"
600617 self .buffer_size = max (buffer_size // n_envs , 1 )
@@ -609,16 +626,16 @@ def __init__(
609626 self .optimize_memory_usage = optimize_memory_usage
610627
611628 self .observations = {
612- key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes [ " observations" ] [key ])
629+ key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes . observations [key ])
613630 for key , _obs_shape in self .obs_shape .items ()
614631 }
615632 self .next_observations = {
616- key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes [ " observations" ] [key ])
633+ key : np .zeros ((self .buffer_size , self .n_envs , * _obs_shape ), dtype = self .dtypes . observations [key ])
617634 for key , _obs_shape in self .obs_shape .items ()
618635 }
619636
620637 self .actions = np .zeros (
621- (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes [ " actions" ] )
638+ (self .buffer_size , self .n_envs , self .action_dim ), dtype = self ._maybe_cast_dtype (self .dtypes . actions )
622639 )
623640 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
624641 self .dones = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
@@ -772,9 +789,8 @@ def __init__(
772789 gae_lambda : float = 1 ,
773790 gamma : float = 0.99 ,
774791 n_envs : int = 1 ,
775- dtypes : Optional [dict ] = None ,
776792 ):
777- super (RolloutBuffer , self ).__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs , dtypes = dtypes )
793+ super (RolloutBuffer , self ).__init__ (buffer_size , observation_space , action_space , device , n_envs = n_envs )
778794
779795 assert isinstance (self .obs_shape , dict ), "DictRolloutBuffer must be used with Dict obs space only"
780796
@@ -788,9 +804,9 @@ def reset(self) -> None:
788804 self .observations = {}
789805 for key , obs_input_shape in self .obs_shape .items ():
790806 self .observations [key ] = np .zeros (
791- (self .buffer_size , self .n_envs , * obs_input_shape ), dtype = self .dtypes [ " observations" ] [key ]
807+ (self .buffer_size , self .n_envs , * obs_input_shape ), dtype = self .dtypes . observations [key ]
792808 )
793- self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes [ " actions" ] )
809+ self .actions = np .zeros ((self .buffer_size , self .n_envs , self .action_dim ), dtype = self .dtypes . actions )
794810 self .rewards = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
795811 self .returns = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
796812 self .episode_starts = np .zeros ((self .buffer_size , self .n_envs ), dtype = np .float32 )
0 commit comments