@@ -224,19 +224,20 @@ def get_device(device: th.device | str = "auto") -> th.device:
224224 """
225225 Retrieve PyTorch device.
226226 It checks that the requested device is available first.
227- For now, it supports only cpu and cuda .
228- By default, it tries to use the gpu .
227+ For now, it supports only CPU and CUDA .
228+ By default, it tries to use the GPU .
229229
230- :param device: One for 'auto', 'cuda', 'cpu'
230+ :param device: One of "auto", "cuda", "cpu",
231+ or any PyTorch supported device (for instance "mps")
231232 :return: Supported Pytorch device
232233 """
233- # Cuda by default
234+ # MPS/CUDA by default
234235 if device == "auto" :
235- device = "cuda"
236+ device = get_available_accelerator ()
236237 # Force conversion to th.device
237238 device = th .device (device )
238239
239- # Cuda not available
240+ # CUDA not available
240241 if device .type == th .device ("cuda" ).type and not th .cuda .is_available ():
241242 return th .device ("cpu" )
242243
@@ -597,6 +598,20 @@ def should_collect_more_steps(
597598 )
598599
599600
601+ def get_available_accelerator () -> str :
602+ """
603+ Return the available accelerator
604+ (currently checking only for CUDA and MPS device)
605+ """
606+ if hasattr (th , "has_mps" ) and th .backends .mps .is_available ():
607+ # MacOS Metal GPU
608+ return "mps"
609+ elif th .cuda .is_available ():
610+ return "cuda"
611+ else :
612+ return "cpu"
613+
614+
600615def get_system_info (print_info : bool = True ) -> tuple [dict [str , str ], str ]:
601616 """
602617 Retrieve system and python env info for the current system.
@@ -612,7 +627,7 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
612627 "Python" : platform .python_version (),
613628 "Stable-Baselines3" : sb3 .__version__ ,
614629 "PyTorch" : th .__version__ ,
615- "GPU Enabled " : str ( th . cuda . is_available () ),
630+ "Accelerator " : get_available_accelerator ( ),
616631 "Numpy" : np .__version__ ,
617632 "Cloudpickle" : cloudpickle .__version__ ,
618633 "Gymnasium" : gym .__version__ ,
0 commit comments