Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693112634
  • Loading branch information
shoyer authored and Dinosaur authors committed Nov 4, 2024
1 parent ce7209b commit c07182e
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 163 deletions.
50 changes: 0 additions & 50 deletions dinosaur/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,56 +124,6 @@ def real_basis_derivative_with_zero_imag(
return j * jnp.where((i + 1) % 2, u_down, -u_up)


def complex_basis(wavenumbers: int, nodes: int) -> np.ndarray:
"""Returns the complex-valued Fourier Basis.
Args:
wavenumbers: number of wavenumbers.
nodes: number of equally spaced nodes in the range [0, 2π). Must satisfy
wavenumbers >= nodes.
Returns:
The nodes x wavenumbers matrix F, such that
F[j, k] = exp(2πi * jk / nodes) / √2π
i.e., the columns of F are the complex Fourier basis functions evenly
spaced points.
The normalization of the basis functions is chosen such that they have unit
L²([0, 2π]) norm.
"""
if wavenumbers > nodes // 2 + 1:
raise ValueError(
'`wavenumbers` must be no greater than `nodes // 2 + 1`;'
f'got wavenumbers = {wavenumbers}, nodes = {nodes}.'
)
basis = scipy.linalg.dft(nodes).conj()[:, :wavenumbers] / np.sqrt(np.pi)
basis[:, 0] /= np.sqrt(2)
return basis


def complex_basis_derivative(
u: jnp.ndarray | jax.Array, axis: int = -1
) -> jax.Array:
"""Calculate the derivative of a signal using a complex basis.
Args:
u: signal to differentiate, in the real Fourier basis.
axis: the axis along which the transform will be applied.
Returns:
The derivative of `u` along `axis`. In particular, if
`u_x = complex_basis_derivative(u)`:
u_x[..., k] = i * k * u[..., k]
"""
if axis >= 0:
raise ValueError('axis must be negative')
k = jnp.arange(u.shape[axis]).reshape((-1,) + (1,) * (-1 - axis))
return 1j * k * u


def quadrature_nodes(nodes: int) -> tuple[np.ndarray, np.ndarray]:
"""Returns nodes and weights for the trapezoidal rule.
Expand Down
91 changes: 0 additions & 91 deletions dinosaur/spherical_harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,97 +572,6 @@ def longitudinal_derivative(self, x: Array) -> Array:
return _fourier_derivative_for_real_basis_with_zero_imag(x, self.spmd_mesh)


@dataclasses.dataclass(frozen=True)
class ComplexSphericalHarmonics(SphericalHarmonics):
"""Complex valued spherical harmonics transforms.
This works fine, but in practice is considerably slower (at least on TPUs)
than real-values spherical harmonics transformations, probably because XLA's
code generation for complex numbers is not well optimized.
"""

@functools.cached_property
def nodal_axes(self) -> tuple[np.ndarray, np.ndarray]:
longitude, _ = fourier.quadrature_nodes(self.longitude_nodes)
sin_latitude, _ = get_latitude_nodes(
self.latitude_nodes, self.latitude_spacing
)
return longitude, sin_latitude

@functools.cached_property
def nodal_shape(self) -> tuple[int, int]:
return (self.longitude_nodes, self.latitude_nodes)

@functools.cached_property
def nodal_padding(self) -> tuple[int, int]:
return (0, 0)

@functools.cached_property
def modal_axes(self) -> tuple[np.ndarray, np.ndarray]:
lon_wavenumbers = np.arange(self.longitude_wavenumbers)
tot_wavenumbers = np.arange(self.total_wavenumbers)
return lon_wavenumbers, tot_wavenumbers

@functools.cached_property
def modal_shape(self) -> tuple[int, int]:
return (self.longitude_wavenumbers, self.total_wavenumbers)

@functools.cached_property
def modal_padding(self) -> tuple[int, int]:
return (0, 0)

@functools.cached_property
def modal_dtype(self) -> np.dtype:
return np.dtype(np.complex64)

@functools.cached_property
def mask(self) -> np.ndarray:
m, l = np.meshgrid(*self.modal_axes, indexing='ij')
return m <= l

@functools.cached_property
def basis(self) -> _SphericalHarmonicBasis:
f = fourier.complex_basis(
wavenumbers=self.longitude_wavenumbers,
nodes=self.longitude_nodes,
)
_, wf = fourier.quadrature_nodes(self.longitude_nodes)
x, wp = get_latitude_nodes(self.latitude_nodes, self.latitude_spacing)
w = wf * wp
p = associated_legendre.evaluate(
n_m=self.longitude_wavenumbers, n_l=self.total_wavenumbers, x=x
)
return _SphericalHarmonicBasis(f=f, p=p, w=w)

def inverse_transform(self, x):
p = self.basis.p
f = self.basis.f
px = jax.named_call(einsum, name='inv_legendre')('mjl,...ml->...mj', p, x)
fpx_from_real = jax.named_call(einsum, name='inv_fourier_from_real')(
'im,...mj->...ij', jnp.real(f), jnp.real(px)
)
fpx_from_imag = jax.named_call(einsum, name='inv_fourier_from_imag')(
'im,...mj->...ij', -jnp.imag(f), jnp.imag(px)
)
return fpx_from_real + fpx_from_imag

def transform(self, x):
w = self.basis.w
f = self.basis.f
p = self.basis.p
wx = w * x
fwx = jax.named_call(einsum, name='fwd_fourier')(
'im,...ij->...mj', jnp.conj(f), wx
)
pfwx = jax.named_call(einsum, name='fwd_legendre')(
'mjl,...mj->...ml', p, fwx
)
return pfwx

def longitudinal_derivative(self, x: Array) -> Array:
return fourier.complex_basis_derivative(x, axis=-2)


def _vertical_pad(
field: jax.Array, mesh: jax.sharding.Mesh | None
) -> tuple[jax.Array, int | None]:
Expand Down
21 changes: 0 additions & 21 deletions dinosaur/spherical_harmonic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class SphericalHarmonicTest(parameterized.TestCase):
impl=[
# RealSphericalHarmonicsWithZeroImag uses a different convention
spherical_harmonic.RealSphericalHarmonics,
spherical_harmonic.ComplexSphericalHarmonics,
],
)
def testBasisShapes(self, params, impl):
Expand All @@ -100,7 +99,6 @@ class GridTest(parameterized.TestCase):
impl=[
spherical_harmonic.RealSphericalHarmonics,
spherical_harmonic.RealSphericalHarmonicsWithZeroImag,
spherical_harmonic.ComplexSphericalHarmonics,
],
)
def testGridShape(self, wavenumbers, latitude_spacing, impl):
Expand Down Expand Up @@ -193,14 +191,6 @@ def testConstructors(self):
reverse_einsum_arg_order=True,
),
),
dict(
longitude_wavenumbers=64,
total_wavenumbers=64,
latitude_spacing='equiangular_with_poles',
jit=True,
seed=0,
spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics,
),
)
def testRoundTrip(
self,
Expand Down Expand Up @@ -243,7 +233,6 @@ def testRoundTrip(
impl=[
spherical_harmonic.RealSphericalHarmonics,
spherical_harmonic.RealSphericalHarmonicsWithZeroImag,
spherical_harmonic.ComplexSphericalHarmonics,
],
)
def testLaplacianRoundTrip(self, wavenumbers, latitude_spacing, seed, impl):
Expand All @@ -268,7 +257,6 @@ def testLaplacianRoundTrip(self, wavenumbers, latitude_spacing, seed, impl):
impl=[
spherical_harmonic.RealSphericalHarmonics,
spherical_harmonic.RealSphericalHarmonicsWithZeroImag,
spherical_harmonic.ComplexSphericalHarmonics,
],
)
def testDerivatives(
Expand Down Expand Up @@ -421,14 +409,6 @@ def testLaplacian(self, grid, seed):
atol=1e-11,
seed=0,
),
dict(
grid=spherical_harmonic.Grid.with_wavenumbers(
128,
spherical_harmonics_impl=spherical_harmonic.ComplexSphericalHarmonics,
),
atol=1e-10,
seed=0,
),
dict(
grid=spherical_harmonic.Grid(
longitude_wavenumbers=64,
Expand Down Expand Up @@ -542,7 +522,6 @@ def testIntegrationSurfaceArea(self, wavenumbers, latitude_spacing, radius):
impl=[
spherical_harmonic.RealSphericalHarmonics,
spherical_harmonic.RealSphericalHarmonicsWithZeroImag,
spherical_harmonic.ComplexSphericalHarmonics,
],
)
def testIntegrationSphericalHarmonics(self, params, impl):
Expand Down
1 change: 0 additions & 1 deletion dinosaur/xarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@
'RealSphericalHarmonicsWithZeroImag': (
spherical_harmonic.RealSphericalHarmonicsWithZeroImag
),
'ComplexSphericalHarmonics': spherical_harmonic.ComplexSphericalHarmonics,
}


Expand Down

0 comments on commit c07182e

Please sign in to comment.