Skip to content

Commit c134821

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7a13614 commit c134821

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

models/tests/layers/test_spectral_helpers.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
import torch
1110
import pytest
11+
import torch
1212

1313
from anemoi.models.layers.spectral_helpers import EcTransOctahedralSHTModule
1414
from anemoi.models.layers.spectral_helpers import InverseEcTransOctahedralSHTModule
@@ -20,14 +20,16 @@
2020
values below the diagonal just set to zero. The m = 0 coefficients are also purely real, to ensure
2121
that inverse transformed fields are also real.
2222
"""
23+
24+
2325
def random_spectral_array(truncation, dtype):
2426
# Shape: [batch index, ensemble member, l, m, variable]
2527
shape = (1, 1, truncation + 1, truncation + 1, 1)
2628
spectral_array = torch.complex(torch.randn(shape, dtype=dtype), torch.randn(shape, dtype=dtype))
27-
spectral_array[0, 0, :, 0, :].imag = 0.0 # m = 0 modes must be real
29+
spectral_array[0, 0, :, 0, :].imag = 0.0 # m = 0 modes must be real
2830
# Zero the lower triangle, which has no meaning
2931
for i in range(truncation + 1):
30-
spectral_array[0, 0, : i, i, :] = 0.0 + 0.0j
32+
spectral_array[0, 0, :i, i, :] = 0.0 + 0.0j
3133

3234
return spectral_array
3335

@@ -41,23 +43,16 @@ def init(self):
4143
device = "cuda" if torch.cuda.is_available() else "cpu" # Spectral truncation
4244
torch.set_default_device(device)
4345

44-
truncation = 39 # T39 corresponding to O40 grid
45-
dtype = torch.float64 # float 64 for numerical correctness checking
46-
torch.manual_seed(0) # set the random seed for reproducibility
47-
tolerance = 1e-08 # define relative tolerance for numerical comparisons
46+
truncation = 39 # T39 corresponding to O40 grid
47+
dtype = torch.float64 # float 64 for numerical correctness checking
48+
torch.manual_seed(0) # set the random seed for reproducibility
49+
tolerance = 1e-08 # define relative tolerance for numerical comparisons
4850

4951
# Create SHT objects
5052
direct = EcTransOctahedralSHTModule(truncation, dtype=dtype).to(device)
5153
inverse = InverseEcTransOctahedralSHTModule(truncation, dtype=dtype).to(device)
5254

53-
return {
54-
"truncation": truncation,
55-
"dtype": dtype,
56-
"tolerance": tolerance,
57-
"direct": direct,
58-
"inverse": inverse
59-
}
60-
55+
return {"truncation": truncation, "dtype": dtype, "tolerance": tolerance, "direct": direct, "inverse": inverse}
6156

6257
def test_idempotency_direct_inverse(self, init):
6358
"""Test that direct followed by inverse transform returns the original data."""
@@ -77,7 +72,6 @@ def test_idempotency_direct_inverse(self, init):
7772
after = inverse(direct(before))
7873
assert torch.allclose(before, after, rtol=tolerance)
7974

80-
8175
def test_idempotency_inverse_direct(self, init):
8276
"""Test that inverse followed by direct transform returns the original data."""
8377

@@ -93,8 +87,8 @@ def test_idempotency_inverse_direct(self, init):
9387
# Compute max relative diff
9488
maxdiff = 0.0
9589
for i in range(truncation + 1):
96-
maxdiff = max(maxdiff, torch.abs(
97-
(before[0, 0, i :, i, :] - after[0, 0, i :, i, :]) / before[0, 0, i :, i, :]
98-
).max())
90+
maxdiff = max(
91+
maxdiff, torch.abs((before[0, 0, i:, i, :] - after[0, 0, i:, i, :]) / before[0, 0, i:, i, :]).max()
92+
)
9993

10094
assert maxdiff < tolerance

0 commit comments

Comments
 (0)