Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting PyTorch GPU compatibility on Apple Silicon chips #914

Open
ryanrudes opened this issue May 20, 2022 · 17 comments · May be fixed by #951 or deathcoder/stable-baselines3#1
Open

Supporting PyTorch GPU compatibility on Apple Silicon chips #914

ryanrudes opened this issue May 20, 2022 · 17 comments · May be fixed by #951 or deathcoder/stable-baselines3#1
Labels
enhancement New feature or request

Comments

@ryanrudes
Copy link

ryanrudes commented May 20, 2022

🚀 Feature

PyTorch recently released support for GPU acceleration using the Apple Silicon chips. This should be supported in stable-baselines3 by the "mps" device (I believe).

Minimal Example

from stable_baselines3 import PPO
import gym

env = gym.make("QbertNoFrameskip-v0")
ppo = PPO("CnnPolicy", env, device = "mps")

ppo.learn(total_timesteps = 1000000)

The Mac Silicon GPU device is not automatically recognized by stable-baselines at the moment, so it defaults to "cpu". If you try to force it to use the "mps" device, this stack trace appears.

A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
Traceback (most recent call last):
  File "/Users/ryanrudes/Desktop/pydt/train_min.py", line 7, in <module>
    ppo.learn(total_timesteps = 1000000)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/ppo/ppo.py", line 310, in learn
    return super().learn(
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/policies.py", line 592, in forward
    distribution = self._get_action_dist_from_latent(latent_pi)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/policies.py", line 610, in _get_action_dist_from_latent
    return self.action_dist.proba_distribution(action_logits=mean_actions)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/distributions.py", line 274, in proba_distribution
    self.distribution = Categorical(logits=action_logits)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/torch/distributions/categorical.py", line 60, in __init__
    self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
NotImplementedError: Could not run 'aten::amax.out' with arguments from the 'MPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::amax.out' is only available for these backends: [Dense, Conjugate, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].

CPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterCPU.cpp:37386 [kernel]
Meta: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterMeta.cpp:31637 [kernel]
BackendSelect: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:133 [backend fallback]
Named: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/ADInplaceOrViewType_1.cpp:3288 [kernel]
AutogradOther: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradCUDA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradXLA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradMPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradIPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradXPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradHPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradLazy: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse1: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse2: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse3: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
Tracer: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/TraceType_1.cpp:11951 [kernel]
AutocastCPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:481 [backend fallback]
Autocast: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
Batched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
VmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
Functionalize: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterFunctionalization_3.cpp:12118 [kernel]
PythonTLSSnapshot: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:137 [backend fallback
@ryanrudes ryanrudes added the enhancement New feature or request label May 20, 2022
@Miffyli Miffyli changed the title Supporting PyTorch GPU compatibility on Silicon chips Supporting PyTorch GPU compatibility on Apple Silicon chips May 20, 2022
@Miffyli
Copy link
Collaborator

Miffyli commented May 20, 2022

Ideally yes, SB3 should support that device too (not a big change), but seems like it would, at the moment, require some operation-call changes to fully support. Those need to be addressed first (or wait till torch has equal functions for all platforms), but the changes should not interfere with existing code at all; this could spur up lots of hidden changes otherwise.

@araffin
Copy link
Member

araffin commented Jul 4, 2022

@ryanrudes could you test again? (there was a PyTorch release recently)

And maybe test with PyTorch nightly build, it apparently works: DLR-RM/rl-baselines3-zoo#267

I will update the "auto" device behavior in case it does ;)

@araffin araffin linked a pull request Jul 4, 2022 that will close this issue
14 tasks
@qgallouedec
Copy link
Collaborator

from stable_baselines3 import PPO
import gym

env = gym.make("Pendulum-v1")
ppo = PPO("MlpPolicy", env, device="mps")

ppo.learn(total_timesteps=1000)

With current pip version of PyTorch (1.12), it raises the following exception:

Traceback (most recent call last):
  File "/Users/quentingallouedec/stable-baselines3/try_mps.py", line 7, in <module>
    ppo.learn(total_timesteps=1000)
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/ppo/ppo.py", line 310, in learn
    return super().learn(
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 267, in learn
    self.train()
  File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/ppo/ppo.py", line 270, in train
    loss.backward()
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
NotImplementedError: The operator 'aten::logical_and.out' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Following the suggestion of the traceback (export PYTORCH_ENABLE_MPS_FALLBACK=1) works for me.

So it is not completely stable for the moment. We may have to wait until the next release...

@araffin
Copy link
Member

araffin commented Jul 4, 2022

thanks @qgallouedec for the feedback =)

We do need to wait for more coverage yes, issue is here: pytorch/pytorch#77764

@qgallouedec
Copy link
Collaborator

So it is not completely stable for the moment. We may have to wait until the next release...

The bug no longer occurs with the new version of PyTorch (1.12.1)

@traderpedroso
Copy link

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
MPS framework only support float32 add dtype=th.float32 on utils.py

def obs_as_tensor(
    obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device
) -> Union[th.Tensor, TensorDict]:
    """
    Moves the observation to the given device.

    :param obs:
    :param device: PyTorch device
    :return: PyTorch tensor of the observation on a desired device.
    """
    if isinstance(obs, np.ndarray):
        return th.as_tensor(obs, dtype=th.float32).to(device)
    elif isinstance(obs, dict):
        return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()}
    else:
        raise Exception(f"Unrecognized type of observation {type(obs)}")

