Skip to content

Commit 02c9752

Browse files
authored
Merge pull request #81 from bleykauf/feature/better-plotting
Add better plotting methods
2 parents c3f52ca + 8c8bc64 commit 02c9752

File tree

3 files changed

+84
-21
lines changed

3 files changed

+84
-21
lines changed

aisim/atoms.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Classes and functions related to the atomic cloud."""
22

3+
from typing import Literal
4+
5+
import matplotlib.pyplot as plt
36
import numpy as np
47
import scipy.linalg as splin
58

@@ -32,13 +35,6 @@ class AtomicEnsemble:
3235
phase_space_vectors : ndarray
3336
n x 6 dimensional array representing the phase space vectors
3437
(x0, y0, z0, vx, vy, vz) of the atoms in an atomic ensemble
35-
position
36-
velocity
37-
state_kets
38-
state_bras
39-
density_matrices
40-
density_matrix
41-
4238
"""
4339

4440
def __init__(self, phase_space_vectors, state_kets=[1, 0], time=0):
@@ -53,7 +49,7 @@ def __init__(self, phase_space_vectors, state_kets=[1, 0], time=0):
5349
def __getitem__(self, key):
5450
"""Select certain atoms "from the ensemble.
5551
56-
Parameters
52+
Parameters
5753
----------
5854
key : int or slice or bool map
5955
for example 2, 1:15 or a boolean map
@@ -210,6 +206,49 @@ def fidelity(self, rho_target):
210206
"""
211207
return _fidelity(self.density_matrix, rho_target)
212208

209+
def plot(
210+
self,
211+
ax: plt.Axes | None = None,
212+
view_from: Literal["x", "y", "z"] = "z",
213+
bins: int = 50,
214+
**kwargs,
215+
) -> tuple[plt.Figure, plt.Axes]:
216+
"""Plot the positions of the atoms in the ensemble.
217+
218+
ax : Axis , optional
219+
If axis is provided, they will be used for the plot. if not provided, a new
220+
plot will automatically be created.
221+
view_from : str
222+
View from which direction the plot is created. Options are "x", "y", "z".
223+
bins : int
224+
Number of bins for the histogram
225+
**kwargs
226+
Additional keyword arguments for the plot function
227+
228+
Returns
229+
-------
230+
fig, ax : tuple of plt.Figure and plt.Axes
231+
The figure and axis of the plot
232+
"""
233+
if ax is None:
234+
fig, ax = plt.subplots(subplot_kw=dict(projection="polar"))
235+
else:
236+
fig = ax.figure
237+
238+
view = {
239+
"x": (1, 2),
240+
"y": (0, 2),
241+
"z": (0, 1),
242+
}
243+
ax.hist2d(
244+
self.position[:, view[view_from][0]],
245+
self.position[:, view[view_from][1]],
246+
bins=bins,
247+
**kwargs,
248+
)
249+
250+
return fig, ax
251+
213252

214253
def create_random_ensemble_from_gaussian_distribution(
215254
pos_params, vel_params, n_samples, seed=None, **kwargs

aisim/beam.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import matplotlib.pyplot as plt
44
import numpy as np
5+
from matplotlib.colors import Colormap
56

67
from . import convert
78
from .zern import FIRST_INDEX_J, ZernikeNorm, ZernikeOrder, ZernikePolynomial
@@ -176,7 +177,13 @@ def get_value(self, pos: np.ndarray) -> np.ndarray:
176177
values[rho > self.r_beam / self.r_wf] = np.nan
177178
return values
178179

179-
def plot(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
180+
def plot(
181+
self,
182+
ax: plt.Axes | None = None,
183+
cmap: str | Colormap = "RdBu",
184+
levels: int = 100,
185+
**kwargs,
186+
) -> tuple[plt.Figure, plt.Axes]:
180187
"""
181188
Plot the wavefront data.
182189
@@ -185,7 +192,17 @@ def plot(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
185192
ax : Axis , optional
186193
If axis is provided, they will be used for the plot. if not provided, a new
187194
plot will automatically be created.
195+
cmap : str or Colormap
196+
Colormap for the plot
197+
level : int
198+
Number of levels for the contour plot
199+
**kwargs
200+
Additional keyword arguments for the plot function
188201
202+
Returns
203+
-------
204+
fig, ax : tuple of plt.Figure and plt.Axes
205+
The figure and axis of the plot
189206
"""
190207
azimuths = np.radians(np.linspace(0, 360, 180))
191208
zeniths = np.linspace(0, self.r_wf, 50)
@@ -204,14 +221,16 @@ def plot(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
204221
theta = theta.reshape(n_dim, m_dim)
205222
rho = rho.reshape(n_dim, m_dim)
206223
values = values.reshape(n_dim, m_dim)
207-
contour = ax.contourf(theta, rho, values)
224+
contour = ax.contourf(theta, rho, values, cmap=cmap, levels=levels, **kwargs)
208225
cbar = plt.colorbar(contour)
209226
cbar.set_label(r"Aberration / $\lambda$", rotation=90)
210227
plt.tight_layout()
211228

212229
return fig, ax
213230

214-
def plot_coeff(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
231+
def plot_coeff(
232+
self, ax: plt.Axes | None = None, **kwargs
233+
) -> tuple[plt.Figure, plt.Axes]:
215234
"""
216235
Plot the coefficients as a bar chart.
217236
@@ -220,13 +239,18 @@ def plot_coeff(self, ax: plt.Axes | None = None) -> tuple[plt.Figure, plt.Axes]:
220239
ax : Axis , optional
221240
If axis is provided, they will be used for the plot. if not provided, a new
222241
plot will automatically be created.
242+
243+
Returns
244+
-------
245+
fig, ax : tuple of plt.Figure and plt.Axes
246+
The figure and axis of the plot
223247
"""
224248
if ax is None:
225249
fig, ax = plt.subplots()
226250
else:
227251
fig = ax.figure
228252

229-
ax.bar(list(self.coeff.keys()), list(self.coeff.values()))
253+
ax.bar(list(self.coeff.keys()), list(self.coeff.values()), **kwargs)
230254
ax.set_xlabel("Zernike polynomial $j$")
231255
ax.set_ylabel(r"Zernike coefficient $Z_j$ / $\lambda$")
232256
ax.set_xlim(min(self.coeff.keys()) - 1, max(self.coeff.keys()) + 1)

docs/examples/wavefront-aberrations.ipynb

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)