Skip to content

Commit

Permalink
Use jax.random.key
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Nov 8, 2023
1 parent c99a071 commit 3381087
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion docs/tjax/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -1241,4 +1241,4 @@ <h4><code><a title="tjax.custom_vjp" href="#tjax.custom_vjp">custom_vjp</a></cod
<p>Generated by <a href="https://pdoc3.github.io/pdoc"><cite>pdoc</cite> 0.8.4</a>.</p>
</footer>
</body>
</html>
</html>
4 changes: 2 additions & 2 deletions tests/fixed_point/test_fixed_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp
import pytest
from jax import grad
from jax.random import KeyArray, PRNGKey, normal, split
from jax.random import KeyArray, key, normal, split
from numpy.testing import assert_allclose
from typing_extensions import override

Expand Down Expand Up @@ -157,7 +157,7 @@ def test_grad(fixed_point_using_while: C,
def test_noisy_grad(noisy_it_fun: NoisyNewtonsMethod, theta: float) -> None:

def fixed_point_using_while_of_theta(theta: float) -> float:
state = (8.0, PRNGKey(123))
state = (8.0, key(123))
x, _ = noisy_it_fun.find_fixed_point(theta, state).current_state
return x
assert_allclose(theta,
Expand Down
5 changes: 2 additions & 3 deletions tests/fixed_point/test_use_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import pytest
from jax import grad
from jax.random import KeyArray, PRNGKey, normal, split
from jax.random import KeyArray, key, normal, split
from numpy.testing import assert_allclose
from typing_extensions import override

Expand Down Expand Up @@ -76,8 +76,7 @@ class EncodingElement:
diffusion: float = 0.01

def _initial_state(self) -> EncodingState:
return EncodingState(EncodingConfiguration(8.0, 1),
PRNGKey(123))
return EncodingState(EncodingConfiguration(8.0, 1), key(123))

def iterate(self,
ec: EncodingConfiguration,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest
from jax import enable_custom_prng, jit, vmap
from jax.random import KeyArray, PRNGKey
from jax.random import KeyArray, key
from pytest import CaptureFixture
from rich.console import Console

Expand Down Expand Up @@ -149,12 +149,12 @@ def f(x: RealArray) -> RealArray:
def test_tapped_key(capsys: CaptureFixture[str],
console: Console) -> None:
with enable_custom_prng():
key = PRNGKey(123)
k = key(123)
@jit
def f(x: KeyArray) -> KeyArray:
return tapped_print_generic(x)

f(key)
f(k)
captured = capsys.readouterr()
verify(captured.out,
"""
Expand Down

0 comments on commit 3381087

Please sign in to comment.