|
1 |
| -import random |
2 | 1 | import logging
|
3 | 2 | import sys
|
4 | 3 | import contextlib
|
|
13 | 12 | import torch.utils._mode_utils as mode_utils
|
14 | 13 | import torch.utils._python_dispatch as torch_dispatch
|
15 | 14 | import torch.utils._pytree as torch_pytree
|
16 |
| -from torchax.view import View, NarrowInfo |
| 15 | +from torchax.view import View |
17 | 16 | from torchax import config
|
18 | 17 | from torchax.ops import mappings, ops_registry
|
19 | 18 | from torchax import amp
|
| 19 | +from jax.experimental import mutable_array |
20 | 20 |
|
21 | 21 | logger = logging.getLogger(__name__)
|
22 | 22 |
|
@@ -323,11 +323,16 @@ def __init__(self, configuration=None):
|
323 | 323 | self._manually_entered = False
|
324 | 324 | self.enabled = False
|
325 | 325 | self._jax_devices = set(["jax", "jax_cpu", "xla"])
|
326 |
| - self.prng_key = jax.random.key(torch.initial_seed() % (1 << 63)) |
| 326 | + self._prng_key = mutable_array( |
| 327 | + jax.random.key(torch.initial_seed() % (1 << 63))) |
327 | 328 | self.autocast_dtype = None
|
328 | 329 |
|
329 | 330 | def manual_seed(self, key):
|
330 |
| - self.prng_key = jax.random.key(key) |
| 331 | + self._prng_key = mutable_array(jax.random.key(key)) |
| 332 | + |
| 333 | + @property |
| 334 | + def prng_key(self): |
| 335 | + return self._prng_key[...] |
331 | 336 |
|
332 | 337 | def get_as_jax_device(self, device: Any):
|
333 | 338 | if device is None:
|
@@ -431,8 +436,10 @@ def get_and_rotate_prng_key(self,
|
431 | 436 | generator: Optional[torch.Generator] = None):
|
432 | 437 | if generator is not None:
|
433 | 438 | with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
|
434 |
| - self.prng_key = jax.random.key(generator.initial_seed() % (2**63)) |
435 |
| - self.prng_key, next_key = jax.random.split(self.prng_key) |
| 439 | + self._prng_key[...] = jax.random.key(generator.initial_seed() % (2**63)) |
| 440 | + old_key = self._prng_key[...] |
| 441 | + new_prng_key, next_key = jax.random.split(old_key) |
| 442 | + self._prng_key[...] = new_prng_key |
436 | 443 | return next_key
|
437 | 444 |
|
438 | 445 | def _handle_tensor_constructor(self, func, args, kwargs):
|
|
0 commit comments