Skip to content

Commit 9c314ef

Browse files
authored
Merge branch 'master' into copilot/fix-2779f837-79a9-4b3c-b3a3-48a6ab7d437f
2 parents ff2c000 + bab847b commit 9c314ef

File tree

7 files changed

+167
-13
lines changed

7 files changed

+167
-13
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ cd stable-baselines3/
3333
2. Install Stable-Baselines3 in develop mode, with support for building the docs and running tests:
3434

3535
```bash
36-
pip install -e .[docs,tests,extra]
36+
pip install -e '.[docs,tests,extra]'
3737
```
3838

3939
## Codestyle

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ Actions `gymnasium.spaces`:
210210
## Testing the installation
211211
### Install dependencies
212212
```sh
213-
pip install -e .[docs,tests,extra]
213+
pip install -e '.[docs,tests,extra]'
214214
```
215215
### Run tests
216216
All unit tests in stable baselines3 can be run using `pytest` runner:

docs/guide/export.rst

Lines changed: 154 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,15 @@ to do inference in another framework.
2929

3030

3131
Export to ONNX
32-
-----------------
32+
--------------
3333

3434

3535
If you are using PyTorch 2.0+ and ONNX Opset 14+, you can easily export SB3 policies using the following code:
3636

3737

3838
.. warning::
3939

40-
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions
41-
(clip or unscale the action to the correct space).
40+
The following returns normalized actions and doesn't include the `post-processing <https://github.com/DLR-RM/stable-baselines3/blob/a9273f968eaf8c6e04302a07d803eebfca6e7e86/stable_baselines3/common/policies.py#L370-L377>`_ step that is done with continuous actions (clip or unscale the action to the correct space).
4241

4342

4443
.. code-block:: python
@@ -192,11 +191,159 @@ There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl
192191
action_jit = loaded_module(dummy_input)
193192
194193
195-
Export to tensorflowjs / ONNX-JS
196-
--------------------------------
194+
Export to ONNX-JS / ONNX Runtime Web
195+
------------------------------------
197196

198-
TODO: contributors help is welcomed!
199-
Probably a good starting point: https://github.com/elliotwaite/pytorch-to-javascript-with-onnx-js
197+
Official documentation: https://onnxruntime.ai/docs/tutorials/web/build-web-app.html
198+
199+
Full example code: https://github.com/JonathanColetti/CarDodgingGym
200+
201+
Demo: https://jonathancoletti.github.io/CarDodgingGym
202+
203+
The code linked above is a complete example (using car dodging environment) that:
204+
205+
1. Creates/Trains a PPO model
206+
2. Exports the model to ONNX along with normalization stats in JSON
207+
3. Runs in the browser with normalization using onnxruntime-web to achieve similar results
208+
209+
Below is a simple example with converting to ONNX then inferencing without postprocess in ONNX-JS
210+
211+
.. code-block:: python
212+
213+
import torch as th
214+
215+
from stable_baselines3 import SAC
216+
217+
218+
class OnnxablePolicy(th.nn.Module):
219+
def __init__(self, actor: th.nn.Module):
220+
super().__init__()
221+
self.actor = actor
222+
223+
def forward(self, observation: th.Tensor) -> th.Tensor:
224+
# NOTE: You may have to postprocess (unnormalize or renormalize)
225+
return self.actor(observation, deterministic=True)
226+
227+
228+
# Example: model = SAC("MlpPolicy", "Pendulum-v1")
229+
SAC("MlpPolicy", "Pendulum-v1").save("PathToTrainedModel.zip")
230+
model = SAC.load("PathToTrainedModel.zip", device="cpu")
231+
onnxable_model = OnnxablePolicy(model.policy.actor)
232+
233+
observation_size = model.observation_space.shape
234+
dummy_input = th.randn(1, *observation_size)
235+
th.onnx.export(
236+
onnxable_model,
237+
dummy_input,
238+
"my_sac_actor.onnx",
239+
opset_version=17,
240+
input_names=["input"],
241+
)
242+
243+
.. code-block:: javascript
244+
245+
// Install using `npm install onnxruntime-web` (tested with version 1.19) or using cdn
246+
import * as ort from 'onnxruntime-web';
247+
248+
async function runInference() {
249+
const session = await ort.InferenceSession.create('my_sac_actor.onnx');
250+
251+
// The observation_size = 3 (for Pendulum-v1)
252+
const inputData = Float32Array.from([0.1, -0.2, 0.3]);
253+
254+
const inputTensor = new ort.Tensor('float32', inputData, [1, 3]);
255+
256+
const results = await session.run({ input: inputTensor });
257+
258+
const outputName = session.outputNames[0];
259+
const action = results[outputName].data;
260+
261+
console.log('Predicted action=', action);
262+
}
263+
264+
runInference();
265+
266+
267+
Export to TensorFlow.js
268+
-----------------------
269+
270+
.. warning::
271+
272+
As of November 2025, `onnx2tf <https://github.com/PINTO0309/onnx2tf>`_ does not support TensorFlow.js. Therefore, `tfjs-converter <https://github.com/tensorflow/tfjs-converter>`_ is used instead. However, tfjs-converter is not currently maintained and requires older opsets and TensorFlow versions.
273+
274+
275+
In order for this to work, you have to do multiple conversions: SB3 => ONNX => TensorFlow => TensorFlow.js.
276+
277+
The opset version needs to be changed for the conversion (``opset_version=14`` is currently required). Please refer to the code above for more stable usage with a higher opset.
278+
279+
The following is a simple example that showcases the full conversion + inference.
280+
281+
Please refer to the previous sections for the first step (SB3 => ONNX).
282+
The main difference is that you need to specify ``opset_version=14``.
283+
284+
.. code-block:: python
285+
286+
# Tested with python3.10
287+
# Then install these dependencies in a fresh env
288+
"""
289+
pip install --use-deprecated=legacy-resolver tensorflow==2.13.0 keras==2.13.1 onnx==1.16.0 onnx-tf==1.9.0 tensorflow-probability==0.21.0 tensorflowjs==4.15.0 jax==0.4.26 jaxlib==0.4.26
290+
"""
291+
# Then run this codeblock
292+
# If there are no errors (the folder is structure correctly) then
293+
"""
294+
# tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
295+
"""
296+
297+
# If you get an error exporting using `tensorflowjs_converter` then upgrade tensorflow
298+
"""
299+
pip install --upgrade tensorflow tensorflow-decision-forests tensorflowjs
300+
"""
301+
# And retry with and it should work (do not rerun this codeblock)
302+
"""
303+
tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model tf_model tfjs_model
304+
"""
305+
306+
import onnx
307+
import onnx_tf.backend
308+
import tensorflow as tf
309+
310+
ONNX_FILE_PATH = "my_sac_actor.onnx"
311+
MODEL_PATH = "tf_model"
312+
313+
onnx_model = onnx.load(ONNX_FILE_PATH)
314+
onnx.checker.check_model(onnx_model)
315+
print(onnx.helper.printable_graph(onnx_model.graph))
316+
317+
print('Converting ONNX to TF...')
318+
tf_rep = onnx_tf.backend.prepare(onnx_model)
319+
tf_rep.export_graph(MODEL_PATH)
320+
# After this do not forget to use `tensorflowjs_converter`
321+
322+
323+
.. code-block:: javascript
324+
325+
import * as tf from 'https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/+esm';
326+
// Post processing not included
327+
async function runInference() {
328+
const MODEL_URL = './tfjs_model/model.json';
329+
330+
const model = await tf.loadGraphModel(MODEL_URL);
331+
332+
// Observation_size is 3 for Pendulum-v1
333+
const inputData = [1.0, 0.0, 0.0];
334+
const inputTensor = tf.tensor2d([inputData], [1, 3]);
335+
336+
const resultTensor = model.execute(inputTensor);
337+
338+
const action = await resultTensor.data();
339+
340+
console.log('Predicted action=', action);
341+
342+
inputTensor.dispose();
343+
resultTensor.dispose();
344+
}
345+
346+
runInference();
200347
201348
202349
Export to TFLite / Coral (Edge TPU)

