-
Notifications
You must be signed in to change notification settings - Fork 2k
Add tensorflow.js export example #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+102
−14
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
3bc652c
Update documentation with tensorflow.js/ONNX-js, robust command in co…
JonathanColetti 346c36b
fix doc building
JonathanColetti 9d63c98
Add offical documentation and changed tensorflow description
JonathanColetti efb4d65
Opset works at version 14 and change ONNX_FILE_PATH to the sac model
JonathanColetti 96d8680
README also does not contain quotes in pip install command
JonathanColetti 4fa7b61
Update export.rst
araffin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,8 +37,7 @@ If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 poli | |
|
|
||
| .. warning:: | ||
|
|
||
| The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions | ||
| (clip or unscale the action to the correct space). | ||
| The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions (clip or unscale the action to the correct space). | ||
|
|
||
|
|
||
| .. code-block:: python | ||
|
|
@@ -195,10 +194,22 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl | |
| Export to ONNX-JS / ONNX Runtime Web | ||
| ------------------------------------ | ||
|
|
||
| See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | ||
| Official documentation: https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | ||
|
|
||
| Full example code: https://github.com/JonathanColetti/CarDodgingGym | ||
|
|
||
| Demo: https://jonathancoletti.github.io/CarDodgingGym | ||
|
|
||
| The code linked above is a complete example (using car dodging environment) that: | ||
|
|
||
| 1. Creates/Trains a PPO model | ||
| 2. Exports the model to ONNX along with normalization stats in JSON | ||
| 3. Runs in the browser with normalization using onnxruntime-web to achieve similar results | ||
|
|
||
| Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| import torch as th | ||
|
|
||
| from stable_baselines3 import SAC | ||
|
|
@@ -210,8 +221,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | |
| self.actor = actor | ||
|
|
||
| def forward(self, observation: th.Tensor) -> th.Tensor: | ||
| # NOTE: You may have to postprocess (unnormalize) actions | ||
| # to the correct bounds (see commented code below) | ||
| # NOTE: You may have to postprocess (unnormalize or renormalize) | ||
| return self.actor(observation, deterministic=True) | ||
|
|
||
|
|
||
|
|
@@ -231,8 +241,8 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | |
| ) | ||
|
|
||
| .. code-block:: javascript | ||
| // Install using `npm install onnxruntime-web` or using cdn | ||
|
|
||
| // Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn | ||
| import * as ort from 'onnxruntime-web'; | ||
|
|
||
| async function runInference() { | ||
|
|
@@ -254,11 +264,86 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | |
| runInference(); | ||
|
|
||
|
|
||
| Export to tensorflowjs | ||
| ---------------------- | ||
| Export to TensorFlow.js | ||
| ----------------------- | ||
|
|
||
| .. warning:: | ||
|
|
||
| As of November 2025, `onnx2tf <https://github.com/PINTO0309/onnx2tf>`_ does not support TensorFlow.js. Therefore, `tfjs-converter <https://github.com/tensorflow/tfjs-converter>`_ is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions. | ||
|
|
||
|
|
||
| In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js. | ||
|
|
||
| The opset version needs to be changed for the conversion (``opset_version=14`` is currently required). Please refer to the code above for more stable usage with a higher opset. | ||
|
|
||
| The following is a simple example that showcases the full conversion + inference. | ||
|
|
||
| Please refer to the previous sections for the first step (SB3 => ONNX). | ||
| The main difference is that you need to specify ``opset_version=14``. | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # Tested with python3.10 | ||
| # Then install these dependencies in a fresh env | ||
| """ | ||
| pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26 | ||
| """ | ||
| # Then run this codeblock | ||
| # If there are no errors (the folder is structure correctly) then | ||
| """ | ||
| # tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model | ||
| """ | ||
|
|
||
| # If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow | ||
| """ | ||
| pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs | ||
| """ | ||
| # And retry with and it should work (do not rerun this codeblock) | ||
| """ | ||
| tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model | ||
| """ | ||
|
|
||
| TODO: contributors help is welcomed! | ||
| Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js | ||
| import onnx | ||
| import onnx_tf.backend | ||
| import tensorflow as tf | ||
|
|
||
| ONNX_FILE_PATH = "my_sac_actor.onnx" | ||
| MODEL_PATH = "tf_model" | ||
|
|
||
| onnx_model = onnx.load(ONNX_FILE_PATH) | ||
| onnx.checker.check_model(onnx_model) | ||
| print(onnx.helper.printable_graph(onnx_model.graph)) | ||
|
|
||
| print('Converting ONNX to TF...') | ||
| tf_rep = onnx_tf.backend.prepare(onnx_model) | ||
| tf_rep.export_graph(MODEL_PATH) | ||
| # After this do not forget to use `tensorflowjs_converter` | ||
|
|
||
|
|
||
| .. code-block:: javascript | ||
|
|
||
| import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/+esm'; | ||
| // Post processing not included | ||
| async function runInference() { | ||
| const MODEL_URL = './tfjs_model/model.json'; | ||
|
|
||
| const model = await tf.loadGraphModel(MODEL_URL); | ||
|
|
||
| // Observation_size is 3 for Pendulum-v1 | ||
| const inputData = [1.0, 0.0, 0.0]; | ||
| const inputTensor = tf.tensor2d([inputData], [1, 3]); | ||
|
|
||
| const resultTensor = model.execute(inputTensor); | ||
|
|
||
| const action = await resultTensor.data(); | ||
|
|
||
| console.log('Predicted action=', action); | ||
|
|
||
| inputTensor.dispose(); | ||
| resultTensor.dispose(); | ||
| } | ||
|
|
||
| runInference(); | ||
|
|
||
|
|
||
| Export to TFLite / Coral (Edge TPU) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.