44Plotting
55========
66
7- Stable Baselines3 provides utilities for plotting training results to monitor and visualize your agent' s learning progress.
7+ Stable Baselines3 provides utilities for plotting training results to monitor and visualize your agent" s learning progress.
88The main plotting functionality is provided by the ``results_plotter `` module, which can load monitor files created during training and generate various plots.
99
1010
@@ -66,43 +66,6 @@ The plotting functions support three different x-axis modes:
6666 plt.show()
6767
6868
69- Plotting Multiple Runs
70- ======================
71-
72- To plot multiple training runs together:
73-
74- .. code-block :: python
75-
76- import os
77- import gymnasium as gym
78- import matplotlib.pyplot as plt
79-
80- from stable_baselines3 import PPO
81- from stable_baselines3.common.monitor import Monitor
82- from stable_baselines3.common.results_plotter import plot_results
83- from stable_baselines3.common import results_plotter
84-
85- # Train multiple agents with different runs
86- runs = [(" PPO_1" , PPO ), (" PPO_2" , PPO )]
87- log_dirs = []
88-
89- for name, algorithm in runs:
90- log_dir = f " logs/ { name} / "
91- os.makedirs(log_dir, exist_ok = True )
92- log_dirs.append(log_dir)
93-
94- env = gym.make(" CartPole-v1" )
95- env = Monitor(env, log_dir)
96-
97- model = algorithm(" MlpPolicy" , env, verbose = 0 )
98- model.learn(total_timesteps = 20_000 )
99-
100- # Plot all results together
101- plot_results(log_dirs, 20_000 , results_plotter.X_TIMESTEPS , " Algorithm Comparison" )
102- plt.legend([" " , " PPO_1" , " " , " PPO_2" ])
103- plt.show()
104-
105-
10669 Advanced Plotting with Manual Data Processing
10770=============================================
10871
@@ -118,113 +81,30 @@ For more control over the plotting, you can use the underlying functions to proc
11881 # Load the results
11982 df = load_results(log_dir)
12083
121- # Convert to x , y coordinates
122- x, y = ts2xy(df, ' timesteps' )
84+ # Convert dataframe (x=timesteps , y=episodic return)
85+ x, y = ts2xy(df, " timesteps" )
12386
12487 # Plot raw data
12588 plt.figure(figsize = (10 , 6 ))
12689 plt.subplot(2 , 1 , 1 )
12790 plt.scatter(x, y, s = 2 , alpha = 0.6 )
128- plt.xlabel(' Timesteps' )
129- plt.ylabel(' Episode Reward' )
130- plt.title(' Raw Episode Rewards' )
91+ plt.xlabel(" Timesteps" )
92+ plt.ylabel(" Episode Reward" )
93+ plt.title(" Raw Episode Rewards" )
13194
13295 # Plot smoothed data with custom window
13396 plt.subplot(2 , 1 , 2 )
13497 if len (x) >= 50 : # Only smooth if we have enough data
13598 x_smooth, y_smooth = window_func(x, y, 50 , np.mean)
13699 plt.plot(x_smooth, y_smooth, linewidth = 2 )
137- plt.xlabel(' Timesteps' )
138- plt.ylabel(' Average Episode Reward (50-episode window)' )
139- plt.title(' Smoothed Episode Rewards' )
100+ plt.xlabel(" Timesteps" )
101+ plt.ylabel(" Average Episode Reward (50-episode window)" " )
102+ plt.title(" Smoothed Episode Rewards" )
140103
141104 plt.tight_layout()
142105 plt.show()
143106
144107
145- Plotting Success Rates
146- ======================
147-
148- For environments that support it (e.g., goal-conditioned environments), you can also plot success rates:
149-
150- .. code-block :: python
151-
152- import pandas as pd
153- import numpy as np
154- from stable_baselines3.common.monitor import load_results
155-
156- # For environments that log success rates in info
157- # The monitor will log 'is_success' if present in info dict
158- df = load_results(log_dir)
159-
160- # Check if success data is available
161- if ' is_success' in df.columns:
162- # Calculate rolling success rate
163- window_size = 100
164- success_rate = df[' is_success' ].rolling(window = window_size).mean()
165-
166- plt.figure(figsize = (10 , 4 ))
167- plt.plot(success_rate)
168- plt.xlabel(' Episode' )
169- plt.ylabel(' Success Rate' )
170- plt.title(f ' Success Rate (rolling { window_size} -episode average) ' )
171- plt.show()
172- else :
173- print (" No success rate data available in monitor logs" )
174-
175-
176- Customizing Plot Appearance
177- ===========================
178-
179- You can customize the plots by modifying matplotlib parameters:
180-
181- .. code-block :: python
182-
183- import matplotlib.pyplot as plt
184- from stable_baselines3.common.results_plotter import plot_curves, ts2xy
185- from stable_baselines3.common.monitor import load_results
186-
187- # Load and process data
188- df = load_results(log_dir)
189- x, y = ts2xy(df, ' timesteps' )
190-
191- # Create custom plot
192- plt.figure(figsize = (12 , 6 ))
193-
194- # Use the plot_curves function with custom figure size
195- plot_curves([(x, y)], ' timesteps' , ' Custom Training Progress' , figsize = (12 , 6 ))
196-
197- # Customize appearance
198- plt.grid(True , alpha = 0.3 )
199- plt.xlabel(' Training Timesteps' , fontsize = 12 )
200- plt.ylabel(' Episode Reward' , fontsize = 12 )
201- plt.title(' Training Progress with Custom Styling' , fontsize = 14 , fontweight = ' bold' )
202-
203- plt.show()
204-
205-
206- Saving Plots
207- ============
208-
209- To save plots instead of displaying them:
210-
211- .. code-block :: python
212-
213- import matplotlib.pyplot as plt
214- from stable_baselines3.common.results_plotter import plot_results
215- from stable_baselines3.common import results_plotter
216-
217- # Create the plot but don't show it
218- plot_results([log_dir], None , results_plotter.X_TIMESTEPS , " Training Results" )
219-
220- # Save as high-quality image
221- plt.savefig(" training_results.png" , dpi = 300 , bbox_inches = ' tight' )
222- plt.savefig(" training_results.pdf" , bbox_inches = ' tight' ) # Vector format
223-
224- # Close the figure to free memory
225- plt.close()
226-
227-
228108Monitor File Format
229109================== =
230110
@@ -234,12 +114,12 @@ The ``Monitor`` wrapper saves training data in CSV format with the following col
234114- `` l`` : Episode length (number of steps)
235115- `` t`` : Timestamp (wall- clock time when episode ended)
236116
237- Additional columns may be present if you log custom metrics in the environment' s info dict.
117+ Additional columns may be present if you log custom metrics in the environment" s info dict.
238118
239119.. note::
240120
241121 The plotting functions automatically handle multiple monitor files from the same directory,
242- which occurs when using vectorized environments. The files are loaded and sorted by timestamp
122+ which occurs when using vectorized environments. The episodes are loaded and sorted by timestamp
243123 to maintain proper chronological order.
244124
245125
0 commit comments