22
33import matplotlib .pyplot as plt
44import numpy as np
5+ from matplotlib .colors import Colormap
56
67from . import convert
78from .zern import FIRST_INDEX_J , ZernikeNorm , ZernikeOrder , ZernikePolynomial
@@ -176,7 +177,13 @@ def get_value(self, pos: np.ndarray) -> np.ndarray:
176177 values [rho > self .r_beam / self .r_wf ] = np .nan
177178 return values
178179
179- def plot (self , ax : plt .Axes | None = None ) -> tuple [plt .Figure , plt .Axes ]:
180+ def plot (
181+ self ,
182+ ax : plt .Axes | None = None ,
183+ cmap : str | Colormap = "RdBu" ,
184+ levels : int = 100 ,
185+ ** kwargs ,
186+ ) -> tuple [plt .Figure , plt .Axes ]:
180187 """
181188 Plot the wavefront data.
182189
@@ -185,7 +192,17 @@ def plot(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
185192 ax : Axis , optional
186193 If axis is provided, they will be used for the plot. if not provided, a new
187194 plot will automatically be created.
195+ cmap : str or Colormap
196+ Colormap for the plot
197+ level : int
198+ Number of levels for the contour plot
199+ **kwargs
200+ Additional keyword arguments for the plot function
188201
202+ Returns
203+ -------
204+ fig, ax : tuple of plt.Figure and plt.Axes
205+ The figure and axis of the plot
189206 """
190207 azimuths = np .radians (np .linspace (0 , 360 , 180 ))
191208 zeniths = np .linspace (0 , self .r_wf , 50 )
@@ -204,14 +221,16 @@ def plot(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
204221 theta = theta .reshape (n_dim , m_dim )
205222 rho = rho .reshape (n_dim , m_dim )
206223 values = values .reshape (n_dim , m_dim )
207- contour = ax .contourf (theta , rho , values )
224+ contour = ax .contourf (theta , rho , values , cmap = cmap , levels = levels , ** kwargs )
208225 cbar = plt .colorbar (contour )
209226 cbar .set_label (r"Aberration / $\lambda$" , rotation = 90 )
210227 plt .tight_layout ()
211228
212229 return fig , ax
213230
214- def plot_coeff (self , ax : plt .Axes | None = None ) -> tuple [plt .Figure , plt .Axes ]:
231+ def plot_coeff (
232+ self , ax : plt .Axes | None = None , ** kwargs
233+ ) -> tuple [plt .Figure , plt .Axes ]:
215234 """
216235 Plot the coefficients as a bar chart.
217236
@@ -220,13 +239,18 @@ def plot_coeff(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
220239 ax : Axis , optional
221240 If axis is provided, they will be used for the plot. if not provided, a new
222241 plot will automatically be created.
242+
243+ Returns
244+ -------
245+ fig, ax : tuple of plt.Figure and plt.Axes
246+ The figure and axis of the plot
223247 """
224248 if ax is None :
225249 fig , ax = plt .subplots ()
226250 else :
227251 fig = ax .figure
228252
229- ax .bar (list (self .coeff .keys ()), list (self .coeff .values ()))
253+ ax .bar (list (self .coeff .keys ()), list (self .coeff .values ()), ** kwargs )
230254 ax .set_xlabel ("Zernike polynomial $j$" )
231255 ax .set_ylabel (r"Zernike coefficient $Z_j$ / $\lambda$" )
232256 ax .set_xlim (min (self .coeff .keys ()) - 1 , max (self .coeff .keys ()) + 1 )
0 commit comments