-
Notifications
You must be signed in to change notification settings - Fork 2k
Add tensorflow.js export example #2190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
3bc652c
346c36b
9d63c98
efb4d65
96d8680
4fa7b61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -195,7 +195,14 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl | |
| Export to ONNX-JS / ONNX Runtime Web | ||
| ------------------------------------ | ||
|
|
||
| See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | ||
| Full example code: https://github.com/JonathanColetti/CarDodgingGym and demo: https://jonathancoletti.github.io/CarDodgingGym/ | ||
|
||
|
|
||
| The code linked above is a complete example (using car dodging) that: | ||
| 1. Creates/Trains a PPO model | ||
| 2. Exports the model to ONNX along with normalization stats in JSON | ||
| 3. Normalizes and Inferences using onnxruntime-web to achieve similar results | ||
|
|
||
| Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS | ||
|
|
||
| .. code-block:: python | ||
|
|
||
|
|
@@ -210,8 +217,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | |
| self.actor = actor | ||
|
|
||
| def forward(self, observation: th.Tensor) -> th.Tensor: | ||
| # NOTE: You may have to postprocess (unnormalize) actions | ||
| # to the correct bounds (see commented code below) | ||
| # NOTE: You may have to postprocess (unnormalize or renormalize) | ||
| return self.actor(observation, deterministic=True) | ||
|
|
||
|
|
||
|
|
@@ -232,7 +238,7 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | |
|
|
||
| .. code-block:: javascript | ||
|
|
||
| // Install using `npm install onnxruntime-web` or using cdn | ||
| // Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn | ||
| import * as ort from 'onnxruntime-web'; | ||
|
|
||
| async function runInference() { | ||
|
|
@@ -257,8 +263,110 @@ See https://onnxruntime.ai/docs/tutorials/web/build-web-app.html | |
| Export to tensorflowjs | ||
| ---------------------- | ||
|
|
||
| TODO: contributors help is welcomed! | ||
| Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js | ||
| .. warning:: | ||
|
|
||
| As of writing this (November 2025), (https://github.com/PINTO0309/onnx2tf) does not support tensorflow js. Thus, (https://github.com/tensorflow/tfjs-converter) is used. This is not currently maintained and requires old opsets/tf versions. | ||
| Therefore, it is recommended you use onnx runtime for higher opsets | ||
|
|
||
| In order for this to work, you must convert (SB3 => ONNX => Tensorflow => Tensorflowjs) | ||
|
|
||
| The following is a simple example that showcases the full conversion + inference | ||
|
|
||
| .. code-block:: python | ||
|
||
|
|
||
| import torch as th | ||
| from stable_baselines3 import SAC | ||
|
|
||
|
|
||
| class OnnxablePolicy(th.nn.Module): | ||
| def __init__(self, actor: th.nn.Module): | ||
| super().__init__() | ||
| self.actor = actor | ||
|
|
||
| def forward(self, observation: th.Tensor) -> th.Tensor: | ||
| # NOTE: You may have to postprocess (unnormalize) actions | ||
| # to the correct bounds (see commented code below) | ||
| return self.actor(observation, deterministic=True) | ||
|
|
||
|
|
||
| # Example: model = SAC("MlpPolicy", "Pendulum-v1") | ||
| SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip") | ||
| model = SAC.load("PathToTrainedModel.zip", device="cpu") | ||
| onnxable_model = OnnxablePolicy(model.policy.actor) | ||
|
|
||
| observation_size = model.observation_space.shape | ||
| dummy_input = th.randn(1, *observation_size) | ||
| th.onnx.export( | ||
| onnxable_model, | ||
| dummy_input, | ||
| "my_sac_actor.onnx", | ||
| opset_version=12, # because of the outdated tf-js converter you have to use an old opset | ||
| input_names=["input"], | ||
| ) | ||
|
|
||
| .. code-block:: python | ||
|
|
||
| # Tested with python3.10 | ||
| # Then install these dependencies in a fresh env | ||
| """ | ||
| pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26 | ||
| """ | ||
| # Then run this codeblock | ||
| # If there are no errors (the folder is structure correctly) then | ||
| """ | ||
| # tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model | ||
| """ | ||
|
|
||
| # If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow | ||
| """ | ||
| pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs | ||
| """ | ||
| # And retry with and it should work (do not rerun this codeblock) | ||
| """ | ||
| tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model | ||
| """ | ||
|
|
||
| import onnx | ||
| import onnx_tf.backend | ||
| import tensorflow as tf | ||
|
|
||
| ONNX_FILE_PATH = "ppo_cargame.onnx" | ||
| MODEL_PATH = "tf_model" | ||
|
|
||
| onnx_model = onnx.load(ONNX_FILE_PATH) | ||
| onnx.checker.check_model(onnx_model) | ||
| print(onnx.helper.printable_graph(onnx_model.graph)) | ||
|
|
||
| print('Converting ONNX to TF...') | ||
| tf_rep = onnx_tf.backend.prepare(onnx_model) | ||
| tf_rep.export_graph(MODEL_PATH) | ||
| # After this do not forget to use `tensorflowjs_converter` | ||
|
|
||
|
|
||
| .. code-block:: javascript | ||
|
|
||
| import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/+esm'; | ||
| // Post processing not included | ||
| async function runInference() { | ||
| const MODEL_URL = './tfjs_model/model.json'; | ||
|
|
||
| const model = await tf.loadGraphModel(MODEL_URL); | ||
|
|
||
| // Observation_size is 3 for Pendulum-v1 | ||
| const inputData = [1.0, 0.0, 0.0]; | ||
| const inputTensor = tf.tensor2d([inputData], [1, 3]); | ||
|
|
||
| const resultTensor = model.execute(inputTensor); | ||
|
|
||
| const action = await resultTensor.data(); | ||
|
|
||
| console.log('Predicted action=', action); | ||
|
|
||
| inputTensor.dispose(); | ||
| resultTensor.dispose(); | ||
| } | ||
|
|
||
| runInference(); | ||
|
|
||
|
|
||
| Export to TFLite / Coral (Edge TPU) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.