Skip to content

Commit

Permalink
Consolidate SphericalHarmonics implementations
Browse files Browse the repository at this point in the history
- Delete ComplexSphericalHarmonics
- Replace the old RealSphericalHarmonics with RealSphericalHarmonicsWithZeroImag, which is what we actually use in practice (it's faster and supports parallelism)

PiperOrigin-RevId: 693112634
  • Loading branch information
shoyer authored and Dinosaur authors committed Nov 6, 2024
1 parent ce7209b commit 711b873
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 419 deletions.
5 changes: 0 additions & 5 deletions dinosaur/coordinate_systems_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ class CoordinateSystemTest(parameterized.TestCase):
dict(
horizontal=spherical_harmonic.Grid.T21(),
vertical=sigma_coordinates.SigmaCoordinates.equidistant(6)),
dict(
horizontal=spherical_harmonic.Grid.T21(
spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics
),
vertical=sigma_coordinates.SigmaCoordinates.equidistant(8)),
dict(
horizontal=spherical_harmonic.Grid.T21(),
vertical=layer_coordinates.LayerCoordinates(5)),
Expand Down
50 changes: 0 additions & 50 deletions dinosaur/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,56 +124,6 @@ def real_basis_derivative_with_zero_imag(
return j * jnp.where((i + 1) % 2, u_down, -u_up)


def complex_basis(wavenumbers: int, nodes: int) -> np.ndarray:
"""Returns the complex-valued Fourier Basis.
Args:
wavenumbers: number of wavenumbers.
nodes: number of equally spaced nodes in the range [0, 2π). Must satisfy
wavenumbers >= nodes.
Returns:
The nodes x wavenumbers matrix F, such that
F[j, k] = exp(2πi * jk / nodes) / √2π
i.e., the columns of F are the complex Fourier basis functions evenly
spaced points.
The normalization of the basis functions is chosen such that they have unit
L²([0, 2π]) norm.
"""
if wavenumbers > nodes // 2 + 1:
raise ValueError(
'`wavenumbers` must be no greater than `nodes // 2 + 1`;'
f'got wavenumbers = {wavenumbers}, nodes = {nodes}.'
)
basis = scipy.linalg.dft(nodes).conj()[:, :wavenumbers] / np.sqrt(np.pi)
basis[:, 0] /= np.sqrt(2)
return basis


def complex_basis_derivative(
u: jnp.ndarray | jax.Array, axis: int = -1
) -> jax.Array:
"""Calculate the derivative of a signal using a complex basis.
Args:
u: signal to differentiate, in the real Fourier basis.
axis: the axis along which the transform will be applied.
Returns:
The derivative of `u` along `axis`. In particular, if
`u_x = complex_basis_derivative(u)`:
u_x[..., k] = i * k * u[..., k]
"""
if axis >= 0:
raise ValueError('axis must be negative')
k = jnp.arange(u.shape[axis]).reshape((-1,) + (1,) * (-1 - axis))
return 1j * k * u


def quadrature_nodes(nodes: int) -> tuple[np.ndarray, np.ndarray]:
"""Returns nodes and weights for the trapezoidal rule.
Expand Down
45 changes: 0 additions & 45 deletions dinosaur/fourier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for fourier."""
import itertools

from absl.testing import absltest
Expand Down Expand Up @@ -68,48 +66,5 @@ def testNormalized(self, wavenumbers):
np.testing.assert_allclose((f.T * w).dot(f), eye, atol=1e-12)


class ComplexFourierTest(parameterized.TestCase):

@parameterized.parameters(
dict(wavenumbers=4, nodes=7),
dict(wavenumbers=11, nodes=21),
dict(wavenumbers=32, nodes=63),
)
def testBasis(self, wavenumbers, nodes):
f = fourier.complex_basis(wavenumbers, nodes)
for j, k in itertools.product(range(nodes), range(wavenumbers)):
normalization = np.sqrt(np.pi)
if k == 0:
normalization *= np.sqrt(2)
expected = np.exp(2 * np.pi * 1j * j * k / nodes) / normalization
np.testing.assert_allclose(f[j, k], expected, atol=1e-12)

@parameterized.parameters(
dict(wavenumbers=4, seed=0),
dict(wavenumbers=11, seed=0),
dict(wavenumbers=32, seed=0),
)
def testDerivatives(self, wavenumbers, seed):
f = np.random.RandomState(seed).normal(size=[wavenumbers])
f_x = fourier.complex_basis_derivative(f)
for k in range(wavenumbers):
np.testing.assert_allclose(f_x[k], 1j * k * f[k])

@parameterized.parameters(
dict(wavenumbers=4),
dict(wavenumbers=16),
dict(wavenumbers=256),
)
def testNormalized(self, wavenumbers):
"""Tests that the basis functions are normalized on [0, 2π]."""
nodes = 2 * wavenumbers - 1
f = fourier.complex_basis(wavenumbers, nodes)
_, w = fourier.quadrature_nodes(nodes)
expected = 2 * np.eye(wavenumbers)
expected[0, 0] = 1
norms = (f.T.conj() * w).dot(f)
np.testing.assert_allclose(norms, expected, atol=1e-12)


if __name__ == '__main__':
absltest.main()
5 changes: 1 addition & 4 deletions dinosaur/primitive_equations_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from absl.testing import absltest
import chex

from dinosaur import coordinate_systems
from dinosaur import primitive_equations
from dinosaur import primitive_equations_states
Expand All @@ -27,7 +26,6 @@
from dinosaur import spherical_harmonic
from dinosaur import time_integration
from dinosaur import xarray_utils

import jax
from jax import config
import jax.numpy as jnp
Expand All @@ -38,8 +36,7 @@ def make_coords(
max_wavenumber: int,
num_layers: int,
mesh: jax.sharding.Mesh | None = None,
spherical_harmonics_impl: ... = (
spherical_harmonic.RealSphericalHarmonicsWithZeroImag),
spherical_harmonics_impl: ... = spherical_harmonic.RealSphericalHarmonics,
) -> coordinate_systems.CoordinateSystem:
return coordinate_systems.CoordinateSystem(
spherical_harmonic.Grid.with_wavenumbers(
Expand Down
Loading

0 comments on commit 711b873

Please sign in to comment.