Skip to content

Commit 3a1bfa8

Browse files
committed
Cleanup doc
1 parent 98d0cd5 commit 3a1bfa8

File tree

2 files changed

+14
-133
lines changed

2 files changed

+14
-133
lines changed

docs/guide/plotting.rst

Lines changed: 11 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Plotting
55
========
66

7-
Stable Baselines3 provides utilities for plotting training results to monitor and visualize your agent's learning progress.
7+
Stable Baselines3 provides utilities for plotting training results to monitor and visualize your agent"s learning progress.
88
The main plotting functionality is provided by the ``results_plotter`` module, which can load monitor files created during training and generate various plots.
99

1010

@@ -66,43 +66,6 @@ The plotting functions support three different x-axis modes:
6666
plt.show()
6767
6868
69-
Plotting Multiple Runs
70-
======================
71-
72-
To plot multiple training runs together:
73-
74-
.. code-block:: python
75-
76-
import os
77-
import gymnasium as gym
78-
import matplotlib.pyplot as plt
79-
80-
from stable_baselines3 import PPO
81-
from stable_baselines3.common.monitor import Monitor
82-
from stable_baselines3.common.results_plotter import plot_results
83-
from stable_baselines3.common import results_plotter
84-
85-
# Train multiple agents with different runs
86-
runs = [("PPO_1", PPO), ("PPO_2", PPO)]
87-
log_dirs = []
88-
89-
for name, algorithm in runs:
90-
log_dir = f"logs/{name}/"
91-
os.makedirs(log_dir, exist_ok=True)
92-
log_dirs.append(log_dir)
93-
94-
env = gym.make("CartPole-v1")
95-
env = Monitor(env, log_dir)
96-
97-
model = algorithm("MlpPolicy", env, verbose=0)
98-
model.learn(total_timesteps=20_000)
99-
100-
# Plot all results together
101-
plot_results(log_dirs, 20_000, results_plotter.X_TIMESTEPS, "Algorithm Comparison")
102-
plt.legend(["", "PPO_1", "", "PPO_2"])
103-
plt.show()
104-
105-
10669
Advanced Plotting with Manual Data Processing
10770
=============================================
10871

