Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ cd stable-baselines3/
2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:

```bash
pip install -e .[docs,tests,extra]
pip install -e '.[docs,tests,extra]'
```

## Codestyle
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Actions `gymnasium.spaces`:
## Testing the installation
### Install dependencies
```sh
pip install -e .[docs,tests,extra]
pip install -e '.[docs,tests,extra]'
```
### Run tests
All unit tests in stable baselines3 can be run using `pytest` runner:
Expand Down
109 changes: 97 additions & 12 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 poli

.. warning::

The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
(clip or unscale the action to the correct space).
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions (clip or unscale the action to the correct space).


.. code-block:: python
Expand Down Expand Up @@ -195,10 +194,22 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl
Export to ONNX-JS / ONNX Runtime Web
------------------------------------

See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
Official documentation: https://onnxruntime.ai/docs/tutorials/web/build-web-app.html

Full example code: https://github.com/JonathanColetti/CarDodgingGym

Demo: https://jonathancoletti.github.io/CarDodgingGym

The code linked above is a complete example (using car dodging environment) that:

1. Creates/Trains a PPO model
2. Exports the model to ONNX along with normalization stats in JSON
3. Runs in the browser with normalization using onnxruntime-web to achieve similar results

Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS

.. code-block:: python

import torch as th

from stable_baselines3 import SAC
Expand All @@ -210,8 +221,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
self.actor = actor

def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
# NOTE: You may have to postprocess (unnormalize or renormalize)
return self.actor(observation, deterministic=True)


Expand All @@ -231,8 +241,8 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
)

.. code-block:: javascript
// Install using `npm install onnxruntime-web` or using cdn

// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
import * as ort from 'onnxruntime-web';

async function runInference() {
Expand All @@ -254,11 +264,86 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
runInference();


Export to tensorflowjs
----------------------
Export to TensorFlow.js
-----------------------

.. warning::

As of November 2025, `onnx2tf <https://github.com/PINTO0309/onnx2tf>`_ does not support TensorFlow.js. Therefore, `tfjs-converter <https://github.com/tensorflow/tfjs-converter>`_ is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions.


In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js.

The opset version needs to be changed for the conversion (``opset_version=14`` is currently required). Please refer to the code above for more stable usage with a higher opset.

The following is a simple example that showcases the full conversion + inference.

Please refer to the previous sections for the first step (SB3 => ONNX).
The main difference is that you need to specify ``opset_version=14``.

.. code-block:: python

# Tested with python3.10
# Then install these dependencies in a fresh env
"""
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26
"""
# Then run this codeblock
# If there are no errors (the folder is structure correctly) then
"""
# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""

# If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow
"""
pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs
"""
# And retry with and it should work (do not rerun this codeblock)
"""
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""

TODO: contributors help is welcomed!
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js
import onnx
import onnx_tf.backend
import tensorflow as tf

ONNX_FILE_PATH = "my_sac_actor.onnx"
MODEL_PATH = "tf_model"

onnx_model = onnx.load(ONNX_FILE_PATH)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

print('Converting ONNX to TF...')
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph(MODEL_PATH)
# After this do not forget to use `tensorflowjs_converter`


.. code-block:: javascript

import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/+esm';
// Post processing not included
async function runInference() {
const MODEL_URL = './tfjs_model/model.json';

const model = await tf.loadGraphModel(MODEL_URL);

// Observation_size is 3 for Pendulum-v1
const inputData = [1.0, 0.0, 0.0];
const inputTensor = tf.tensor2d([inputData], [1, 3]);

const resultTensor = model.execute(inputTensor);

const action = await resultTensor.data();

console.log('Predicted action=', action);

inputTensor.dispose();
resultTensor.dispose();
}

runInference();


Export to TFLite / Coral (Edge TPU)
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ Documentation:
- Added sb3-plus to projects page
- Added example usage of ONNX JS
- Updated link to paper of community project DeepNetSlice (@AlexPasqua)
- Added example usage of Tensorflow JS
- Included exact versions in ONNX JS and example project
- Made step 2 (`pip install`) of `CONTRIBUTING.md` more robust


Release 2.7.0 (2025-07-25)
Expand Down