Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ cd stable-baselines3/
2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:

```bash
pip install -e .[docs,tests,extra]
pip install -e '.[docs,tests,extra]'
```

## Codestyle
Expand Down
120 changes: 114 additions & 6 deletions docs/guide/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,14 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl
Export to ONNX-JS / ONNX Runtime Web
------------------------------------

See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
Full example code: https://github.com/JonathanColetti/CarDodgingGym and demo: https://jonathancoletti.github.io/CarDodgingGym/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep the link to the official doc too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and cool demo =)


The code linked above is a complete example (using car dodging) that:
1. Creates/Trains a PPO model
2. Exports the model to ONNX along with normalization stats in JSON
3. Normalizes and Inferences using onnxruntime-web to achieve similar results

Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS

.. code-block:: python

Expand All @@ -210,8 +217,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
self.actor = actor

def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
# NOTE: You may have to postprocess (unnormalize or renormalize)
return self.actor(observation, deterministic=True)


Expand All @@ -232,7 +238,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html

.. code-block:: javascript

// Install using `npm install onnxruntime-web` or using cdn
// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
import * as ort from 'onnxruntime-web';

async function runInference() {
Expand All @@ -257,8 +263,110 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
Export to tensorflowjs
----------------------

TODO: contributors help is welcomed!
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js
.. warning::

As of writing this (November 2025), (https://github.com/PINTO0309/onnx2tf) does not support tensorflow js. Thus, (https://github.com/tensorflow/tfjs-converter) is used. This is not currently maintained and requires old opsets/tf versions.
Therefore, it is recommended you use onnx runtime for higher opsets

In order for this to work, you must convert (SB3 => ONNX => Tensorflow => Tensorflowjs)

The following is a simple example that showcases the full conversion + inference

.. code-block:: python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just refer to the code above and say that the opset need to be changed


import torch as th
from stable_baselines3 import SAC


class OnnxablePolicy(th.nn.Module):
def __init__(self, actor: th.nn.Module):
super().__init__()
self.actor = actor

def forward(self, observation: th.Tensor) -> th.Tensor:
# NOTE: You may have to postprocess (unnormalize) actions
# to the correct bounds (see commented code below)
return self.actor(observation, deterministic=True)


# Example: model = SAC("MlpPolicy", "Pendulum-v1")
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
model = SAC.load("PathToTrainedModel.zip", device="cpu")
onnxable_model = OnnxablePolicy(model.policy.actor)

observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)
th.onnx.export(
onnxable_model,
dummy_input,
"my_sac_actor.onnx",
opset_version=12, # because of the outdated tf-js converter you have to use an old opset
input_names=["input"],
)

.. code-block:: python

# Tested with python3.10
# Then install these dependencies in a fresh env
"""
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26
"""
# Then run this codeblock
# If there are no errors (the folder is structure correctly) then
"""
# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""

# If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow
"""
pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs
"""
# And retry with and it should work (do not rerun this codeblock)
"""
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
"""

import onnx
import onnx_tf.backend
import tensorflow as tf

ONNX_FILE_PATH = "ppo_cargame.onnx"
MODEL_PATH = "tf_model"

onnx_model = onnx.load(ONNX_FILE_PATH)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

print('Converting ONNX to TF...')
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph(MODEL_PATH)
# After this do not forget to use `tensorflowjs_converter`


.. code-block:: javascript

import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/+esm';
// Post processing not included
async function runInference() {
const MODEL_URL = './tfjs_model/model.json';

const model = await tf.loadGraphModel(MODEL_URL);

// Observation_size is 3 for Pendulum-v1
const inputData = [1.0, 0.0, 0.0];
const inputTensor = tf.tensor2d([inputData], [1, 3]);

const resultTensor = model.execute(inputTensor);

const action = await resultTensor.data();

console.log('Predicted action=', action);

inputTensor.dispose();
resultTensor.dispose();
}

runInference();


Export to TFLite / Coral (Edge TPU)
Expand Down
3 changes: 3 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ Documentation:
- Added sb3-plus to projects page
- Added example usage of ONNX JS
- Updated link to paper of community project DeepNetSlice (@AlexPasqua)
- Added example usage of Tensorflow JS
- Included exact versions in ONNX JS and example project
- Made step 2 (`pip install`) of `CONTRIBUTING.md` more robust


Release 2.7.0 (2025-07-25)
Expand Down