@@ -118,113 +81,30 @@ For more control over the plotting, you can use the underlying functions to proc
11881
# Load the results
11982
df = load_results(log_dir)
12083
121-
# Convert to x, y coordinates
122-
x, y = ts2xy(df, 'timesteps')
84+
# Convert dataframe (x=timesteps, y=episodic return)
85+
x, y = ts2xy(df, "timesteps")
12386
12487
# Plot raw data
12588
plt.figure(figsize=(10, 6))
12689
plt.subplot(2, 1, 1)
12790
plt.scatter(x, y, s=2, alpha=0.6)
128-
plt.xlabel('Timesteps')
129-
plt.ylabel('Episode Reward')
130-
plt.title('Raw Episode Rewards')
91+
plt.xlabel("Timesteps")
92+
plt.ylabel("Episode Reward")
93+
plt.title("Raw Episode Rewards")
13194
13295
# Plot smoothed data with custom window
13396
plt.subplot(2, 1, 2)
13497
if len(x) >= 50: # Only smooth if we have enough data
13598
x_smooth, y_smooth = window_func(x, y, 50, np.mean)
13699
plt.plot(x_smooth, y_smooth, linewidth=2)
137-
plt.xlabel('Timesteps')
138-
plt.ylabel('Average Episode Reward (50-episode window)')
139-
plt.title('Smoothed Episode Rewards')
100+
plt.xlabel("Timesteps")
101+
plt.ylabel("Average Episode Reward (50-episode window)"")
102+
plt.title("Smoothed Episode Rewards")
140103
141104
plt.tight_layout()
142105
plt.show()
143106
144107
145-
Plotting Success Rates
146-
======================
147-
148-
For environments that support it (e.g., goal-conditioned environments), you can also plot success rates:
149-
150-
.. code-block:: python
151-
152-
import pandas as pd
153-
import numpy as np
154-
from stable_baselines3.common.monitor import load_results
155-
156-
# For environments that log success rates in info
157-
# The monitor will log 'is_success' if present in info dict
158-
df = load_results(log_dir)
159-
160-
# Check if success data is available
161-
if 'is_success' in df.columns:
162-
# Calculate rolling success rate
163-
window_size = 100
164-
success_rate = df['is_success'].rolling(window=window_size).mean()
165-
166-
plt.figure(figsize=(10, 4))
167-
plt.plot(success_rate)
168-
plt.xlabel('Episode')
169-
plt.ylabel('Success Rate')
170-
plt.title(f'Success Rate (rolling {window_size}-episode average)')
171-
plt.show()
172-
else:
173-
print("No success rate data available in monitor logs")
174-
175-
176-
Customizing Plot Appearance
177-
===========================
178-
179-
You can customize the plots by modifying matplotlib parameters:
180-
181-
.. code-block:: python
182-
183-
import matplotlib.pyplot as plt
184-
from stable_baselines3.common.results_plotter import plot_curves, ts2xy
185-
from stable_baselines3.common.monitor import load_results
186-
187-
# Load and process data
188-
df = load_results(log_dir)
189-
x, y = ts2xy(df, 'timesteps')
190-
191-
# Create custom plot
192-
plt.figure(figsize=(12, 6))
193-
194-
# Use the plot_curves function with custom figure size
195-
plot_curves([(x, y)], 'timesteps', 'Custom Training Progress', figsize=(12, 6))
196-
197-
# Customize appearance
198-
plt.grid(True, alpha=0.3)
199-
plt.xlabel('Training Timesteps', fontsize=12)
200-
plt.ylabel('Episode Reward', fontsize=12)
201-
plt.title('Training Progress with Custom Styling', fontsize=14, fontweight='bold')
202-
203-
plt.show()
204-
205-
206-
Saving Plots
207-
============
208-
209-
To save plots instead of displaying them:
210-
211-
.. code-block:: python
212-
213-
import matplotlib.pyplot as plt
214-
from stable_baselines3.common.results_plotter import plot_results
215-
from stable_baselines3.common import results_plotter
216-
217-
# Create the plot but don't show it
218-
plot_results([log_dir], None, results_plotter.X_TIMESTEPS, "Training Results")
219-
220-
# Save as high-quality image
221-
plt.savefig("training_results.png", dpi=300, bbox_inches='tight')
222-
plt.savefig("training_results.pdf", bbox_inches='tight') # Vector format
223-
224-
# Close the figure to free memory
225-
plt.close()
226-
227-
228108
Monitor File Format
229109
===================
230110
@@ -234,12 +114,12 @@ The ``Monitor`` wrapper saves training data in CSV format with the following col
234114
- ``l``: Episode length (number of steps)
235115
- ``t``: Timestamp (wall-clock time when episode ended)
236116
237-
Additional columns may be present if you log custom metrics in the environment's info dict.
117+
Additional columns may be present if you log custom metrics in the environment"s info dict.
238118
239119
.. note::
240120
241121
The plotting functions automatically handle multiple monitor files from the same directory,
242-
which occurs when using vectorized environments. The files are loaded and sorted by timestamp
122+
which occurs when using vectorized environments. The episodes are loaded and sorted by timestamp
243123
to maintain proper chronological order.
244124
245125

stable_baselines3/common/results_plotter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray, window: int, func: Callabl
4747
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray]:
4848
"""
4949
Decompose a data frame variable to x and ys
50+
(y = episodic return)
5051
5152
:param data_frame: the input data
52-
:param x_axis: the axis for the x and y output
53+
:param x_axis: the x-axis for the x and y output
5354
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
5455
:return: the x and y output
5556
"""
@@ -64,7 +65,7 @@ def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> tuple[np.ndarray, np.ndarray
6465
x_var = data_frame.t.values / 3600.0 # type: ignore[operator, assignment]
6566
y_var = data_frame.r.values
6667
else:
67-
raise NotImplementedError
68+
raise NotImplementedError(f"Unsupported {x_axis=}, please use one of {POSSIBLE_X_AXES}")
6869
return x_var, y_var # type: ignore[return-value]
6970

7071

0 commit comments

Comments
 (0)