Skip to content

Commit bab847b

Browse files
Add tensorflow.js export example (#2190)
* Update documentation with tensorflow.js/ONNX-js, robust command in contributing and update changelog * fix doc building * Add offical documentation and changed tensorflow description * Opset works at version 14 and change ONNX_FILE_PATH to the sac model * README also does not contain quotes in pip install command * Update export.rst --------- Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent b018e4b commit bab847b

File tree

4 files changed

+102
-14
lines changed

4 files changed

+102
-14
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ cd stable-baselines3/
3333
2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:
3434

3535
```bash
36-
pip install -e .[docs,tests,extra]
36+
pip install -e '.[docs,tests,extra]'
3737
```
3838

3939
## Codestyle

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ Actions `gymnasium.spaces`:
210210
## Testing the installation
211211
### Install dependencies
212212
```sh
213-
pip install -e .[docs,tests,extra]
213+
pip install -e '.[docs,tests,extra]'
214214
```
215215
### Run tests
216216
All unit tests in stable baselines3 can be run using `pytest` runner:

docs/guide/export.rst

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 poli
3737

3838
.. warning::
3939

40-
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
41-
(clip or unscale the action to the correct space).
40+
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions (clip or unscale the action to the correct space).
4241

4342

4443
.. code-block:: python
@@ -195,10 +194,22 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl
195194
Export to ONNX-JS / ONNX Runtime Web
196195
------------------------------------
197196

198-
See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
197+
Official documentation: https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
198+
199+
Full example code: https://github.com/JonathanColetti/CarDodgingGym
200+
201+
Demo: https://jonathancoletti.github.io/CarDodgingGym
202+
203+
The code linked above is a complete example (using car dodging environment) that:
204+
205+
1. Creates/Trains a PPO model
206+
2. Exports the model to ONNX along with normalization stats in JSON
207+
3. Runs in the browser with normalization using onnxruntime-web to achieve similar results
208+
209+
Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS
199210

200211
.. code-block:: python
201-
212+
202213
import torch as th
203214
204215
from stable_baselines3 import SAC
@@ -210,8 +221,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
210221
self.actor = actor
211222
212223
def forward(self, observation: th.Tensor) -> th.Tensor:
213-
# NOTE: You may have to postprocess (unnormalize) actions
214-
# to the correct bounds (see commented code below)
224+
# NOTE: You may have to postprocess (unnormalize or renormalize)
215225
return self.actor(observation, deterministic=True)
216226
217227
@@ -231,8 +241,8 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
231241
)
232242
233243
.. code-block:: javascript
234-
235-
// Install using `npm install onnxruntime-web` or using cdn
244+
245+
// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
236246
import * as ort from 'onnxruntime-web';
237247
238248
async function runInference() {
@@ -254,11 +264,86 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
254264
runInference();
255265
256266
257-
Export to tensorflowjs
258-
----------------------
267+
Export to TensorFlow.js
268+
-----------------------
269+
270+
.. warning::
271+
272+
As of November 2025, `onnx2tf <https://github.com/PINTO0309/onnx2tf>`_ does not support TensorFlow.js. Therefore, `tfjs-converter <https://github.com/tensorflow/tfjs-converter>`_ is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions.
273+
274+
275+
In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js.
276+
277+
The opset version needs to be changed for the conversion (``opset_version=14`` is currently required). Please refer to the code above for more stable usage with a higher opset.
278+
279+
The following is a simple example that showcases the full conversion + inference.
280+
281+
Please refer to the previous sections for the first step (SB3 => ONNX).
282+
The main difference is that you need to specify ``opset_version=14``.
283+
284+
.. code-block:: python
285+
286+
# Tested with python3.10
287+
# Then install these dependencies in a fresh env
288+
"""
289+
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26
290+
"""
291+
# Then run this codeblock
292+
# If there are no errors (the folder is structure correctly) then
293+
"""
294+
# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
295+
"""
296+
297+
# If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow
298+
"""
299+
pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs
300+
"""
301+
# And retry with and it should work (do not rerun this codeblock)
302+
"""
303+
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
304+
"""
259305
260-
TODO: contributors help is welcomed!
261-
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js
306+
import onnx
307+
import onnx_tf.backend
308+
import tensorflow as tf
309+
310+
ONNX_FILE_PATH = "my_sac_actor.onnx"
311+
MODEL_PATH = "tf_model"
312+
313+
onnx_model = onnx.load(ONNX_FILE_PATH)
314+
onnx.checker.check_model(onnx_model)
315+
print(onnx.helper.printable_graph(onnx_model.graph))
316+
317+
print('Converting ONNX to TF...')
318+
tf_rep = onnx_tf.backend.prepare(onnx_model)
319+
tf_rep.export_graph(MODEL_PATH)
320+
# After this do not forget to use `tensorflowjs_converter`
321+
322+
323+
.. code-block:: javascript
324+
325+
import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/+esm';
326+
// Post processing not included
327+
async function runInference() {
328+
const MODEL_URL = './tfjs_model/model.json';
329+
330+
const model = await tf.loadGraphModel(MODEL_URL);
331+
332+
// Observation_size is 3 for Pendulum-v1
333+
const inputData = [1.0, 0.0, 0.0];
334+
const inputTensor = tf.tensor2d([inputData], [1, 3]);
335+
336+
const resultTensor = model.execute(inputTensor);
337+
338+
const action = await resultTensor.data();
339+
340+
console.log('Predicted action=', action);
341+
342+
inputTensor.dispose();
343+
resultTensor.dispose();
344+
}
345+
346+
runInference();
262347
263348
264349
Export to TFLite / Coral (Edge TPU)

docs/misc/changelog.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ Documentation:
4444
- Added sb3-plus to projects page
4545
- Added example usage of ONNX JS
4646
- Updated link to paper of community project DeepNetSlice (@AlexPasqua)
47+
- Added example usage of Tensorflow JS
48+
- Included exact versions in ONNX JS and example project
49+
- Made step 2 (`pip install`) of `CONTRIBUTING.md` more robust
4750

4851

4952
Release 2.7.0 (2025-07-25)

0 commit comments

Comments
 (0)