-
Notifications
You must be signed in to change notification settings - Fork 78
feat: Consolidate and expand spectral losses #788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4154104
09373fb
fc88f77
12a753c
75e2f1b
73d958a
5d33958
2c8cd46
5d9f721
9b98381
b49aafc
0833dfb
7e8968f
5e7f1e1
6d8db50
3554b84
75836b6
8a9440b
15037f3
1a0a221
eb0460b
219ff37
c787082
9404215
1a73269
4dd6c9b
9537bc2
5329bc7
445a6b8
d013f46
bcd6b8c
70eb8bc
cea9d04
c81cf48
2d76c55
0efd194
9f91942
55e6fde
d727c08
8fdb520
764acc3
228f207
fa8eef8
1a37faa
077d620
14a6b32
16ef3e9
27278f8
97db876
d61a7dc
d341447
e9672ae
0b2bf43
6ecb2dd
8d58072
739741c
10e7c0d
a77e265
2c0575a
d3786b6
9ef8319
cb619e9
cf66868
da826b0
bd5a53b
ba8f531
b18e1cc
d148f5a
aa825db
a78dced
f85de18
5723682
0dfc7d1
494ea80
f75275f
5d2d6a9
452277d
5c95ff4
e2dbdfe
ed4c85d
5549910
d796b02
c6b9062
7bcabac
2b811c1
15fbfbf
f9f6ce3
bae58cd
2366ea5
f5ae9a2
e300c2f
ca81ae3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,213 @@ | ||||||
| # (C) Copyright 2025 Anemoi contributors. | ||||||
| # | ||||||
| # This software is licensed under the terms of the Apache Licence Version 2.0 | ||||||
| # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||||||
| # | ||||||
| # In applying this licence, ECMWF does not waive the privileges and immunities | ||||||
| # granted to it by virtue of its status as an intergovernmental organisation | ||||||
| # nor does it submit to any jurisdiction. | ||||||
|
|
||||||
|
|
||||||
| import numpy as np | ||||||
| import torch | ||||||
| from torch import Tensor | ||||||
| from torch.nn import Module | ||||||
|
|
||||||
|
|
||||||
| def legendre_gauss_weights(n: int, a: float = -1.0, b: float = 1.0) -> np.ndarray: | ||||||
| r"""Helper routine which returns the Legendre-Gauss nodes and weights | ||||||
| on the interval [a, b]. | ||||||
| """ | ||||||
|
|
||||||
| xlg, wlg = np.polynomial.legendre.leggauss(n) | ||||||
| xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5 | ||||||
| wlg = wlg * (b - a) * 0.5 | ||||||
|
|
||||||
| return xlg, wlg | ||||||
|
|
||||||
|
|
||||||
| def legpoly( | ||||||
| mmax: int, | ||||||
| lmax: int, | ||||||
| x: np.ndarray, | ||||||
| norm: str = "ortho", | ||||||
| inverse: bool = False, | ||||||
| csphase: bool = True, | ||||||
| ) -> np.ndarray: | ||||||
| r"""Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x. | ||||||
| The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m | ||||||
| can be turned off optionally. | ||||||
|
|
||||||
| Method of computation follows | ||||||
| [1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. | ||||||
| [2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982; https://apps.dtic.mil/sti/citations/ADA123406. | ||||||
| [3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients. | ||||||
| """ | ||||||
|
|
||||||
| # Compute the tensor P^m_n: | ||||||
| nmax = max(mmax, lmax) | ||||||
| vdm = np.zeros((nmax, nmax, len(x)), dtype=np.float64) | ||||||
|
|
||||||
| norm_factor = 1.0 if norm == "ortho" else np.sqrt(4 * np.pi) | ||||||
| norm_factor = 1.0 / norm_factor if inverse else norm_factor | ||||||
| vdm[0, 0, :] = norm_factor / np.sqrt(4 * np.pi) | ||||||
|
|
||||||
| # Fill the diagonal and the lower diagonal | ||||||
| for n in range(1, nmax): | ||||||
| vdm[n - 1, n, :] = np.sqrt(2 * n + 1) * x * vdm[n - 1, n - 1, :] | ||||||
| vdm[n, n, :] = np.sqrt((2 * n + 1) * (1 + x) * (1 - x) / 2 / n) * vdm[n - 1, n - 1, :] | ||||||
|
|
||||||
| # Fill the remaining values on the upper triangle and multiply b | ||||||
| for n in range(2, nmax): | ||||||
| for m in range(0, n - 1): | ||||||
| vdm[m, n, :] = ( | ||||||
| x * np.sqrt((2 * n - 1) / (n - m) * (2 * n + 1) / (n + m)) * vdm[m, n - 1, :] | ||||||
| - np.sqrt((n + m - 1) / (n - m) * (2 * n + 1) / (2 * n - 3) * (n - m - 1) / (n + m)) * vdm[m, n - 2, :] | ||||||
| ) | ||||||
|
|
||||||
| if norm == "schmidt": | ||||||
| for num in range(0, nmax): | ||||||
| if inverse: | ||||||
| vdm[:, num, :] = vdm[:, num, :] * np.sqrt(2 * num + 1) | ||||||
| else: | ||||||
| vdm[:, num, :] = vdm[:, num, :] / np.sqrt(2 * num + 1) | ||||||
|
|
||||||
| vdm = vdm[:mmax, :lmax] | ||||||
|
|
||||||
| if csphase: | ||||||
| for m in range(1, mmax, 2): | ||||||
| vdm[m] *= -1 | ||||||
|
|
||||||
| return vdm | ||||||
|
|
||||||
|
|
||||||
| def precompute_legpoly( | ||||||
| mmax: int, | ||||||
| lmax: int, | ||||||
| t: np.ndarray, | ||||||
| norm: str = "ortho", | ||||||
| inverse: bool = False, | ||||||
| csphase: bool = True, | ||||||
| ) -> np.ndarray: | ||||||
| r"""Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta). | ||||||
| The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m | ||||||
| can be turned off optionally. | ||||||
|
|
||||||
| Method of computation follows | ||||||
| [1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems. | ||||||
| [2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982; https://apps.dtic.mil/sti/citations/ADA123406. | ||||||
| [3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients. | ||||||
| """ | ||||||
|
|
||||||
| return legpoly(mmax, lmax, np.cos(t), norm=norm, inverse=inverse, csphase=csphase) | ||||||
|
|
||||||
|
|
||||||
| class SphericalHarmonicTransform(Module): | ||||||
|
|
||||||
| def __init__(self, nlat: int, lons_per_lat: list[int], lmax: int | None = None, mmax: int | None = None) -> None: | ||||||
|
|
||||||
| super().__init__() | ||||||
|
|
||||||
| self.lmax = lmax or nlat | ||||||
| self.mmax = mmax or nlat | ||||||
|
|
||||||
| self.nlat = nlat | ||||||
| self.lons_per_lat = lons_per_lat | ||||||
|
|
||||||
| self.n_grid_points = sum(self.lons_per_lat) | ||||||
|
|
||||||
| self.slon = [0] + list(np.cumsum(self.lons_per_lat))[:-1] | ||||||
| self.rlon = [nlat + 8 - nlon // 2 for nlon in self.lons_per_lat] | ||||||
|
|
||||||
| theta, weight = legendre_gauss_weights(nlat) | ||||||
| theta = np.flip(np.arccos(theta)) | ||||||
|
|
||||||
| pct = precompute_legpoly(self.mmax, self.lmax, theta) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I suggest we do this so the polynomials match the ones from ecTrans, in case we want to use ecTrans to generate polynomials. Note that
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any thoughts @PortillaS-Predictia ? |
||||||
| pct = torch.from_numpy(pct) | ||||||
|
|
||||||
| weight = torch.from_numpy(weight) | ||||||
| weight = torch.einsum("mlk, k -> mlk", pct, weight) | ||||||
|
|
||||||
| self.register_buffer("weight", weight, persistent=False) | ||||||
|
|
||||||
| def rfft(self, x: Tensor) -> Tensor: | ||||||
|
|
||||||
| return torch.fft.rfft(input=x, norm="forward") | ||||||
|
|
||||||
| def rfft_rings(self, x: Tensor) -> Tensor: | ||||||
|
|
||||||
| rfft = [self.rfft(x[..., slon : slon + nlon]) for slon, nlon in zip(self.slon, self.lons_per_lat)] | ||||||
|
|
||||||
| rfft = [ | ||||||
| torch.cat([x, torch.zeros((*x.shape[:-1], rlon), device=x.device)], dim=-1) | ||||||
| for x, rlon in zip(rfft, self.rlon) | ||||||
| ] | ||||||
|
|
||||||
| return torch.stack( | ||||||
| tensors=rfft, | ||||||
| dim=-2, | ||||||
| ) | ||||||
|
|
||||||
| def forward(self, x: Tensor) -> Tensor: | ||||||
|
|
||||||
| x = 2.0 * torch.pi * self.rfft_rings(x) | ||||||
| x = torch.view_as_real(x) | ||||||
|
|
||||||
| rl = torch.einsum("...km, mlk -> ...lm", x[..., : self.mmax, 0], self.weight.to(x.dtype)) | ||||||
| im = torch.einsum("...km, mlk -> ...lm", x[..., : self.mmax, 1], self.weight.to(x.dtype)) | ||||||
|
|
||||||
| x = torch.stack((rl, im), -1) | ||||||
| x = torch.view_as_complex(x) | ||||||
|
|
||||||
| return x | ||||||
|
|
||||||
|
|
||||||
| class InverseSphericalHarmonicTransform(Module): | ||||||
|
|
||||||
| def __init__(self, nlat: int, lons_per_lat: list[int], lmax: int | None = None, mmax: int | None = None) -> None: | ||||||
|
|
||||||
| super().__init__() | ||||||
|
|
||||||
| self.lmax = lmax or nlat | ||||||
| self.mmax = mmax or nlat | ||||||
|
|
||||||
| self.nlat = nlat | ||||||
| self.lons_per_lat = lons_per_lat | ||||||
|
|
||||||
| theta, _ = legendre_gauss_weights(nlat) | ||||||
| theta = np.flip(np.arccos(theta)) | ||||||
|
|
||||||
| pct = precompute_legpoly(self.mmax, self.lmax, theta, inverse=True) | ||||||
| pct = torch.from_numpy(pct) | ||||||
|
|
||||||
| self.register_buffer("pct", pct, persistent=False) | ||||||
|
|
||||||
| def irfft(self, x: Tensor, nlon: int) -> Tensor: | ||||||
|
|
||||||
| return torch.fft.irfft( | ||||||
| input=x, | ||||||
| n=nlon, | ||||||
| norm="forward", | ||||||
| ) | ||||||
|
|
||||||
| def irfft_rings(self, x: Tensor) -> Tensor: | ||||||
|
|
||||||
| irfft = [self.irfft(x[..., t, :], nlon) for t, nlon in enumerate(self.lons_per_lat)] | ||||||
|
|
||||||
| return torch.cat( | ||||||
| tensors=irfft, | ||||||
| dim=-1, | ||||||
| ) | ||||||
|
|
||||||
| def forward(self, x: Tensor) -> Tensor: | ||||||
|
|
||||||
| x = torch.view_as_real(x) | ||||||
|
|
||||||
| rl = torch.einsum("...lm, mlk -> ...km", x[..., 0], self.pct.to(x.dtype)) | ||||||
| im = torch.einsum("...lm, mlk -> ...km", x[..., 1], self.pct.to(x.dtype)) | ||||||
|
|
||||||
| x = torch.stack((rl, im), -1).to(x.dtype) | ||||||
| x = torch.view_as_complex(x) | ||||||
| x = self.irfft_rings(x) | ||||||
|
|
||||||
| return x | ||||||
Uh oh!
There was an error while loading. Please reload this page.