its fixed the problem converting float64 to float32! but after on traing model

NotImplementedError: The operator 'aten::multinomial' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

so than add export PYTORCH_ENABLE_MPS_FALLBACK=1 and same error continues

@qgallouedec
Copy link
Collaborator

For the moment, consider that SB3 is not compatible with MPS. But we are working on it: #951

Have you seen something in the documentation about float64 and MPS? (You can answer in the PR conversation.)

@traderpedroso
Copy link

traderpedroso commented Aug 19, 2022

For the moment, consider that SB3 is not compatible with MPS. But we are working on it: #951

Have you seen something in the documentation about float64 and MPS? (You can answer in the PR conversation.)

yes MPS framework

For a more extensive list of which data types do and don’t run:

Avoid Float64 on all Apple devices. Even if the hardware supports Double physically (AMD or Intel), the Metal API doesn’t let you access it.
Avoid BFloat16. That is natively supported by the latest Nvidia GPUs, but not supported in Metal. Also don’t try to use TF18/TF32 or Int4.
All standard integer types (UInt8, UInt16, UInt32, UInt64) and their signed counterparts work natively on Apple devices. Not exactly 8-bit integers, which are cast to 16-bit integers before being stored into registers, but those aren’t going to harm performance. Yes, they run 64-bit integers on the Apple GPU and not 64-bit floats. Metal allows you to use 64-bit integers in shaders on AMD and Intel, but the arithmetic there might just happen through emulation (slow). I think that’s where I experienced the crash in MPSGraph previously - trying to run an operation on UInt64 on my Intel Mac mini.

https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf

@traderpedroso
Copy link

traderpedroso commented Apr 9, 2023

This modified version of obs_as_tensor should work. Make these changes in the stable_baselines3/common/utils.py The modified obs_as_tensor function should now automatically convert the observation to float32 if the device is an MPS device.

Original

def obs_as_tensor(
    obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device
) -> Union[th.Tensor, TensorDict]:
    """
    Moves the observation to the given device.

    :param obs:
    :param device: PyTorch device
    :return: PyTorch tensor of the observation on a desired device.
    """
    if isinstance(obs, np.ndarray):
        return th.as_tensor(obs, device=device)
    elif isinstance(obs, dict):
        return {key: th.as_tensor(_obs, device=device) for (key, _obs) in obs.items()}
    else:
        raise Exception(f"Unrecognized type of observation {type(obs)}")

a workaround

def obs_as_tensor(obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device) -> Union[th.Tensor, TensorDict]:
    """
    Moves the observation to the given device.

    :param obs:
    :param device: PyTorch device
    :return: PyTorch tensor of the observation on a desired device.
    """
    dtype = th.float32 if device.type == "mps" else None

    if isinstance(obs, np.ndarray):
        return th.as_tensor(obs, device=device, dtype=dtype)
    elif isinstance(obs, dict):
        return {key: th.as_tensor(_obs, device=device, dtype=dtype) for (key, _obs) in obs.items()}
    else:
        raise ValueError(f"Unsupported observation format: {obs}")

Although it works normally, the CPU continues to be more performant than the MPS. Honestly, I don't know if this workaround is worth it, but it worked nonetheless.

@ScharanCysne
Copy link

Sorry @traderpedroso, but I don't see the difference between the workaround and the original code.

@traderpedroso
Copy link

Sorry @traderpedroso, but I don't see the difference between the workaround and the original code.

