diff --git a/src/progpy/models/aircraft_model/small_rotorcraft.py b/src/progpy/models/aircraft_model/small_rotorcraft.py index 06cee9d..b2f5390 100644 --- a/src/progpy/models/aircraft_model/small_rotorcraft.py +++ b/src/progpy/models/aircraft_model/small_rotorcraft.py @@ -322,7 +322,7 @@ def linear_model(self, phi, theta, psi, p, q, r, T): return A, B - def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs): + def visualize_traj(self, pred, ref=None, prefix='', fig=None, pred_cfg={'linewidth': 2.0, 'alpha': 0.6, 'color': 'tab:blue', 'linestyle':'-', 'label':'predicted'}, ref_cfg={'linewidth': 2.0, 'alpha': 0.6, 'color': 'tab:orange', 'linestyle':'--', 'label':'reference'}): """ This method provides functionality to visualize a predicted trajectory generated, plotted with the reference trajectory. @@ -339,6 +339,10 @@ def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs): Reference trajectory - dict with keys for each state in the vehicle model and corresponding values as numpy arrays prefix : str, optional Prefix added to keys in predicted values. This is used to plot the trajectory using the results from a composite model + pred_cfg : dict, optional + Configuration for the prediction line on the graphs. See matplotlib.pyplot.plot documentation for more details + ref_cfg : dict, optional + Configuration for the reference line (if provided) on the graphs. See matplotlib.pyplot.plot documentation for more details fig : TrajectoryFigure, optional Figure where the additional diagrams are to be added. Creates a new figure if not provided @@ -352,9 +356,6 @@ def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs): elif not isinstance(fig, TrajectoryFigure): raise TypeError(f"fig must be a TrajectorFigure, was {type(fig)}") - params = {'linewidth': 2.0, 'alpha': 0.6} - params.update(kwargs) - # Handle reference information if ref is not None: # Extract reference trajectory information @@ -364,8 +365,8 @@ def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs): ref_z = ref['z'].tolist() # Plot reference trajectories - fig.plot_traj(ref_x, ref_y, linestyle='--', color='tab:orange', label='reference trajectory', **params) - fig.plot_alt(time, ref_z, linestyle='-', color='tab:orange', label='reference trajectory', **params) + fig.plot_traj(ref_x, ref_y, **ref_cfg) + fig.plot_alt(time, ref_z, **ref_cfg) # Extract predicted trajectory information pred_x = [pred.outputs[iter][prefix+'x'] for iter in range(len(pred.times))] @@ -373,8 +374,11 @@ def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs): pred_z = [pred.outputs[iter][prefix+'z'] for iter in range(len(pred.times))] # Plot predictions - fig.plot_traj(pred_x, pred_y, linestyle='--', color='tab:orange', label='prediction', **params) - fig.plot_alt(pred.times, pred_z, linestyle='-', color='tab:orange', label='prediction', **params) + fig.plot_traj(pred_x, pred_y, **pred_cfg) + fig.plot_alt(pred.times, pred_z, **pred_cfg) + + # Final formatting + fig.get_axes()[0].legend(fontsize=14) return fig \ No newline at end of file