77# granted to it by virtue of its status as an intergovernmental organisation
88# nor does it submit to any jurisdiction.
99
10- import torch
1110import pytest
11+ import torch
1212
1313from anemoi .models .layers .spectral_helpers import EcTransOctahedralSHTModule
1414from anemoi .models .layers .spectral_helpers import InverseEcTransOctahedralSHTModule
2020values below the diagonal just set to zero. The m = 0 coefficients are also purely real, to ensure
2121that inverse transformed fields are also real.
2222"""
23+
24+
2325def random_spectral_array (truncation , dtype ):
2426 # Shape: [batch index, ensemble member, l, m, variable]
2527 shape = (1 , 1 , truncation + 1 , truncation + 1 , 1 )
2628 spectral_array = torch .complex (torch .randn (shape , dtype = dtype ), torch .randn (shape , dtype = dtype ))
27- spectral_array [0 , 0 , :, 0 , :].imag = 0.0 # m = 0 modes must be real
29+ spectral_array [0 , 0 , :, 0 , :].imag = 0.0 # m = 0 modes must be real
2830 # Zero the lower triangle, which has no meaning
2931 for i in range (truncation + 1 ):
30- spectral_array [0 , 0 , : i , i , :] = 0.0 + 0.0j
32+ spectral_array [0 , 0 , :i , i , :] = 0.0 + 0.0j
3133
3234 return spectral_array
3335
@@ -41,23 +43,16 @@ def init(self):
4143 device = "cuda" if torch .cuda .is_available () else "cpu" # Spectral truncation
4244 torch .set_default_device (device )
4345
44- truncation = 39 # T39 corresponding to O40 grid
45- dtype = torch .float64 # float 64 for numerical correctness checking
46- torch .manual_seed (0 ) # set the random seed for reproducibility
47- tolerance = 1e-08 # define relative tolerance for numerical comparisons
46+ truncation = 39 # T39 corresponding to O40 grid
47+ dtype = torch .float64 # float 64 for numerical correctness checking
48+ torch .manual_seed (0 ) # set the random seed for reproducibility
49+ tolerance = 1e-08 # define relative tolerance for numerical comparisons
4850
4951 # Create SHT objects
5052 direct = EcTransOctahedralSHTModule (truncation , dtype = dtype ).to (device )
5153 inverse = InverseEcTransOctahedralSHTModule (truncation , dtype = dtype ).to (device )
5254
53- return {
54- "truncation" : truncation ,
55- "dtype" : dtype ,
56- "tolerance" : tolerance ,
57- "direct" : direct ,
58- "inverse" : inverse
59- }
60-
55+ return {"truncation" : truncation , "dtype" : dtype , "tolerance" : tolerance , "direct" : direct , "inverse" : inverse }
6156
6257 def test_idempotency_direct_inverse (self , init ):
6358 """Test that direct followed by inverse transform returns the original data."""
@@ -77,7 +72,6 @@ def test_idempotency_direct_inverse(self, init):
7772 after = inverse (direct (before ))
7873 assert torch .allclose (before , after , rtol = tolerance )
7974
80-
8175 def test_idempotency_inverse_direct (self , init ):
8276 """Test that inverse followed by direct transform returns the original data."""
8377
@@ -93,8 +87,8 @@ def test_idempotency_inverse_direct(self, init):
9387 # Compute max relative diff
9488 maxdiff = 0.0
9589 for i in range (truncation + 1 ):
96- maxdiff = max (maxdiff , torch . abs (
97- ( before [0 , 0 , i :, i , :] - after [0 , 0 , i :, i , :]) / before [0 , 0 , i :, i , :]
98- ). max ())
90+ maxdiff = max (
91+ maxdiff , torch . abs (( before [0 , 0 , i :, i , :] - after [0 , 0 , i :, i , :]) / before [0 , 0 , i :, i , :]). max ()
92+ )
9993
10094 assert maxdiff < tolerance
0 commit comments