|
21 | 21 | "- [**Plotting**](#Plotting)\n", |
22 | 22 | "- [**Animations**](#Animations)\n", |
23 | 23 | "\n", |
24 | | - "For more advanced reading and tutorials on the analysis of Lagrangian trajectories, we recommend checking out the [Lagrangian Diagnostics](https://lagrangian-diags.readthedocs.io/en/latest/index.html) project." |
| 24 | + "For more advanced reading and tutorials on the analysis of Lagrangian trajectories, we recommend checking out the [Lagrangian Diagnostics](https://lagrangian-diags.readthedocs.io/en/latest/index.html) project. The [TrajAn package]() can be used to read and plot datasets of Lagrangian trajectories." |
25 | 25 | ] |
26 | 26 | }, |
27 | 27 | { |
|
128 | 128 | "metadata": {}, |
129 | 129 | "outputs": [], |
130 | 130 | "source": [ |
131 | | - "data_xarray = xr.open_zarr(\"Output.zarr\", decode_times=False)\n", |
132 | | - "# TODO: remove once decode_times issue is fixed\n", |
133 | | - "data_xarray[\"time\"].values = (\n", |
134 | | - " data_xarray[\"time\"].values.astype(\"timedelta64[s]\").astype(\"timedelta64[ns]\")\n", |
135 | | - ")\n", |
136 | | - "data_xarray[\"time\"].values[data_xarray[\"time\"].values < 0] = np.timedelta64(\"NaT\", \"ns\")\n", |
| 131 | + "data_xarray = xr.open_zarr(\"Output.zarr\")\n", |
137 | 132 | "\n", |
138 | 133 | "print(data_xarray)" |
139 | 134 | ] |
|
168 | 163 | "source": [ |
169 | 164 | "np.set_printoptions(linewidth=160)\n", |
170 | 165 | "one_hour = np.timedelta64(1, \"h\") # Define timedelta object to help with conversion\n", |
| 166 | + "time_from_start = data_xarray[\"time\"].values - fieldset.time_interval.left\n", |
171 | 167 | "\n", |
172 | | - "print(\n", |
173 | | - " data_xarray[\"time\"].data / one_hour\n", |
174 | | - ") # timedelta / timedelta -> float number of hours" |
| 168 | + "print(time_from_start / one_hour) # timedelta / timedelta -> float number of hours" |
175 | 169 | ] |
176 | 170 | }, |
177 | 171 | { |
|
206 | 200 | " np.sqrt(np.square(np.diff(x)) + np.square(np.diff(y))), axis=1\n", |
207 | 201 | ") # d = (dx^2 + dy^2)^(1/2)\n", |
208 | 202 | "\n", |
209 | | - "real_time = data_xarray[\"time\"] / one_hour # convert time to hours\n", |
| 203 | + "real_time = time_from_start / one_hour # convert time to hours\n", |
210 | 204 | "time_since_release = (\n", |
211 | | - " real_time.values.transpose() - real_time.values[:, 0]\n", |
| 205 | + " real_time.transpose() - real_time[:, 0]\n", |
212 | 206 | ") # substract the initial time from each timeseries" |
213 | 207 | ] |
214 | 208 | }, |
|
380 | 374 | "metadata": {}, |
381 | 375 | "outputs": [], |
382 | 376 | "source": [ |
383 | | - "from IPython.display import HTML\n", |
384 | | - "from matplotlib.animation import FuncAnimation\n", |
385 | | - "\n", |
386 | | - "outputdt = timedelta(hours=2)\n", |
387 | | - "\n", |
388 | | - "# timerange in nanoseconds\n", |
389 | | - "timerange = np.arange(\n", |
390 | | - " np.nanmin(data_xarray[\"time\"].values),\n", |
391 | | - " np.nanmax(data_xarray[\"time\"].values) + np.timedelta64(outputdt),\n", |
392 | | - " outputdt,\n", |
393 | | - ")" |
| 377 | + "import cartopy.crs as ccrs\n", |
| 378 | + "import cartopy.feature as cfeature\n", |
| 379 | + "import matplotlib" |
394 | 380 | ] |
395 | 381 | }, |
396 | 382 | { |
|
399 | 385 | "metadata": {}, |
400 | 386 | "outputs": [], |
401 | 387 | "source": [ |
402 | | - "%%capture\n", |
403 | | - "fig = plt.figure(figsize=(5, 5), constrained_layout=True)\n", |
404 | | - "ax = fig.add_subplot()\n", |
405 | | - "\n", |
406 | | - "ax.set_ylabel(\"Meridional distance [m]\")\n", |
407 | | - "ax.set_xlabel(\"Zonal distance [m]\")\n", |
408 | | - "ax.set_xlim(31, 33)\n", |
409 | | - "ax.set_ylim(-33, -30)\n", |
410 | | - "\n", |
411 | | - "# Indices of the data where time = 0\n", |
412 | | - "time_id = np.where(data_xarray[\"time\"] == timerange[0])\n", |
413 | | - "\n", |
414 | | - "scatter = ax.scatter(\n", |
415 | | - " data_xarray[\"lon\"].values[time_id], data_xarray[\"lat\"].values[time_id]\n", |
416 | | - ")\n", |
417 | | - "\n", |
418 | | - "t = str(timerange[0].astype(\"timedelta64[h]\"))\n", |
419 | | - "title = ax.set_title(\"Particles at t = \" + t)\n", |
420 | | - "\n", |
421 | | - "\n", |
422 | | - "def animate(i):\n", |
423 | | - " t = str(timerange[i].astype(\"timedelta64[h]\"))\n", |
424 | | - " title.set_text(\"Particles at t = \" + t)\n", |
425 | | - "\n", |
426 | | - " time_id = np.where(data_xarray[\"time\"] == timerange[i])\n", |
427 | | - " scatter.set_offsets(\n", |
428 | | - " np.c_[data_xarray[\"lon\"].values[time_id], data_xarray[\"lat\"].values[time_id]]\n", |
| 388 | + "# for interactive display of animation\n", |
| 389 | + "plt.rcParams[\"animation.html\"] = \"jshtml\"" |
| 390 | + ] |
| 391 | + }, |
| 392 | + { |
| 393 | + "cell_type": "code", |
| 394 | + "execution_count": null, |
| 395 | + "metadata": {}, |
| 396 | + "outputs": [], |
| 397 | + "source": [ |
| 398 | + "# Number of timesteps to animate\n", |
| 399 | + "nframes = 13 # use less frames for testing purposes\n", |
| 400 | + "nreducedtrails = 1 # every 10th particle will have a trail (if 1, all particles have trails. Adjust for faster performance)\n", |
| 401 | + "\n", |
| 402 | + "\n", |
| 403 | + "# Set up the colors and associated trajectories:\n", |
| 404 | + "# get release times for each particle (first valide obs for each trajectory)\n", |
| 405 | + "release_times = data_xarray[\"time\"].min(dim=\"obs\", skipna=True).values\n", |
| 406 | + "\n", |
| 407 | + "# get unique release times and assign colors\n", |
| 408 | + "unique_release_times = np.unique(release_times[~np.isnat(release_times)])\n", |
| 409 | + "n_release_times = len(unique_release_times)\n", |
| 410 | + "print(f\"Number of unique release times: {n_release_times}\")\n", |
| 411 | + "\n", |
| 412 | + "# choose a continuous colormap\n", |
| 413 | + "colormap = matplotlib.colormaps[\"tab20b\"]\n", |
| 414 | + "\n", |
| 415 | + "# set up a unique color for each release time\n", |
| 416 | + "release_time_to_color = {}\n", |
| 417 | + "for i, release_time in enumerate(unique_release_times):\n", |
| 418 | + " release_time_to_color[release_time] = colormap(i / max(n_release_times - 1, 1))\n", |
| 419 | + "\n", |
| 420 | + "\n", |
| 421 | + "# --> Store data for all timeframes (this is needed for faster performance)\n", |
| 422 | + "print(\"Pre-computing all particle positions...\")\n", |
| 423 | + "all_particles_data = []\n", |
| 424 | + "for i, target_time in enumerate(timerange):\n", |
| 425 | + " time_id = np.where(data_xarray[\"time\"] == target_time)\n", |
| 426 | + " lons = data_xarray[\"lon\"].values[time_id]\n", |
| 427 | + " lats = data_xarray[\"lat\"].values[time_id]\n", |
| 428 | + " particle_indices = time_id[0]\n", |
| 429 | + " valid = ~np.isnan(lons) & ~np.isnan(lats)\n", |
| 430 | + "\n", |
| 431 | + " all_particles_data.append(\n", |
| 432 | + " {\n", |
| 433 | + " \"lons\": lons[valid],\n", |
| 434 | + " \"lats\": lats[valid],\n", |
| 435 | + " \"particle_indices\": particle_indices[valid],\n", |
| 436 | + " \"valid_count\": np.sum(valid),\n", |
| 437 | + " }\n", |
429 | 438 | " )\n", |
430 | 439 | "\n", |
431 | 440 | "\n", |
432 | | - "anim = FuncAnimation(fig, animate, frames=len(timerange), interval=100)" |
| 441 | + "# figure setup\n", |
| 442 | + "fig, ax = plt.subplots(figsize=(6, 5), subplot_kw={\"projection\": ccrs.PlateCarree()})\n", |
| 443 | + "ax.set_xlim(30, 33)\n", |
| 444 | + "ax.set_xticks(np.arange(30, 33.5, 0.5))\n", |
| 445 | + "ax.set_xlabel(\"Longitude (deg E)\")\n", |
| 446 | + "ax.set_ylim(-33, -30)\n", |
| 447 | + "ax.set_yticks(ticks=np.arange(-33, -29.5, 0.5))\n", |
| 448 | + "ax.set_yticklabels(np.arange(33, 29.5, -0.5).astype(str))\n", |
| 449 | + "ax.set_ylabel(\"Latitude (deg S)\")\n", |
| 450 | + "ax.coastlines(color=\"saddlebrown\")\n", |
| 451 | + "ax.add_feature(cfeature.LAND, alpha=0.5, facecolor=\"saddlebrown\")\n", |
| 452 | + "\n", |
| 453 | + "# --> Use pre-computed data for initial setup\n", |
| 454 | + "initial_data = all_particles_data[0]\n", |
| 455 | + "initial_colors = []\n", |
| 456 | + "for particle_idx in initial_data[\"particle_indices\"]:\n", |
| 457 | + " rt = release_times[particle_idx]\n", |
| 458 | + " if rt in release_time_to_color:\n", |
| 459 | + " initial_colors.append(release_time_to_color[rt])\n", |
| 460 | + " else:\n", |
| 461 | + " initial_colors.append(\"blue\")\n", |
| 462 | + "\n", |
| 463 | + "# --> plot first timestep\n", |
| 464 | + "scatter = ax.scatter(initial_data[\"lons\"], initial_data[\"lats\"], s=10, c=initial_colors)\n", |
| 465 | + "\n", |
| 466 | + "# --> initialize trails\n", |
| 467 | + "trail_plot = []\n", |
| 468 | + "\n", |
| 469 | + "# Set initial title\n", |
| 470 | + "t_str = str(timerange[0])[:19] # Format datetime nicely\n", |
| 471 | + "title = ax.set_title(f\"Particles at t = {t_str}\")\n", |
| 472 | + "\n", |
| 473 | + "\n", |
| 474 | + "# loop over for animation\n", |
| 475 | + "def animate(i):\n", |
| 476 | + " print(f\"Animating frame {i + 1}/{len(timerange)} at time {timerange[i]}\")\n", |
| 477 | + " t_str = str(timerange[i])[:19]\n", |
| 478 | + " title.set_text(f\"Particles at t = {t_str}\")\n", |
| 479 | + "\n", |
| 480 | + " # Find particles at current time\n", |
| 481 | + " current_data = all_particles_data[i]\n", |
| 482 | + "\n", |
| 483 | + " if current_data[\"valid_count\"] > 0:\n", |
| 484 | + " current_colors = []\n", |
| 485 | + " for particle_idx in current_data[\"particle_indices\"]:\n", |
| 486 | + " rt = release_times[particle_idx]\n", |
| 487 | + " current_colors.append(release_time_to_color[rt])\n", |
| 488 | + "\n", |
| 489 | + " scatter.set_offsets(np.c_[current_data[\"lons\"], current_data[\"lats\"]])\n", |
| 490 | + " scatter.set_color(current_colors)\n", |
| 491 | + "\n", |
| 492 | + " # --> add trails\n", |
| 493 | + "\n", |
| 494 | + " for trail in trail_plot:\n", |
| 495 | + " trail.remove()\n", |
| 496 | + " trail_plot.clear()\n", |
| 497 | + "\n", |
| 498 | + " trail_length = min(10, i) # trails will have max length of 10 time steps\n", |
| 499 | + "\n", |
| 500 | + " if trail_length > 0:\n", |
| 501 | + " sampled_particles = current_data[\"particle_indices\"][\n", |
| 502 | + " ::nreducedtrails\n", |
| 503 | + " ] # use all or sample if you want faster computation\n", |
| 504 | + "\n", |
| 505 | + " for particle_idx in sampled_particles:\n", |
| 506 | + " trail_lons = []\n", |
| 507 | + " trail_lats = []\n", |
| 508 | + " for j in range(i - trail_length, i + 1):\n", |
| 509 | + " past_data = all_particles_data[j]\n", |
| 510 | + " if particle_idx in past_data[\"particle_indices\"]:\n", |
| 511 | + " idx = np.where(past_data[\"particle_indices\"] == particle_idx)[\n", |
| 512 | + " 0\n", |
| 513 | + " ][0]\n", |
| 514 | + " trail_lons.append(past_data[\"lons\"][idx])\n", |
| 515 | + " trail_lats.append(past_data[\"lats\"][idx])\n", |
| 516 | + " if len(trail_lons) > 1:\n", |
| 517 | + " rt = release_times[particle_idx]\n", |
| 518 | + " color = release_time_to_color[rt]\n", |
| 519 | + " (trail,) = ax.plot(\n", |
| 520 | + " trail_lons, trail_lats, color=color, linewidth=0.6, alpha=0.6\n", |
| 521 | + " )\n", |
| 522 | + " trail_plot.append(trail)\n", |
| 523 | + "\n", |
| 524 | + " else:\n", |
| 525 | + " scatter.set_offsets(np.empty((0, 2)))\n", |
| 526 | + "\n", |
| 527 | + "\n", |
| 528 | + "# Create animation\n", |
| 529 | + "anim = matplotlib.animation.FuncAnimation(fig, animate, frames=nframes, interval=100)\n", |
| 530 | + "anim" |
433 | 531 | ] |
434 | 532 | }, |
435 | 533 | { |
436 | 534 | "cell_type": "code", |
437 | 535 | "execution_count": null, |
438 | 536 | "metadata": {}, |
439 | 537 | "outputs": [], |
440 | | - "source": [ |
441 | | - "HTML(anim.to_jshtml())" |
442 | | - ] |
| 538 | + "source": [] |
443 | 539 | } |
444 | 540 | ], |
445 | 541 | "metadata": { |
|
0 commit comments