Skip to content

Commit bc4d9eb

Browse files
committed
Working, but perhaps slow, version of the animate_polygon for animating lists of data with Mike Kryjak's polygon plot.
1 parent 367c225 commit bc4d9eb

File tree

2 files changed

+266
-4
lines changed

2 files changed

+266
-4
lines changed

xbout/boutdataset.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import animatplot as amp
1212
from matplotlib import pyplot as plt
1313
from matplotlib.animation import PillowWriter
14+
from matplotlib.animation import FuncAnimation
1415

1516
from mpl_toolkits.axes_grid1 import make_axes_locatable
1617

@@ -22,6 +23,7 @@
2223
animate_poloidal,
2324
animate_pcolormesh,
2425
animate_line,
26+
animate_polygon,
2527
_add_controls,
2628
_normalise_time_coord,
2729
_parse_coord_option,
@@ -1345,6 +1347,259 @@ def is_list(variable):
13451347

13461348
return anim
13471349

1350+
def animate_polygon_list(
1351+
self,
1352+
variables,
1353+
animate_over=None,
1354+
save_as=None,
1355+
show=False,
1356+
fps=10,
1357+
nrows=None,
1358+
ncols=None,
1359+
poloidal_plot=False,
1360+
axis_coords=None,
1361+
subplots_adjust=None,
1362+
vmin=None,
1363+
vmax=None,
1364+
logscale=None,
1365+
titles=None,
1366+
aspect=None,
1367+
extend=None,
1368+
#controls="both",
1369+
tight_layout=True,
1370+
**kwargs,
1371+
):
1372+
"""
1373+
Parameters
1374+
----------
1375+
variables : list of str or BoutDataArray
1376+
The variables to plot. For any string passed, the corresponding
1377+
variable in this DataSet is used - then the calling DataSet must
1378+
have only 3 dimensions. It is possible to pass BoutDataArrays to
1379+
allow more flexible plots, e.g. with different variables being
1380+
plotted against different axes.
1381+
animate_over : str, optional
1382+
Dimension over which to animate, defaults to the time dimension
1383+
save_as : str, optional
1384+
If passed, a gif is created with this filename
1385+
show : bool, optional
1386+
Call pyplot.show() to display the animation
1387+
fps : float, optional
1388+
Indicates the number of frames per second to play
1389+
nrows : int, optional
1390+
Specify the number of rows of plots
1391+
ncols : int, optional
1392+
Specify the number of columns of plots
1393+
poloidal_plot : bool or sequence of bool, optional
1394+
If set to True, make all 2D animations in the poloidal plane instead of using
1395+
grid coordinates, per variable if sequence is given
1396+
axis_coords : None, str, dict or list of None, str or dict
1397+
Coordinates to use for axis labelling.
1398+
1399+
- None: Use the dimension coordinate for each axis, if it exists.
1400+
- "index": Use the integer index values.
1401+
- dict: keys are dimension names, values set axis_coords for each axis
1402+
separately. Values can be: None, "index", the name of a 1d variable or
1403+
coordinate (which must have the dimension given by 'key'), or a 1d
1404+
numpy array, dask array or DataArray whose length matches the length of
1405+
the dimension given by 'key'.
1406+
1407+
Only affects time coordinate for plots with poloidal_plot=True.
1408+
If a list is passed, it must have the same length as 'variables' and gives
1409+
the axis_coords setting for each plot individually.
1410+
The setting to use for the 'animate_over' coordinate can be passed in one or
1411+
more dict values, but must be the same in all dicts if given more than once.
1412+
subplots_adjust : dict, optional
1413+
Arguments passed to fig.subplots_adjust()()
1414+
vmin : float or sequence of floats
1415+
Minimum value for color scale, per variable if a sequence is given
1416+
vmax : float or sequence of floats
1417+
Maximum value for color scale, per variable if a sequence is given
1418+
logscale : bool or float, sequence of bool or float, optional
1419+
If True, default to a logarithmic color scale instead of a linear one.
1420+
If a non-bool type is passed it is treated as a float used to set the linear
1421+
threshold of a symmetric logarithmic scale as
1422+
linthresh=min(abs(vmin),abs(vmax))*logscale, defaults to 1e-5 if True is
1423+
passed.
1424+
Per variable if sequence is given.
1425+
titles : sequence of str or None, optional
1426+
Custom titles for each plot. Pass None in the sequence to use the default for
1427+
a certain variable
1428+
aspect : str or None, or sequence of str or None, optional
1429+
Argument to set_aspect() for each plot. Defaults to "equal" for poloidal
1430+
plots and "auto" for others.
1431+
extend : str or None, optional
1432+
Passed to fig.colorbar()
1433+
controls : string or None, default "both"
1434+
By default, add both the timeline and play/pause toggle to the animation. If
1435+
"timeline" is passed add only the timeline, if "toggle" is passed add only
1436+
the play/pause toggle. If None or an empty string is passed, add neither.
1437+
tight_layout : bool or dict, optional
1438+
If set to False, don't call tight_layout() on the figure.
1439+
If a dict is passed, the dict entries are passed as arguments to
1440+
tight_layout()
1441+
**kwargs : dict, optional
1442+
Additional keyword arguments are passed on to each animation function, per
1443+
variable if a sequence is given.
1444+
1445+
Returns
1446+
-------
1447+
animation
1448+
An animatplot.Animation object.
1449+
"""
1450+
1451+
if animate_over is None:
1452+
animate_over = self.metadata.get("bout_tdim", "t")
1453+
1454+
nvars = len(variables)
1455+
1456+
if nrows is None and ncols is None:
1457+
ncols = int(np.ceil(np.sqrt(nvars)))
1458+
nrows = int(np.ceil(nvars / ncols))
1459+
elif nrows is None:
1460+
nrows = int(np.ceil(nvars / ncols))
1461+
elif ncols is None:
1462+
ncols = int(np.ceil(nvars / nrows))
1463+
else:
1464+
if nrows * ncols < nvars:
1465+
raise ValueError("Not enough rows*columns to fit all variables")
1466+
1467+
fig, axes = plt.subplots(nrows, ncols, squeeze=False)
1468+
axes = axes.flatten()
1469+
1470+
ncells = nrows * ncols
1471+
1472+
if nvars < ncells:
1473+
for index in range(ncells - nvars):
1474+
fig.delaxes(axes[ncells - index - 1])
1475+
1476+
if subplots_adjust is not None:
1477+
fig.subplots_adjust(**subplots_adjust)
1478+
1479+
def _expand_list_arg(arg, arg_name):
1480+
if isinstance(arg, collections.abc.Sequence) and not isinstance(arg, str):
1481+
if len(arg) != len(variables):
1482+
raise ValueError(
1483+
"if %s is a sequence, it must have the same "
1484+
'number of elements as "variables"' % arg_name
1485+
)
1486+
else:
1487+
arg = [arg] * len(variables)
1488+
return arg
1489+
1490+
poloidal_plot = _expand_list_arg(poloidal_plot, "poloidal_plot")
1491+
vmin = _expand_list_arg(vmin, "vmin")
1492+
vmax = _expand_list_arg(vmax, "vmax")
1493+
logscale = _expand_list_arg(logscale, "logscale")
1494+
titles = _expand_list_arg(titles, "titles")
1495+
aspect = _expand_list_arg(aspect, "aspect")
1496+
extend = _expand_list_arg(extend, "extend")
1497+
axis_coords = _expand_list_arg(axis_coords, "axis_coords")
1498+
for k in kwargs:
1499+
kwargs[k] = _expand_list_arg(kwargs[k], k)
1500+
1501+
animate_data = []
1502+
1503+
def is_list(variable):
1504+
return (
1505+
isinstance(variable, list)
1506+
or isinstance(variable, tuple)
1507+
or isinstance(variable, set)
1508+
)
1509+
1510+
for i, subplot_args in enumerate(
1511+
zip(
1512+
variables,
1513+
axes,
1514+
poloidal_plot,
1515+
vmin,
1516+
vmax,
1517+
logscale,
1518+
titles,
1519+
)
1520+
):
1521+
(
1522+
v,
1523+
ax,
1524+
this_poloidal_plot,
1525+
this_vmin,
1526+
this_vmax,
1527+
this_logscale,
1528+
this_title,
1529+
) = subplot_args
1530+
1531+
this_kwargs = {k: v[i] for k, v in kwargs.items()}
1532+
1533+
divider = make_axes_locatable(ax)
1534+
cax = divider.append_axes("right", size="5%", pad=0.1)
1535+
1536+
if isinstance(v, str):
1537+
v = self.data[v]
1538+
1539+
data = v.bout.data
1540+
ndims = len(data.dims)
1541+
ax.set_title(data.name)
1542+
1543+
if ndims == 3:
1544+
if this_poloidal_plot:
1545+
polys, da, update_func = animate_polygon(
1546+
data,
1547+
ax=ax,
1548+
cax=cax,
1549+
vmin=this_vmin,
1550+
vmax=this_vmax,
1551+
logscale=this_logscale,
1552+
animate=False,
1553+
**this_kwargs,
1554+
)
1555+
animate_data.append([polys,da,update_func])
1556+
else:
1557+
raise ValueError(
1558+
"Unsupported option "
1559+
+ ". this_poloidal_plot "
1560+
+ str(this_poloidal_plot)
1561+
)
1562+
else:
1563+
raise ValueError(
1564+
"Unsupported number of dimensions "
1565+
+ str(ndims)
1566+
+ ". Dims are "
1567+
+ str(v.dims)
1568+
)
1569+
1570+
if this_title is not None:
1571+
# Replace default title with user-specified one
1572+
ax.set_title(this_title)
1573+
1574+
def update(frame):
1575+
for list in animate_data:
1576+
(polys, da, update_func) = list
1577+
# call update function for each axes
1578+
update_func(frame,polys,da)
1579+
1580+
# make the animation for all the subplots simultaneously
1581+
# use the last data array da to choose the number of frames
1582+
# assumes time dimension same length for all variables
1583+
anim = FuncAnimation(fig=fig, func=update, frames=np.shape(da.data)[0], interval=30)
1584+
if tight_layout:
1585+
if subplots_adjust is not None:
1586+
warnings.warn(
1587+
"tight_layout argument to animate_list() is True, but "
1588+
"subplots_adjust argument is not None. subplots_adjust "
1589+
"is being ignored."
1590+
)
1591+
if not isinstance(tight_layout, dict):
1592+
tight_layout = {}
1593+
fig.tight_layout(**tight_layout)
1594+
1595+
if save_as is not None:
1596+
anim.save(save_as + ".gif", writer=PillowWriter(fps=fps))
1597+
1598+
if show:
1599+
plt.show()
1600+
1601+
return anim
1602+
13481603
def with_cherab_grid(self):
13491604
"""
13501605
Returns a new DataSet with a 'cherab_grid' attribute.

xbout/plotting/animate.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def animate_polygon(
700700
grid_only=False,
701701
linewidth=0,
702702
linecolor="black",
703+
animate=True,
703704
):
704705
"""
705706
Nice looking 2D plots which have no visual artifacts around the X-point.
@@ -862,8 +863,14 @@ def update(frame):
862863

863864
if targets:
864865
plot_targets(da, ax, x="R", y="Z", hatching=add_limiter_hatching)
865-
866-
# make the animation by using FuncAnimation and update() to generate frames
867-
ani = matplotlib.animation.FuncAnimation(fig=fig, func=update, frames=np.shape(da.data)[0], interval=30)
868-
return ani
866+
if animate:
867+
# make the animation by using FuncAnimation and update() to generate frames
868+
ani = matplotlib.animation.FuncAnimation(fig=fig, func=update, frames=np.shape(da.data)[0], interval=30)
869+
return ani
870+
else:
871+
# return function and data for making the animation
872+
def update_out(frame,polys,da):
873+
colors = da.data[frame,:,:].flatten()
874+
polys.set_array(colors)
875+
return polys, da, update_out
869876

0 commit comments

Comments
 (0)