docs/guide/integrations.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Weights & Biases
99

1010
Weights & Biases provides a callback for experiment tracking that allows to visualize and share results.
1111

12-
The full documentation is available here: https://docs.wandb.ai/guides/integrations/other/stable-baselines-3
12+
The full documentation is available here: https://docs.wandb.ai/models/integrations/stable-baselines-3
1313

1414
.. code-block:: python
1515

docs/misc/changelog.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Bug Fixes:
1919
- Fixed env checker to properly handle ``Sequence`` observation spaces when nested inside composite spaces (``Dict``, ``Tuple``, ``OneOf``) (@copilot)
2020
- Update env checker to warn users when using Graph space (@dhruvmalik007).
2121
- Fixed memory leak in ``VecVideoRecorder`` where ``recorded_frames`` stayed in memory due to reference in the moviepy clip (@copilot)
22+
- Remove double space in `StopTrainingOnRewardThreshold` callback message (@sea-bass)
2223

2324
`SB3-Contrib`_
2425
^^^^^^^^^^^^^^
@@ -42,6 +43,11 @@ Documentation:
4243
- Documented Atari wrapper reset behavior where ``env.reset()`` may perform a no-op step instead of truly resetting when ``terminal_on_life_loss=True`` (default), and how to avoid this behavior by setting ``terminal_on_life_loss=False``
4344
- Clarified comment in ``_sample_action()`` method to better explain action scaling behavior for off-policy algorithms (@copilot)
4445
- Added sb3-plus to projects page
46+
- Added example usage of ONNX JS
47+
- Updated link to paper of community project DeepNetSlice (@AlexPasqua)
48+
- Added example usage of Tensorflow JS
49+
- Included exact versions in ONNX JS and example project
50+
- Made step 2 (`pip install`) of `CONTRIBUTING.md` more robust
4551

4652

4753
Release 2.7.0 (2025-07-25)
@@ -1898,4 +1904,4 @@ And all the contributors:
18981904
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
18991905
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
19001906
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
1901-
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore
1907+
@brn-dev @jmacglashan @kplers @MarcDcls @chrisgao99 @pstahlhofen @akanto @Trenza1ore @JonathanColetti

docs/misc/projects.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ intelligent agents to perform network slice placement.
228228

229229
| Author: Alex Pasquali
230230
| Github: https://github.com/AlexPasqua/DeepNetSlice
231-
| Paper: **under review** (citation instructions on the project's README.md) -> see this Master's Thesis for the moment: https://etd.adm.unipi.it/theses/available/etd-01182023-110038/unrestricted/Tesi_magistrale_Pasquali_Alex.pdf
231+
| Paper: https://ieeexplore.ieee.org/document/10625023
232+
| Associated Master's Thesis: https://etd.adm.unipi.it/theses/available/etd-01182023-110038/unrestricted/Tesi_magistrale_Pasquali_Alex.pdf
232233
233234

234235
PokemonRedExperiments

stable_baselines3/common/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def _on_step(self) -> bool:
565565
if self.verbose >= 1 and not continue_training:
566566
print(
567567
f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
568-
f" is above the threshold {self.reward_threshold}"
568+
f"is above the threshold {self.reward_threshold}"
569569
)
570570
return continue_training
571571

0 commit comments

Comments
 (0)