Skip to content

Commit 4fa7b61

Browse files
committed
Update export.rst
1 parent 96d8680 commit 4fa7b61

File tree

1 file changed

+22
-49
lines changed

1 file changed

+22
-49
lines changed

docs/guide/export.rst

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 poli
3737

3838
.. warning::
3939

40-
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
41-
(clip or unscale the action to the correct space).
40+
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions (clip or unscale the action to the correct space).
4241

4342

4443
.. code-block:: python
@@ -195,19 +194,22 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl
195194
Export to ONNX-JS / ONNX Runtime Web
196195
------------------------------------
197196

198-
The official documentation is located at: https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
197+
Official documentation: https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
199198

200-
Full example code: https://github.com/JonathanColetti/CarDodgingGym and demo: https://jonathancoletti.github.io/CarDodgingGym/
199+
Full example code: https://github.com/JonathanColetti/CarDodgingGym
201200

202-
The code linked above is a complete example (using car dodging) that:
203-
1. Creates/Trains a PPO model
201+
Demo: https://jonathancoletti.github.io/CarDodgingGym
202+
203+
The code linked above is a complete example (using car dodging environment) that:
204+
205+
1. Creates/Trains a PPO model
204206
2. Exports the model to ONNX along with normalization stats in JSON
205-
3. Normalizes and Inferences using onnxruntime-web to achieve similar results
207+
3. Runs in the browser with normalization using onnxruntime-web to achieve similar results
206208

207209
Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS
208210

209211
.. code-block:: python
210-
212+
211213
import torch as th
212214
213215
from stable_baselines3 import SAC
@@ -219,7 +221,7 @@ Below is a simple example with converting to ONNX then inferencing without postp
219221
self.actor = actor
220222
221223
def forward(self, observation: th.Tensor) -> th.Tensor:
222-
# NOTE: You may have to postprocess (unnormalize or renormalize)
224+
# NOTE: You may have to postprocess (unnormalize or renormalize)
223225
return self.actor(observation, deterministic=True)
224226
225227
@@ -239,7 +241,7 @@ Below is a simple example with converting to ONNX then inferencing without postp
239241
)
240242
241243
.. code-block:: javascript
242-
244+
243245
// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
244246
import * as ort from 'onnxruntime-web';
245247
@@ -262,55 +264,26 @@ Below is a simple example with converting to ONNX then inferencing without postp
262264
runInference();
263265
264266
265-
Export to tensorflowjs
266-
----------------------
267+
Export to TensorFlow.js
268+
-----------------------
267269

268270
.. warning::
269271

270-
As of writing this (November 2025), (https://github.com/PINTO0309/onnx2tf) does not support tensorflow js. Thus, (https://github.com/tensorflow/tfjs-converter) is used. This is not currently maintained and requires old opsets/tf versions.
271-
272+
As of November 2025, `onnx2tf <https://github.com/PINTO0309/onnx2tf>`_ does not support TensorFlow.js. Therefore, `tfjs-converter <https://github.com/tensorflow/tfjs-converter>`_ is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions.
272273

273-
In order for this to work, you must convert (SB3 => ONNX => Tensorflow => Tensorflowjs)
274274

275-
The opset version needs to be changed for the conversion. Please refer to the code above for more stable usage with a higer opset.
275+
In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js.
276276

277-
The following is a simple example that showcases the full conversion + inference
277+
The opset version needs to be changed for the conversion (``opset_version=14`` is currently required). Please refer to the code above for more stable usage with a higher opset.
278278

279-
.. code-block:: python
280-
281-
import torch as th
282-
from stable_baselines3 import SAC
279+
The following is a simple example that showcases the full conversion + inference.
283280

284-
285-
class OnnxablePolicy(th.nn.Module):
286-
def __init__(self, actor: th.nn.Module):
287-
super().__init__()
288-
self.actor = actor
289-
290-
def forward(self, observation: th.Tensor) -> th.Tensor:
291-
# NOTE: You may have to postprocess (unnormalize) actions
292-
# to the correct bounds (see commented code below)
293-
return self.actor(observation, deterministic=True)
294-
295-
296-
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
297-
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
298-
model = SAC.load("PathToTrainedModel.zip", device="cpu")
299-
onnxable_model = OnnxablePolicy(model.policy.actor)
300-
301-
observation_size = model.observation_space.shape
302-
dummy_input = th.randn(1, *observation_size)
303-
th.onnx.export(
304-
onnxable_model,
305-
dummy_input,
306-
"my_sac_actor.onnx",
307-
opset_version=14, # because of the outdated tf-js converter you have to use an old opset
308-
input_names=["input"],
309-
)
281+
Please refer to the previous sections for the first step (SB3 => ONNX).
282+
The main difference is that you need to specify ``opset_version=14``.
310283

311284
.. code-block:: python
312-
313-
# Tested with python3.10
285+
286+
# Tested with python3.10
314287
# Then install these dependencies in a fresh env
315288
"""
316289
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26

0 commit comments

Comments
 (0)