Skip to content

Commit 720f79f

Browse files
reint-fischerreint-fischer
authored andcommitted
add Vesna's animation and move to examples folder
1 parent 38df000 commit 720f79f

File tree

1 file changed

+150
-54
lines changed

1 file changed

+150
-54
lines changed

docs/examples_v3/tutorial_output.ipynb renamed to docs/examples/tutorial_output.ipynb

Lines changed: 150 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"- [**Plotting**](#Plotting)\n",
2222
"- [**Animations**](#Animations)\n",
2323
"\n",
24-
"For more advanced reading and tutorials on the analysis of Lagrangian trajectories, we recommend checking out the [Lagrangian Diagnostics](https://lagrangian-diags.readthedocs.io/en/latest/index.html) project."
24+
"For more advanced reading and tutorials on the analysis of Lagrangian trajectories, we recommend checking out the [Lagrangian Diagnostics](https://lagrangian-diags.readthedocs.io/en/latest/index.html) project. The [TrajAn package]() can be used to read and plot datasets of Lagrangian trajectories."
2525
]
2626
},
2727
{
@@ -128,12 +128,7 @@
128128
"metadata": {},
129129
"outputs": [],
130130
"source": [
131-
"data_xarray = xr.open_zarr(\"Output.zarr\", decode_times=False)\n",
132-
"# TODO: remove once decode_times issue is fixed\n",
133-
"data_xarray[\"time\"].values = (\n",
134-
" data_xarray[\"time\"].values.astype(\"timedelta64[s]\").astype(\"timedelta64[ns]\")\n",
135-
")\n",
136-
"data_xarray[\"time\"].values[data_xarray[\"time\"].values < 0] = np.timedelta64(\"NaT\", \"ns\")\n",
131+
"data_xarray = xr.open_zarr(\"Output.zarr\")\n",
137132
"\n",
138133
"print(data_xarray)"
139134
]
@@ -168,10 +163,9 @@
168163
"source": [
169164
"np.set_printoptions(linewidth=160)\n",
170165
"one_hour = np.timedelta64(1, \"h\") # Define timedelta object to help with conversion\n",
166+
"time_from_start = data_xarray[\"time\"].values - fieldset.time_interval.left\n",
171167
"\n",
172-
"print(\n",
173-
" data_xarray[\"time\"].data / one_hour\n",
174-
") # timedelta / timedelta -> float number of hours"
168+
"print(time_from_start / one_hour) # timedelta / timedelta -> float number of hours"
175169
]
176170
},
177171
{
@@ -206,9 +200,9 @@
206200
" np.sqrt(np.square(np.diff(x)) + np.square(np.diff(y))), axis=1\n",
207201
") # d = (dx^2 + dy^2)^(1/2)\n",
208202
"\n",
209-
"real_time = data_xarray[\"time\"] / one_hour # convert time to hours\n",
203+
"real_time = time_from_start / one_hour # convert time to hours\n",
210204
"time_since_release = (\n",
211-
" real_time.values.transpose() - real_time.values[:, 0]\n",
205+
" real_time.transpose() - real_time[:, 0]\n",
212206
") # substract the initial time from each timeseries"
213207
]
214208
},
@@ -380,17 +374,9 @@
380374
"metadata": {},
381375
"outputs": [],
382376
"source": [
383-
"from IPython.display import HTML\n",
384-
"from matplotlib.animation import FuncAnimation\n",
385-
"\n",
386-
"outputdt = timedelta(hours=2)\n",
387-
"\n",
388-
"# timerange in nanoseconds\n",
389-
"timerange = np.arange(\n",
390-
" np.nanmin(data_xarray[\"time\"].values),\n",
391-
" np.nanmax(data_xarray[\"time\"].values) + np.timedelta64(outputdt),\n",
392-
" outputdt,\n",
393-
")"
377+
"import cartopy.crs as ccrs\n",
378+
"import cartopy.feature as cfeature\n",
379+
"import matplotlib"
394380
]
395381
},
396382
{
@@ -399,47 +385,157 @@
399385
"metadata": {},
400386
"outputs": [],
401387
"source": [
402-
"%%capture\n",
403-
"fig = plt.figure(figsize=(5, 5), constrained_layout=True)\n",
404-
"ax = fig.add_subplot()\n",
405-
"\n",
406-
"ax.set_ylabel(\"Meridional distance [m]\")\n",
407-
"ax.set_xlabel(\"Zonal distance [m]\")\n",
408-
"ax.set_xlim(31, 33)\n",
409-
"ax.set_ylim(-33, -30)\n",
410-
"\n",
411-
"# Indices of the data where time = 0\n",
412-
"time_id = np.where(data_xarray[\"time\"] == timerange[0])\n",
413-
"\n",
414-
"scatter = ax.scatter(\n",
415-
" data_xarray[\"lon\"].values[time_id], data_xarray[\"lat\"].values[time_id]\n",
416-
")\n",
417-
"\n",
418-
"t = str(timerange[0].astype(\"timedelta64[h]\"))\n",
419-
"title = ax.set_title(\"Particles at t = \" + t)\n",
420-
"\n",
421-
"\n",
422-
"def animate(i):\n",
423-
" t = str(timerange[i].astype(\"timedelta64[h]\"))\n",
424-
" title.set_text(\"Particles at t = \" + t)\n",
425-
"\n",
426-
" time_id = np.where(data_xarray[\"time\"] == timerange[i])\n",
427-
" scatter.set_offsets(\n",
428-
" np.c_[data_xarray[\"lon\"].values[time_id], data_xarray[\"lat\"].values[time_id]]\n",
388+
"# for interactive display of animation\n",
389+
"plt.rcParams[\"animation.html\"] = \"jshtml\""
390+
]
391+
},
392+
{
393+
"cell_type": "code",
394+
"execution_count": null,
395+
"metadata": {},
396+
"outputs": [],
397+
"source": [
398+
"# Number of timesteps to animate\n",
399+
"nframes = 13 # use less frames for testing purposes\n",
400+
"nreducedtrails = 1 # every 10th particle will have a trail (if 1, all particles have trails. Adjust for faster performance)\n",
401+
"\n",
402+
"\n",
403+
"# Set up the colors and associated trajectories:\n",
404+
"# get release times for each particle (first valide obs for each trajectory)\n",
405+
"release_times = data_xarray[\"time\"].min(dim=\"obs\", skipna=True).values\n",
406+
"\n",
407+
"# get unique release times and assign colors\n",
408+
"unique_release_times = np.unique(release_times[~np.isnat(release_times)])\n",
409+
"n_release_times = len(unique_release_times)\n",
410+
"print(f\"Number of unique release times: {n_release_times}\")\n",
411+
"\n",
412+
"# choose a continuous colormap\n",
413+
"colormap = matplotlib.colormaps[\"tab20b\"]\n",
414+
"\n",
415+
"# set up a unique color for each release time\n",
416+
"release_time_to_color = {}\n",
417+
"for i, release_time in enumerate(unique_release_times):\n",
418+
" release_time_to_color[release_time] = colormap(i / max(n_release_times - 1, 1))\n",
419+
"\n",
420+
"\n",
421+
"# --> Store data for all timeframes (this is needed for faster performance)\n",
422+
"print(\"Pre-computing all particle positions...\")\n",
423+
"all_particles_data = []\n",
424+
"for i, target_time in enumerate(timerange):\n",
425+
" time_id = np.where(data_xarray[\"time\"] == target_time)\n",
426+
" lons = data_xarray[\"lon\"].values[time_id]\n",
427+
" lats = data_xarray[\"lat\"].values[time_id]\n",
428+
" particle_indices = time_id[0]\n",
429+
" valid = ~np.isnan(lons) & ~np.isnan(lats)\n",
430+
"\n",
431+
" all_particles_data.append(\n",
432+
" {\n",
433+
" \"lons\": lons[valid],\n",
434+
" \"lats\": lats[valid],\n",
435+
" \"particle_indices\": particle_indices[valid],\n",
436+
" \"valid_count\": np.sum(valid),\n",
437+
" }\n",
429438
" )\n",
430439
"\n",
431440
"\n",
432-
"anim = FuncAnimation(fig, animate, frames=len(timerange), interval=100)"
441+
"# figure setup\n",
442+
"fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={\"projection\": ccrs.PlateCarree()})\n",
443+
"ax.set_xlim(30, 33)\n",
444+
"ax.set_xticks(np.arange(30, 33.5, 0.5))\n",
445+
"ax.set_xlabel(\"Longitude (deg E)\")\n",
446+
"ax.set_ylim(-33, -30)\n",
447+
"ax.set_yticks(ticks=np.arange(-33, -29.5, 0.5))\n",
448+
"ax.set_yticklabels(np.arange(33, 29.5, -0.5).astype(str))\n",
449+
"ax.set_ylabel(\"Latitude (deg S)\")\n",
450+
"ax.coastlines(color=\"saddlebrown\")\n",
451+
"ax.add_feature(cfeature.LAND, alpha=0.5, facecolor=\"saddlebrown\")\n",
452+
"\n",
453+
"# --> Use pre-computed data for initial setup\n",
454+
"initial_data = all_particles_data[0]\n",
455+
"initial_colors = []\n",
456+
"for particle_idx in initial_data[\"particle_indices\"]:\n",
457+
" rt = release_times[particle_idx]\n",
458+
" if rt in release_time_to_color:\n",
459+
" initial_colors.append(release_time_to_color[rt])\n",
460+
" else:\n",
461+
" initial_colors.append(\"blue\")\n",
462+
"\n",
463+
"# --> plot first timestep\n",
464+
"scatter = ax.scatter(initial_data[\"lons\"], initial_data[\"lats\"], s=10, c=initial_colors)\n",
465+
"\n",
466+
"# --> initialize trails\n",
467+
"trail_plot = []\n",
468+
"\n",
469+
"# Set initial title\n",
470+
"t_str = str(timerange[0])[:19] # Format datetime nicely\n",
471+
"title = ax.set_title(f\"Particles at t = {t_str}\")\n",
472+
"\n",
473+
"\n",
474+
"# loop over for animation\n",
475+
"def animate(i):\n",
476+
" print(f\"Animating frame {i + 1}/{len(timerange)} at time {timerange[i]}\")\n",
477+
" t_str = str(timerange[i])[:19]\n",
478+
" title.set_text(f\"Particles at t = {t_str}\")\n",
479+
"\n",
480+
" # Find particles at current time\n",
481+
" current_data = all_particles_data[i]\n",
482+
"\n",
483+
" if current_data[\"valid_count\"] > 0:\n",
484+
" current_colors = []\n",
485+
" for particle_idx in current_data[\"particle_indices\"]:\n",
486+
" rt = release_times[particle_idx]\n",
487+
" current_colors.append(release_time_to_color[rt])\n",
488+
"\n",
489+
" scatter.set_offsets(np.c_[current_data[\"lons\"], current_data[\"lats\"]])\n",
490+
" scatter.set_color(current_colors)\n",
491+
"\n",
492+
" # --> add trails\n",
493+
"\n",
494+
" for trail in trail_plot:\n",
495+
" trail.remove()\n",
496+
" trail_plot.clear()\n",
497+
"\n",
498+
" trail_length = min(10, i) # trails will have max length of 10 time steps\n",
499+
"\n",
500+
" if trail_length > 0:\n",
501+
" sampled_particles = current_data[\"particle_indices\"][\n",
502+
" ::nreducedtrails\n",
503+
" ] # use all or sample if you want faster computation\n",
504+
"\n",
505+
" for particle_idx in sampled_particles:\n",
506+
" trail_lons = []\n",
507+
" trail_lats = []\n",
508+
" for j in range(i - trail_length, i + 1):\n",
509+
" past_data = all_particles_data[j]\n",
510+
" if particle_idx in past_data[\"particle_indices\"]:\n",
511+
" idx = np.where(past_data[\"particle_indices\"] == particle_idx)[\n",
512+
" 0\n",
513+
" ][0]\n",
514+
" trail_lons.append(past_data[\"lons\"][idx])\n",
515+
" trail_lats.append(past_data[\"lats\"][idx])\n",
516+
" if len(trail_lons) > 1:\n",
517+
" rt = release_times[particle_idx]\n",
518+
" color = release_time_to_color[rt]\n",
519+
" (trail,) = ax.plot(\n",
520+
" trail_lons, trail_lats, color=color, linewidth=0.6, alpha=0.6\n",
521+
" )\n",
522+
" trail_plot.append(trail)\n",
523+
"\n",
524+
" else:\n",
525+
" scatter.set_offsets(np.empty((0, 2)))\n",
526+
"\n",
527+
"\n",
528+
"# Create animation\n",
529+
"anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)\n",
530+
"anim"
433531
]
434532
},
435533
{
436534
"cell_type": "code",
437535
"execution_count": null,
438536
"metadata": {},
439537
"outputs": [],
440-
"source": [
441-
"HTML(anim.to_jshtml())"
442-
]
538+
"source": []
443539
}
444540
],
445541
"metadata": {

0 commit comments

Comments
 (0)