Skip to content

Commit 0ee876f

Browse files
authored
Implement prng_key as a mutable array (#9305)
1 parent aaff959 commit 0ee876f

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

torchax/pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ classifiers = [
4040
path = "torchax/__init__.py"
4141

4242
[project.optional-dependencies]
43-
cpu = ["jax[cpu]>=0.4.30", "jax[cpu]"]
43+
cpu = ["jax[cpu]>=0.6.2", "jax[cpu]"]
4444
# Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html`
45-
tpu = ["jax[cpu]>=0.4.30", "jax[tpu]"]
46-
cuda = ["jax[cpu]>=0.4.30", "jax[cuda12]"]
47-
odml = ["jax[cpu]>=0.4.30", "jax[cpu]"]
45+
tpu = ["jax[cpu]>=0.6.2", "jax[tpu]"]
46+
cuda = ["jax[cpu]>=0.6.2", "jax[cuda12]"]
47+
odml = ["jax[cpu]>=0.6.2", "jax[cpu]"]
4848

4949
[tool.hatch.build.targets.wheel]
5050
packages = ["torchax"]

torchax/test/test_context.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchax import tensor
66
import torchax.interop
77

8-
xla_env = tensor.Environment()
8+
xla_env = torchax.default_env()
99

1010

1111
class TestContext(unittest.TestCase):
@@ -73,15 +73,16 @@ def random_op():
7373
random_jit = torchax.interop.jax_jit(random_op)
7474
self.assertIsInstance(random_jit(), tensor.Tensor)
7575

76-
# Result always expected to be the same for a jitted function because seeds
77-
# are baked in
78-
torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0)
76+
# If we run the JIT twice, the random values should be different.
77+
with self.assertRaises(AssertionError):
78+
torch.testing.assert_close(random_jit(), random_jit(), atol=0, rtol=0)
7979

8080
def test_generator_seed(self):
8181
with xla_env:
8282
x = torch.randn(2, 3, generator=torch.Generator().manual_seed(0))
8383
y = torch.randn(2, 3, generator=torch.Generator().manual_seed(0))
8484

85+
# Values will be the same given the same seed.
8586
torch.testing.assert_close(x, y)
8687

8788
def test_buffer(self):

torchax/torchax/tensor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import random
21
import logging
32
import sys
43
import contextlib
@@ -13,10 +12,11 @@
1312
import torch.utils._mode_utils as mode_utils
1413
import torch.utils._python_dispatch as torch_dispatch
1514
import torch.utils._pytree as torch_pytree
16-
from torchax.view import View, NarrowInfo
15+
from torchax.view import View
1716
from torchax import config
1817
from torchax.ops import mappings, ops_registry
1918
from torchax import amp
19+
from jax.experimental import mutable_array
2020

2121
logger = logging.getLogger(__name__)
2222

@@ -323,11 +323,16 @@ def __init__(self, configuration=None):
323323
self._manually_entered = False
324324
self.enabled = False
325325
self._jax_devices = set(["jax", "jax_cpu", "xla"])
326-
self.prng_key = jax.random.key(torch.initial_seed() % (1 << 63))
326+
self._prng_key = mutable_array(
327+
jax.random.key(torch.initial_seed() % (1 << 63)))
327328
self.autocast_dtype = None
328329

329330
def manual_seed(self, key):
330-
self.prng_key = jax.random.key(key)
331+
self._prng_key = mutable_array(jax.random.key(key))
332+
333+
@property
334+
def prng_key(self):
335+
return self._prng_key[...]
331336

332337
def get_as_jax_device(self, device: Any):
333338
if device is None:
@@ -431,8 +436,10 @@ def get_and_rotate_prng_key(self,
431436
generator: Optional[torch.Generator] = None):
432437
if generator is not None:
433438
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
434-
self.prng_key = jax.random.key(generator.initial_seed() % (2**63))
435-
self.prng_key, next_key = jax.random.split(self.prng_key)
439+
self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63))
440+
old_key = self._prng_key[...]
441+
new_prng_key, next_key = jax.random.split(old_key)
442+
self._prng_key[...] = new_prng_key
436443
return next_key
437444

438445
def _handle_tensor_constructor(self, func, args, kwargs):

0 commit comments

Comments
 (0)