Skip to content

Commit d644e4b

Browse files
araffinoursland
authored andcommitted
Use MPS device when available
1 parent 69d3d3d commit d644e4b

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Breaking Changes:
1515
New Features:
1616
^^^^^^^^^^^^^
1717
- Added official support for Python 3.13
18+
- Use MacOS Metal "mps" device when available
19+
- Save cloudpickle version
1820

1921
Bug Fixes:
2022
^^^^^^^^^^

stable_baselines3/common/utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,20 @@ def get_device(device: th.device | str = "auto") -> th.device:
224224
"""
225225
Retrieve PyTorch device.
226226
It checks that the requested device is available first.
227-
For now, it supports only cpu and cuda.
228-
By default, it tries to use the gpu.
227+
For now, it supports only CPU and CUDA.
228+
By default, it tries to use the GPU.
229229
230-
:param device: One for 'auto', 'cuda', 'cpu'
230+
:param device: One of "auto", "cuda", "cpu",
231+
or any PyTorch supported device (for instance "mps")
231232
:return: Supported Pytorch device
232233
"""
233-
# Cuda by default
234+
# MPS/CUDA by default
234235
if device == "auto":
235-
device = "cuda"
236+
device = get_available_accelerator()
236237
# Force conversion to th.device
237238
device = th.device(device)
238239

239-
# Cuda not available
240+
# CUDA not available
240241
if device.type == th.device("cuda").type and not th.cuda.is_available():
241242
return th.device("cpu")
242243

@@ -597,6 +598,20 @@ def should_collect_more_steps(
597598
)
598599

599600

601+
def get_available_accelerator() -> str:
602+
"""
603+
Return the available accelerator
604+
(currently checking only for CUDA and MPS device)
605+
"""
606+
if hasattr(th, "has_mps") and th.backends.mps.is_available():
607+
# MacOS Metal GPU
608+
return "mps"
609+
elif th.cuda.is_available():
610+
return "cuda"
611+
else:
612+
return "cpu"
613+
614+
600615
def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
601616
"""
602617
Retrieve system and python env info for the current system.
@@ -612,7 +627,7 @@ def get_system_info(print_info: bool = True) -> tuple[dict[str, str], str]:
612627
"Python": platform.python_version(),
613628
"Stable-Baselines3": sb3.__version__,
614629
"PyTorch": th.__version__,
615-
"GPU Enabled": str(th.cuda.is_available()),
630+
"Accelerator": get_available_accelerator(),
616631
"Numpy": np.__version__,
617632
"Cloudpickle": cloudpickle.__version__,
618633
"Gymnasium": gym.__version__,

0 commit comments

Comments
 (0)