Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Few questions about the library regarding performance #62

Open
KamatMayur opened this issue Nov 21, 2024 · 2 comments
Open

Few questions about the library regarding performance #62

KamatMayur opened this issue Nov 21, 2024 · 2 comments
Labels
question Further information is requested

Comments

@KamatMayur
Copy link

KamatMayur commented Nov 21, 2024

I went through some of the code for PPO and it seems to be still using numpy arrays which are used in the observation spaces and actions spaces. So i'm guessing that the environments are still on the cpu and their steps also take place on the CPU...(correct me if I'm wrong). So does this mean that this library doesn't address the issue of CPU-GPU Data transfer. So all the speed up is basically just due to the optimization taking place using jax jit complied codes only for the policy optimization part... Am I right?... Could you guys please add support for Custom environments created on the GPU? Maybe I could contribute? I currently have jitted jax environment entirely on GPU and i wanted to just use existing libraries to train the policy but apparently all the libraries use some sort of conventional environment library on which their code is based. My environment is build from scratch using jax and contains the main step() and reset() functions. But now i am stuck with the only option to implement the policy and optimization code all on my own. So I just want to know if there is anyway your library can be used on a custom Environment where all the functions are stateless.

@KamatMayur KamatMayur added the question Further information is requested label Nov 21, 2024
@araffin
Copy link
Owner

araffin commented Nov 21, 2024

So i'm guessing that the environments are still on the cpu and their steps also take place on the CPU.

yes

doesn't address the issue of CPU-GPU Data transfer.

for PPO, when not using images, (and when not using isaac sim), there is no need for GPU.
(see runtime reports from https://rlj.cs.umass.edu/2024/papers/Paper18.html and several issues about that on SB3 repo)

just due to the optimization taking place using jax jit complied codes only for the policy optimization part.

and for the policy/value prediction.

Could you guys please add support for Custom environments created on the GPU?

This is currently not planned (as SBX is based on SB3).
However, if you want no data transfer, you will need:

  1. Custom vec env (with have some examples, including for isaac sim in SB3 doc)
  2. Custom replay buffer (should not be too hard)
  3. You might need to remove some calls to .numpy() in SB3 (I'm not 100% sure about that part)

if you do so, please open source your fork of SBX, it might be helpful for others.

@yhs0602
Copy link

yhs0602 commented Jan 10, 2025

@KamatMayur
Hello, I have implemented some monkey-patches for stable-baselines3.

This code is provided under the same license as stable-baselines3 (MIT License).
You can find the implementation here:
https://github.com/yhs0602/minecraft-simulator-benchmark/blob/main/experiments/optim_dummy_vec_env.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants