|
| 1 | +.. _plotting: |
| 2 | + |
| 3 | +======== |
| 4 | +Plotting |
| 5 | +======== |
| 6 | + |
| 7 | + |
| 8 | +Stable Baselines3 provides utilities for plotting training results to monitor and visualize your agent's learning progress. |
| 9 | +The plotting functionality is provided by the ``results_plotter`` module, which can load monitor files created during training and generate various plots. |
| 10 | + |
| 11 | +.. note:: |
| 12 | + |
| 13 | + For plotting, we recommend using the |
| 14 | + `RL Baselines3 Zoo plotting scripts <https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html>`_ |
| 15 | + which provide plotting capabilities with confidence intervals, and publication-ready visualizations. |
| 16 | + |
| 17 | + |
| 18 | +Recommended Approach: RL Baselines3 Zoo Plotting |
| 19 | +================================================ |
| 20 | + |
| 21 | +To have good plotting capabilities, including: |
| 22 | + |
| 23 | +- Comparing results across different environments |
| 24 | +- Publication-ready plots with confidence intervals |
| 25 | +- Evaluation plots with error bars |
| 26 | + |
| 27 | +We recommend using the plotting scripts from `RL Baselines3 Zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_: |
| 28 | + |
| 29 | +- `plot_train.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_train.py>`_: For training plots |
| 30 | +- `all_plots.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/all_plots.py>`_: For evaluation plots, to post-process the result |
| 31 | +- `plot_from_file.py <https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/rl_zoo3/plots/plot_from_file.py>`_: For more advanced plotting from post-processed results |
| 32 | + |
| 33 | +These scripts provide additional features not available in the basic SB3 plotting utilities. |
| 34 | + |
| 35 | + |
| 36 | +Installation |
| 37 | +------------ |
| 38 | + |
| 39 | +First, install RL Baselines3 Zoo: |
| 40 | + |
| 41 | +.. code-block:: bash |
| 42 | +
|
| 43 | + pip install rl_zoo3[plots] |
| 44 | +
|
| 45 | +Basic Training Plot Examples |
| 46 | +---------------------------- |
| 47 | + |
| 48 | +.. code-block:: bash |
| 49 | +
|
| 50 | + # Train an agent |
| 51 | + python -m rl_zoo3.train --algo ppo --env CartPole-v1 -f logs/ |
| 52 | +
|
| 53 | + # Plot training results for a single algorithm |
| 54 | + python -m rl_zoo3.plots.plot_train --algo ppo --env CartPole-v1 --exp-folder logs/ |
| 55 | +
|
| 56 | +
|
| 57 | +Evaluation and Comparison Plots |
| 58 | +------------------------------- |
| 59 | + |
| 60 | +.. code-block:: bash |
| 61 | +
|
| 62 | + # Generate evaluation plots and save post-processed results |
| 63 | + # in `logs/demo_plots.pkl` in order to use `plot_from_file` |
| 64 | + python -m rl_zoo3.plots.all_plots --algo ppo sac -e Pendulum-v1 -f logs/ -o logs/demo_plots |
| 65 | +
|
| 66 | + # More advanced plotting from post-processed results (with confidence intervals) |
| 67 | + python -m rl_zoo3.plots.plot_from_file -i logs/demo_plots.pkl --rliable --ci-size 0.95 |
| 68 | +
|
| 69 | +
|
| 70 | +For more examples, please read the |
| 71 | +`RL Baselines3 Zoo plotting guide <https://rl-baselines3-zoo.readthedocs.io/en/master/guide/plot.html>`_. |
| 72 | + |
| 73 | + |
| 74 | +Real-Time Monitoring |
| 75 | +==================== |
| 76 | + |
| 77 | +For real-time monitoring during training, consider using the plotting functions within callbacks |
| 78 | +(see the `Callbacks guide <callbacks.html>`_) or integrating with tools like `Tensorboard <tensorboard.html>`_ or Weights & Biases |
| 79 | +(see the `Integrations guide <integrations.html>`_). |
| 80 | + |
| 81 | +Monitor File Format |
| 82 | +=================== |
| 83 | + |
| 84 | +The ``Monitor`` wrapper saves training data in CSV format with the following columns: |
| 85 | + |
| 86 | +- ``r``: Episode reward |
| 87 | +- ``l``: Episode length (number of steps) |
| 88 | +- ``t``: Timestamp (wall-clock time when episode ended) |
| 89 | + |
| 90 | +Additional columns may be present if you log custom metrics in the environment"s info dict. |
| 91 | + |
| 92 | +.. note:: |
| 93 | + |
| 94 | + The plotting functions automatically handle multiple monitor files from the same directory, |
| 95 | + which occurs when using vectorized environments. The episodes are loaded and sorted by timestamp |
| 96 | + to maintain proper chronological order. |
| 97 | + |
| 98 | +Basic SB3 Plotting (Simple Use Cases) |
| 99 | +====================================== |
| 100 | + |
| 101 | +Basic Plotting: Single Training Run |
| 102 | +----------------------------------- |
| 103 | + |
| 104 | +The simplest way to plot training results is to use the ``plot_results`` function after training an agent. |
| 105 | +This function reads monitor files created by the ``Monitor`` wrapper and plots the episode rewards over time. |
| 106 | + |
| 107 | +.. code-block:: python |
| 108 | +
|
| 109 | + import os |
| 110 | + import gymnasium as gym |
| 111 | + import matplotlib.pyplot as plt |
| 112 | +
|
| 113 | + from stable_baselines3 import PPO |
| 114 | + from stable_baselines3.common.monitor import Monitor |
| 115 | + from stable_baselines3.common.results_plotter import plot_results |
| 116 | + from stable_baselines3.common import results_plotter |
| 117 | +
|
| 118 | + # Create log directory |
| 119 | + log_dir = "tmp/" |
| 120 | + os.makedirs(log_dir, exist_ok=True) |
| 121 | +
|
| 122 | + # Create and wrap the environment with Monitor |
| 123 | + env = gym.make("CartPole-v1") |
| 124 | + env = Monitor(env, log_dir) |
| 125 | +
|
| 126 | + # Train the agent |
| 127 | + model = PPO("MlpPolicy", env, verbose=1) |
| 128 | + model.learn(total_timesteps=20_000) |
| 129 | +
|
| 130 | + # Plot the results |
| 131 | + plot_results([log_dir], 20_000, results_plotter.X_TIMESTEPS, "PPO CartPole") |
| 132 | + plt.show() |
| 133 | +
|
| 134 | +
|
| 135 | +Different Plotting Modes |
| 136 | +------------------------ |
| 137 | + |
| 138 | +The plotting functions support three different x-axis modes: |
| 139 | + |
| 140 | +- ``X_TIMESTEPS``: Plot rewards vs. timesteps (default) |
| 141 | +- ``X_EPISODES``: Plot rewards vs. episode number |
| 142 | +- ``X_WALLTIME``: Plot rewards vs. wall-clock time in hours |
| 143 | + |
| 144 | +.. code-block:: python |
| 145 | +
|
| 146 | + import matplotlib.pyplot as plt |
| 147 | + from stable_baselines3.common import results_plotter |
| 148 | +
|
| 149 | + # Plot by timesteps (shows sample efficiency) |
| 150 | + # plot_results([log_dir], None, results_plotter.X_TIMESTEPS, "Rewards vs Timesteps") |
| 151 | + # By Episodes |
| 152 | + plot_results([log_dir], None, results_plotter.X_EPISODES, "Rewards vs Episodes") |
| 153 | + # plot_results([log_dir], None, results_plotter.X_WALLTIME, "Rewards vs Time") |
| 154 | +
|
| 155 | + plt.tight_layout() |
| 156 | + plt.show() |
| 157 | +
|
| 158 | +
|
| 159 | +Advanced Plotting with Manual Data Processing |
| 160 | +--------------------------------------------- |
| 161 | + |
| 162 | +For more control over the plotting, you can use the underlying functions to process the data manually: |
| 163 | + |
| 164 | +.. code-block:: python |
| 165 | +
|
| 166 | + import numpy as np |
| 167 | + import matplotlib.pyplot as plt |
| 168 | + from stable_baselines3.common.monitor import load_results |
| 169 | + from stable_baselines3.common.results_plotter import ts2xy, window_func |
| 170 | +
|
| 171 | + # Load the results |
| 172 | + df = load_results(log_dir) |
| 173 | +
|
| 174 | + # Convert dataframe (x=timesteps, y=episodic return) |
| 175 | + x, y = ts2xy(df, "timesteps") |
| 176 | +
|
| 177 | + # Plot raw data |
| 178 | + plt.figure(figsize=(10, 6)) |
| 179 | + plt.subplot(2, 1, 1) |
| 180 | + plt.scatter(x, y, s=2, alpha=0.6) |
| 181 | + plt.xlabel("Timesteps") |
| 182 | + plt.ylabel("Episode Reward") |
| 183 | + plt.title("Raw Episode Rewards") |
| 184 | +
|
| 185 | + # Plot smoothed data with custom window |
| 186 | + plt.subplot(2, 1, 2) |
| 187 | + if len(x) >= 50: # Only smooth if we have enough data |
| 188 | + x_smooth, y_smooth = window_func(x, y, 50, np.mean) |
| 189 | + plt.plot(x_smooth, y_smooth, linewidth=2) |
| 190 | + plt.xlabel("Timesteps") |
| 191 | + plt.ylabel("Average Episode Reward (50-episode window)"") |
| 192 | + plt.title("Smoothed Episode Rewards") |
| 193 | +
|
| 194 | + plt.tight_layout() |
| 195 | + plt.show() |
0 commit comments