Apologies, I hadn't noticed that I duplicated the functions. I have now updated the code. Thank you for pointing it out.

@hom-bahrani
Copy link

@traderpedroso thanks for creating this issue. The M1 pro already comes with quite a lot of CPU's which distribute training nicely. I was wondering if you have done any benchmarks and observed any significant performance improvements with mps device?

@traderpedroso
Copy link

@traderpedroso thanks for creating this issue. The M1 pro already comes with quite a lot of CPU's which distribute training nicely. I was wondering if you have done any benchmarks and observed any significant performance improvements with mps device?

I must admit that I was profoundly disheartened by the limitations of MPs, particularly due to the lack of support. After conducting numerous tests, I discerned that for reinforcement learning, CPUs have consistently proven to be the optimal choice, or at most, TPUs. However, when it comes to leveraging GPUs from Nvidia, AMD, or MPs, their performance has been largely indistinguishable in my experience. Nevertheless, when combining PyTorch with MPs for NLP and image processing tasks, we are able to witness an exhilarating performance boost, as exemplified below.

import sys
import platform
import torch
import pandas as pd
import sklearn as sk

has_gpu = torch.cuda.is_available()
has_mps = getattr(torch,'has_mps',False)
device = "mps" if getattr(torch,'has_mps',False) \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(f"Python Platform: {platform.platform()}")
print(f"PyTorch Version: {torch.__version__}")
print()
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print(f"Scikit-Learn {sk.__version__}")
print("GPU is", "available" if has_gpu else "NOT AVAILABLE")
print("MPS (Apple Metal) is", "AVAILABLE" if has_mps else "NOT AVAILABLE")
print(f"Target device is {device}")

import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms

EPOCHS = 5

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)

    # device = torch.device("mps")
    print("Using Device: ", device)

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=64, shuffle=True)


    model = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

    for epoch in range(1, EPOCHS + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)

if __name__ == "__main__":
    main()

@tty666
Copy link

tty666 commented Jun 30, 2023

  • OS: macOS-13.4.1-arm64-arm-64bit Darwin Kernel Version 22.5.0: Thu Jun 8 22:22:20 PDT 2023; root:xnu-8796.121.3~7/RELEASE_ARM64_T6000
  • Python: 3.8.16
  • Stable-Baselines3: 1.8.0
  • PyTorch: 2.1.0.dev20230629
  • GPU Enabled: False
  • Numpy: 1.23.2
  • Gym: 0.21.0

Same problem here really it's painful to not being able to use the mps correctly today :(
At least we should do this modification :

def obs_as_tensor(
obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device
) -> Union[th.Tensor, Dict[Union[str, int], th.Tensor]]:
"""
Moves the observation to the given device.

:param obs:
:param device: PyTorch device
:return: PyTorch tensor of the observation on a desired device.
"""
if isinstance(obs, np.ndarray):
    # Convert the numpy array to float32 before moving it to the device
    return th.as_tensor(obs.astype(np.float32), device=device)
elif isinstance(obs, dict):
    # Convert the numpy arrays in the dict to float32 before moving them to the device
    return {key: th.as_tensor(_obs.astype(np.float32), device=device) for (key, _obs) in obs.items()}
else:
    raise Exception(f"Unrecognized type of observation {type(obs)}")

To avoid error :

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

but I have no idea if the mps is able to speed up the PPO models ...

@gy2256
Copy link

gy2256 commented Nov 13, 2023

Any updates on this feature request? Is it possible to use MPS with stable baselines3 now?

@nize
Copy link

nize commented Dec 17, 2023

Could the release of MLX play any role for improving performance of stable-baselines3 on Apple silicon? This post discusses how come MLX is not implemented within Pytorch, instead of as an alternative to Pytorch.

Unfortunately (?) it seems like stable-baselines3 would need to support use of MLX in addition to Pytorch to harvest the benefits.

@araffin
Copy link
Member

araffin commented Jan 10, 2024

Any updates on this feature request? Is it possible to use MPS with stable baselines3 now?

Please have a look at the PR and the other comments, you can give it a try using device="mps", the main bottleneck was PyTorch so far, please report any issue to the associated PR.

Could the release of MLX play any role for improving performance of stable-baselines3 on Apple silicon?

If you want to have a performance boost (not only on Apple silicon), I would recommend you to have a look at SBX (SB3 + Jax): https://github.com/araffin/sbx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
10 participants