Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
4154104
new module with basic structure
frazane Nov 14, 2025
09373fb
move spectral transform to anemoi models
frazane Dec 2, 2025
fc88f77
Merge branch 'main' into feat/spectral-losses-groundwork
frazane Dec 2, 2025
12a753c
add log spectral distance loss
frazane Dec 3, 2025
75e2f1b
temporary node indexing for 2d fft
frazane Dec 3, 2025
73d958a
fix combined loss schema
frazane Dec 4, 2025
5d33958
add pydantic schemas
frazane Dec 4, 2025
2c8cd46
Spherical harmonics transform (cartesian transform by B. Bonev et al.…
PortillaS-Predictia Dec 4, 2025
5d9f721
bugfixes
frazane Dec 5, 2025
9b98381
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2025
b49aafc
Minor change after pre-commit
PortillaS-Predictia Dec 10, 2025
0833dfb
Refactor spatial
OpheliaMiralles Dec 22, 2025
7e8968f
Merge remote-tracking branch 'origin/main' into feat/spectral-losses-…
OpheliaMiralles Dec 22, 2025
5e7f1e1
Merge remote-tracking branch 'origin/feature/spectral-transform' into…
OpheliaMiralles Dec 23, 2025
6d8db50
Merge #729 and add tests
OpheliaMiralles Dec 23, 2025
3554b84
Remove generic SHT from transform choices
OpheliaMiralles Dec 23, 2025
75836b6
Add SpectralCRPS loss and tests
OpheliaMiralles Dec 29, 2025
8a9440b
Precommit
OpheliaMiralles Dec 29, 2025
15037f3
Add torch-dct to pyproject
OpheliaMiralles Dec 29, 2025
1a0a221
Make ruff happy
OpheliaMiralles Dec 29, 2025
eb0460b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2025
219ff37
Merge branch 'main' into feat/spectral-losses-groundwork
OpheliaMiralles Jan 2, 2026
c787082
Replace hardcoded tensordim
OpheliaMiralles Jan 5, 2026
9404215
Merge remote-tracking branch 'origin/main' into feat/spectral-losses-…
OpheliaMiralles Jan 5, 2026
1a73269
Merge branch 'main' into feat/spectral-losses-groundwork
OpheliaMiralles Jan 6, 2026
4dd6c9b
Merge branch 'main' into feat/spectral-losses-groundwork
OpheliaMiralles Jan 6, 2026
9537bc2
Merge branch 'main' into feat/spectral-losses-groundwork
OpheliaMiralles Jan 7, 2026
5329bc7
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 12, 2026
445a6b8
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 12, 2026
d013f46
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 12, 2026
bcd6b8c
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 14, 2026
70eb8bc
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 16, 2026
cea9d04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
c81cf48
Change SpectralTransform to abstract base class
OpheliaMiralles Jan 16, 2026
2d76c55
Add copyright and license comments to test_filtered_loss.py
OpheliaMiralles Jan 16, 2026
0efd194
Rename forward method to __call__ in SpectralTransform
OpheliaMiralles Jan 16, 2026
9f91942
Change exception type in loss function test
OpheliaMiralles Jan 16, 2026
55e6fde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
d727c08
Refactor data selection and rearrangement in transforms
OpheliaMiralles Jan 16, 2026
8fdb520
Change exception type in loss function test
OpheliaMiralles Jan 16, 2026
764acc3
Import TensorDim from training.utils.enums
OpheliaMiralles Jan 16, 2026
228f207
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
fa8eef8
Merge conflicts
OpheliaMiralles Jan 16, 2026
1a37faa
Remove xdim ydim from signature
OpheliaMiralles Jan 16, 2026
077d620
Update models/src/anemoi/models/layers/spectral_transforms.py
OpheliaMiralles Jan 19, 2026
14a6b32
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 19, 2026
16ef3e9
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 20, 2026
27278f8
Apply suggestions from code review
OpheliaMiralles Jan 20, 2026
97db876
remove x_dim and y_dim from spectral loss as not used.
sahahner Jan 20, 2026
d61a7dc
remove variables that are not used from ectrans sht layer
sahahner Jan 20, 2026
d341447
Merge branch 'feat/harmonize-spectral-losses' of github.com:ecmwf/ane…
sahahner Jan 20, 2026
e9672ae
remove xdim and ydim from octahedral sht and replace ylim with nlat
sahahner Jan 20, 2026
0b2bf43
update tests to changes in when to take x_dim and y_dim as init argument
sahahner Jan 20, 2026
6ecb2dd
update test_filtered_loss
sahahner Jan 20, 2026
8d58072
Remove cutoff factor
OpheliaMiralles Jan 20, 2026
739741c
Add torch-dct as optional dependency
OpheliaMiralles Jan 20, 2026
10e7c0d
Add dependency to model too
OpheliaMiralles Jan 20, 2026
a77e265
Add dependencies to test extra
OpheliaMiralles Jan 20, 2026
2c0575a
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 21, 2026
d3786b6
Added alpha and no_autocast arguments
evenmn Jan 21, 2026
9ef8319
Minor
evenmn Jan 21, 2026
cb619e9
Added low-pass filter to FFT2D transform
evenmn Jan 21, 2026
cf66868
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2026
da826b0
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 21, 2026
bd5a53b
Fix indentation issue
OpheliaMiralles Jan 21, 2026
ba8f531
remove node_slice, inherit from torch.nn.module
sahahner Jan 21, 2026
b18e1cc
Merge branch 'feat/harmonize-spectral-losses' of github.com:ecmwf/ane…
sahahner Jan 21, 2026
d148f5a
Replace abc by torch nn
OpheliaMiralles Jan 21, 2026
aa825db
Fix test
OpheliaMiralles Jan 21, 2026
a78dced
baseloss takes kwargs in forward function
sahahner Jan 21, 2026
f85de18
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 22, 2026
5723682
Changed TensorDim to make it compatible with targets without ensemble…
evenmn Jan 22, 2026
0dfc7d1
Removed kwargs passing to BaseLoss, as BaseLoss only takes one argume…
evenmn Jan 22, 2026
494ea80
A bunch of fixed needed to make SpectralCRPS run with stretched grid …
evenmn Jan 22, 2026
f75275f
Merge branch 'feat/harmonize-spectral-losses' of github.com:ecmwf/ane…
evenmn Jan 22, 2026
5d2d6a9
Fix tests
OpheliaMiralles Jan 23, 2026
452277d
no sharding in loss function for spectral losses
sahahner Jan 26, 2026
5c95ff4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2026
e2dbdfe
remove healpix inverse from spectral helpers
sahahner Jan 28, 2026
ed4c85d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2026
5549910
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Jan 29, 2026
d796b02
Merge branch 'main' into feat/harmonize-spectral-losses
OpheliaMiralles Feb 2, 2026
c6b9062
Remove ecTrans-based transform
samhatfield Jan 29, 2026
7bcabac
Remove Cartesian SHT
samhatfield Jan 29, 2026
2b811c1
Implement generic SphericalHarmonicTransform
samhatfield Jan 29, 2026
15fbfbf
Rename nlon to lons_per_lat
samhatfield Jan 29, 2026
f9f6ce3
Update spectral_transform wrapper classes
samhatfield Jan 29, 2026
bae58cd
Add test suites for regular and octahedral SHT
samhatfield Jan 29, 2026
2366ea5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
f5ae9a2
Update spectral losses
samhatfield Feb 3, 2026
e300c2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
ca81ae3
Merge branch 'main' into feat/harmonize-spectral-losses
samhatfield Feb 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions models/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ optional-dependencies.migrations = [
"rich",
]

optional-dependencies.tests = [ "hypothesis>=6.11", "pytest>=8" ]

optional-dependencies.spectral = [ "torch-dct>=0.1.6" ]
optional-dependencies.tests = [ "anemoi-models[spectral]", "hypothesis>=6.11", "pytest>=8" ]
urls.Documentation = "https://anemoi-models.readthedocs.io/"
urls.Homepage = "https://github.com/ecmwf/anemoi-models/"
urls.Issues = "https://github.com/ecmwf/anemoi-models/issues"
Expand Down
213 changes: 213 additions & 0 deletions models/src/anemoi/models/layers/spectral_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# (C) Copyright 2025 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import numpy as np
import torch
from torch import Tensor
from torch.nn import Module


def legendre_gauss_weights(n: int, a: float = -1.0, b: float = 1.0) -> np.ndarray:
r"""Helper routine which returns the Legendre-Gauss nodes and weights
on the interval [a, b].
"""

xlg, wlg = np.polynomial.legendre.leggauss(n)
xlg = (b - a) * 0.5 * xlg + (b + a) * 0.5
wlg = wlg * (b - a) * 0.5

return xlg, wlg


def legpoly(
mmax: int,
lmax: int,
x: np.ndarray,
norm: str = "ortho",
inverse: bool = False,
csphase: bool = True,
) -> np.ndarray:
r"""Computes the values of (-1)^m c^l_m P^l_m(x) at the positions specified by x.
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.

Method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982; https://apps.dtic.mil/sti/citations/ADA123406.
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients.
"""

# Compute the tensor P^m_n:
nmax = max(mmax, lmax)
vdm = np.zeros((nmax, nmax, len(x)), dtype=np.float64)

norm_factor = 1.0 if norm == "ortho" else np.sqrt(4 * np.pi)
norm_factor = 1.0 / norm_factor if inverse else norm_factor
vdm[0, 0, :] = norm_factor / np.sqrt(4 * np.pi)

# Fill the diagonal and the lower diagonal
for n in range(1, nmax):
vdm[n - 1, n, :] = np.sqrt(2 * n + 1) * x * vdm[n - 1, n - 1, :]
vdm[n, n, :] = np.sqrt((2 * n + 1) * (1 + x) * (1 - x) / 2 / n) * vdm[n - 1, n - 1, :]

# Fill the remaining values on the upper triangle and multiply b
for n in range(2, nmax):
for m in range(0, n - 1):
vdm[m, n, :] = (
x * np.sqrt((2 * n - 1) / (n - m) * (2 * n + 1) / (n + m)) * vdm[m, n - 1, :]
- np.sqrt((n + m - 1) / (n - m) * (2 * n + 1) / (2 * n - 3) * (n - m - 1) / (n + m)) * vdm[m, n - 2, :]
)

if norm == "schmidt":
for num in range(0, nmax):
if inverse:
vdm[:, num, :] = vdm[:, num, :] * np.sqrt(2 * num + 1)
else:
vdm[:, num, :] = vdm[:, num, :] / np.sqrt(2 * num + 1)

vdm = vdm[:mmax, :lmax]

if csphase:
for m in range(1, mmax, 2):
vdm[m] *= -1

return vdm


def precompute_legpoly(
mmax: int,
lmax: int,
t: np.ndarray,
norm: str = "ortho",
inverse: bool = False,
csphase: bool = True,
) -> np.ndarray:
r"""Computes the values of (-1)^m c^l_m P^l_m(\cos \theta) at the positions specified by t (theta).
The resulting tensor has shape (mmax, lmax, len(x)). The Condon-Shortley Phase (-1)^m
can be turned off optionally.

Method of computation follows
[1] Schaeffer, N.; Efficient spherical harmonic transforms aimed at pseudospectral numerical simulations, G3: Geochemistry, Geophysics, Geosystems.
[2] Rapp, R.H.; A Fortran Program for the Computation of Gravimetric Quantities from High Degree Spherical Harmonic Expansions, Ohio State University Columbus; report; 1982; https://apps.dtic.mil/sti/citations/ADA123406.
[3] Schrama, E.; Orbit integration based upon interpolated gravitational gradients.
"""

return legpoly(mmax, lmax, np.cos(t), norm=norm, inverse=inverse, csphase=csphase)


class SphericalHarmonicTransform(Module):

def __init__(self, nlat: int, lons_per_lat: list[int], lmax: int | None = None, mmax: int | None = None) -> None:

super().__init__()

self.lmax = lmax or nlat
self.mmax = mmax or nlat

self.nlat = nlat
self.lons_per_lat = lons_per_lat

self.n_grid_points = sum(self.lons_per_lat)

self.slon = [0] + list(np.cumsum(self.lons_per_lat))[:-1]
self.rlon = [nlat + 8 - nlon // 2 for nlon in self.lons_per_lat]

theta, weight = legendre_gauss_weights(nlat)
theta = np.flip(np.arccos(theta))

pct = precompute_legpoly(self.mmax, self.lmax, theta)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pct = precompute_legpoly(self.mmax, self.lmax, theta)
pct = precompute_legpoly(self.mmax, self.lmax, theta, norm="ectrans", csphase=False)

I suggest we do this so the polynomials match the ones from ecTrans, in case we want to use ecTrans to generate polynomials. Note that norm="ectrans" doesn't actually activate a particular code path, but we have to give it some non-None value so the default isn't used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts @PortillaS-Predictia ?

pct = torch.from_numpy(pct)

weight = torch.from_numpy(weight)
weight = torch.einsum("mlk, k -> mlk", pct, weight)

self.register_buffer("weight", weight, persistent=False)

def rfft(self, x: Tensor) -> Tensor:

return torch.fft.rfft(input=x, norm="forward")

def rfft_rings(self, x: Tensor) -> Tensor:

rfft = [self.rfft(x[..., slon : slon + nlon]) for slon, nlon in zip(self.slon, self.lons_per_lat)]

rfft = [
torch.cat([x, torch.zeros((*x.shape[:-1], rlon), device=x.device)], dim=-1)
for x, rlon in zip(rfft, self.rlon)
]

return torch.stack(
tensors=rfft,
dim=-2,
)

def forward(self, x: Tensor) -> Tensor:

x = 2.0 * torch.pi * self.rfft_rings(x)
x = torch.view_as_real(x)

rl = torch.einsum("...km, mlk -> ...lm", x[..., : self.mmax, 0], self.weight.to(x.dtype))
im = torch.einsum("...km, mlk -> ...lm", x[..., : self.mmax, 1], self.weight.to(x.dtype))

x = torch.stack((rl, im), -1)
x = torch.view_as_complex(x)

return x


class InverseSphericalHarmonicTransform(Module):

def __init__(self, nlat: int, lons_per_lat: list[int], lmax: int | None = None, mmax: int | None = None) -> None:

super().__init__()

self.lmax = lmax or nlat
self.mmax = mmax or nlat

self.nlat = nlat
self.lons_per_lat = lons_per_lat

theta, _ = legendre_gauss_weights(nlat)
theta = np.flip(np.arccos(theta))

pct = precompute_legpoly(self.mmax, self.lmax, theta, inverse=True)
pct = torch.from_numpy(pct)

self.register_buffer("pct", pct, persistent=False)

def irfft(self, x: Tensor, nlon: int) -> Tensor:

return torch.fft.irfft(
input=x,
n=nlon,
norm="forward",
)

def irfft_rings(self, x: Tensor) -> Tensor:

irfft = [self.irfft(x[..., t, :], nlon) for t, nlon in enumerate(self.lons_per_lat)]

return torch.cat(
tensors=irfft,
dim=-1,
)

def forward(self, x: Tensor) -> Tensor:

x = torch.view_as_real(x)

rl = torch.einsum("...lm, mlk -> ...km", x[..., 0], self.pct.to(x.dtype))
im = torch.einsum("...lm, mlk -> ...km", x[..., 1], self.pct.to(x.dtype))

x = torch.stack((rl, im), -1).to(x.dtype)
x = torch.view_as_complex(x)
x = self.irfft_rings(x)

return x
Loading
Loading