Skip to content

Commit d3b7ba7

Browse files
Copilotaraffin
andauthored
Add plotting documentation (#2168)
* Initial plan * Add comprehensive plotting documentation and update changelog Co-authored-by: araffin <[email protected]> * Cleanup plotting guide * Cleanup doc * Reorganize plotting docs to highlight RL Zoo3 as recommended approach with CLI examples Co-authored-by: araffin <[email protected]> * Fix hallucinations * Cleanup doc --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: araffin <[email protected]> Co-authored-by: Antonin RAFFIN <[email protected]>
1 parent 3b249c0 commit d3b7ba7

File tree

4 files changed

+200
-2
lines changed

4 files changed

+200
-2
lines changed

docs/guide/plotting.rst

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
.. _plotting:
2+
3+
========
4+
Plotting
5+
========
6+
7+
8+
Stable Baselines3 provides utilities for plotting training results to monitor and visualize your agent's learning progress.
9+
The plotting functionality is provided by the ``results_plotter`` module, which can load monitor files created during training and generate various plots.
10+
11+
.. note::
12+
13+
For plotting, we recommend using the
14+
`RL Baselines3 Zoo plotting scripts <https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html>`_
15+
which provide plotting capabilities with confidence intervals, and publication-ready visualizations.
16+
17+
18+
Recommended Approach: RL Baselines3 Zoo Plotting
19+
================================================
20+
21+
To have good plotting capabilities, including:
22+
23+
- Comparing results across different environments
24+
- Publication-ready plots with confidence intervals
25+
- Evaluation plots with error bars
26+
27+
We recommend using the plotting scripts from `RL Baselines3 Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:
28+
29+
- `plot_train.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_train.py>`_: For training plots
30+
- `all_plots.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/all_plots.py>`_: For evaluation plots, to post-process the result
31+
- `plot_from_file.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_from_file.py>`_: For more advanced plotting from post-processed results
32+
33+
These scripts provide additional features not available in the basic SB3 plotting utilities.
34+
35+
36+
Installation
37+
------------
38+
39+
First, install RL Baselines3 Zoo:
40+
41+
.. code-block:: bash
42+
43+
pip install rl_zoo3[plots]
44+
45+
Basic Training Plot Examples
46+
----------------------------
47+
48+
.. code-block:: bash
49+
50+
# Train an agent
51+
python -m rl_zoo3.train --algo ppo --env CartPole-v1 -f logs/
52+
53+
# Plot training results for a single algorithm
54+
python -m rl_zoo3.plots.plot_train --algo ppo --env CartPole-v1 --exp-folder logs/
55+
56+
57+
Evaluation and Comparison Plots
58+
-------------------------------
59+
60+
.. code-block:: bash
61+
62+
# Generate evaluation plots and save post-processed results
63+
# in `logs/demo_plots.pkl` in order to use `plot_from_file`
64+
python -m rl_zoo3.plots.all_plots --algo ppo sac -e Pendulum-v1 -f logs/ -o logs/demo_plots
65+
66+
# More advanced plotting from post-processed results (with confidence intervals)
67+
python -m rl_zoo3.plots.plot_from_file -i logs/demo_plots.pkl --rliable --ci-size 0.95
68+
69+
70+
For more examples, please read the
71+
`RL Baselines3 Zoo plotting guide <https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html>`_.
72+
73+
74+
Real-Time Monitoring
75+
====================
76+
77+
For real-time monitoring during training, consider using the plotting functions within callbacks
78+
(see the `Callbacks guide <callbacks.html>`_) or integrating with tools like `Tensorboard <tensorboard.html>`_ or Weights & Biases
79+
(see the `Integrations guide <integrations.html>`_).
80+
81+
Monitor File Format
82+
===================
83+
84+
The ``Monitor`` wrapper saves training data in CSV format with the following columns:
85+
86+
- ``r``: Episode reward
87+
- ``l``: Episode length (number of steps)
88+
- ``t``: Timestamp (wall-clock time when episode ended)
89+
90+
Additional columns may be present if you log custom metrics in the environment"s info dict.
91+
92+
.. note::
93+
94+
The plotting functions automatically handle multiple monitor files from the same directory,
95+
which occurs when using vectorized environments. The episodes are loaded and sorted by timestamp
96+
to maintain proper chronological order.
97+
98+
Basic SB3 Plotting (Simple Use Cases)
99+
======================================
100+
101+
Basic Plotting: Single Training Run
102+
-----------------------------------
103+
104+
The simplest way to plot training results is to use the ``plot_results`` function after training an agent.
105+
This function reads monitor files created by the ``Monitor`` wrapper and plots the episode rewards over time.
106+
107+
.. code-block:: python
108+
109+
import os
110+
import gymnasium as gym
111+
import matplotlib.pyplot as plt
112+
113+
from stable_baselines3 import PPO
114+
from stable_baselines3.common.monitor import Monitor
115+
from stable_baselines3.common.results_plotter import plot_results
116+
from stable_baselines3.common import results_plotter
117+
118+
# Create log directory
119+
log_dir = "tmp/"
120+
os.makedirs(log_dir, exist_ok=True)
121+
122+
# Create and wrap the environment with Monitor
123+
env = gym.make("CartPole-v1")
124+
env = Monitor(env, log_dir)
125+
126+
# Train the agent
127+
model = PPO("MlpPolicy", env, verbose=1)
128+
model.learn(total_timesteps=20_000)
129+
130+
# Plot the results
131+
plot_results([log_dir], 20_000, results_plotter.X_TIMESTEPS, "PPO CartPole")
132+
plt.show()
133+
134+
135+
Different Plotting Modes
136+
------------------------
137+
138+
The plotting functions support three different x-axis modes:
139+
140+
- ``X_TIMESTEPS``: Plot rewards vs. timesteps (default)
141+
- ``X_EPISODES``: Plot rewards vs. episode number
142+
- ``X_WALLTIME``: Plot rewards vs. wall-clock time in hours
143+
144+
.. code-block:: python
145+
146+
import matplotlib.pyplot as plt
147+
from stable_baselines3.common import results_plotter
148+
149+
# Plot by timesteps (shows sample efficiency)
150+
# plot_results([log_dir], None, results_plotter.X_TIMESTEPS, "Rewards vs Timesteps")
151+
# By Episodes
152+
plot_results([log_dir], None, results_plotter.X_EPISODES, "Rewards vs Episodes")
153+
# plot_results([log_dir], None, results_plotter.X_WALLTIME, "Rewards vs Time")
154+
155+
plt.tight_layout()
156+
plt.show()
157+
158+
159+
Advanced Plotting with Manual Data Processing
160+
---------------------------------------------
161+
162+
For more control over the plotting, you can use the underlying functions to process the data manually:
163+
164+
.. code-block:: python
165+
166+
import numpy as np
167+
import matplotlib.pyplot as plt
168+
from stable_baselines3.common.monitor import load_results
169+
from stable_baselines3.common.results_plotter import ts2xy, window_func
170+
171+
# Load the results
172+
df = load_results(log_dir)
173+
174+
# Convert dataframe (x=timesteps, y=episodic return)
175+
x, y = ts2xy(df, "timesteps")
176+
177+
# Plot raw data
178+
plt.figure(figsize=(10, 6))
179+
plt.subplot(2, 1, 1)
180+
plt.scatter(x, y, s=2, alpha=0.6)
181+
plt.xlabel("Timesteps")
182+
plt.ylabel("Episode Reward")
183+
plt.title("Raw Episode Rewards")
184+
185+
# Plot smoothed data with custom window
186+
plt.subplot(2, 1, 2)
187+
if len(x) >= 50: # Only smooth if we have enough data
188+
x_smooth, y_smooth = window_func(x, y, 50, np.mean)
189+
plt.plot(x_smooth, y_smooth, linewidth=2)
190+
plt.xlabel("Timesteps")
191+
plt.ylabel("Average Episode Reward (50-episode window)"")
192+
plt.title("Smoothed Episode Rewards")
193+
194+
plt.tight_layout()
195+
plt.show()

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ Main Features
5454
guide/rl_zoo
5555
guide/sb3_contrib
5656
guide/sbx
57+
guide/plotting
5758
guide/imitation
5859
guide/migration
5960
guide/checking_nan

docs/misc/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Others:
3333

3434
Documentation:
3535
^^^^^^^^^^^^^^
36+
- Added plotting documentation and examples
3637
- Added documentation clarifying gSDE (Generalized State-Dependent Exploration) inference behavior for PPO, SAC, and A2C algorithms
3738

3839

stable_baselines3/common/results_plotter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl
4747
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]:
4848
"""
4949
Decompose a data frame variable to x and ys
50+
(y = episodic return)
5051
5152
:param data_frame: the input data
52-
:param x_axis: the axis for the x and y output
53+
:param x_axis: the x-axis for the x and y output
5354
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
5455
:return: the x and y output
5556
"""
@@ -64,7 +65,7 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray
6465
x_var = data_frame.t.values / 3600.0 # type: ignore[operator, assignment]
6566
y_var = data_frame.r.values
6667
else:
67-
raise NotImplementedError
68+
raise NotImplementedError(f"Unsupported {x_axis=}, please use one of {POSSIBLE_X_AXES}")
6869
return x_var, y_var # type: ignore[return-value]
6970

7071

0 commit comments

Comments
 (0)