Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 90 additions & 21 deletions docs_nnx/hijax/hijax.ipynb

Large diffs are not rendered by default.

27 changes: 25 additions & 2 deletions docs_nnx/hijax/hijax.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,40 @@ jupytext:
jupytext_version: 1.13.8
---

# Hijax Variable
# Hijax

```{code-cell} ipython3
from flax import nnx
import jax
import jax.numpy as jnp
import optax
current_mode = nnx.using_hijax()
current_mode = nnx.using_hijax() # ignore: only needed for testing
```

```{code-cell} ipython3
nnx.use_hijax(True)
rngs = nnx.Rngs(0)
model = nnx.Linear(2, 3, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)
@jax.jit
def train_step(x, y):
loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(model) # tmp fix for jax.grad
optimizer.update(model, grads)
return loss
x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3))
for _ in range(3):
print(train_step(x, y))
```

## Hijax Variable

+++

State propagation:

```{code-cell} ipython3
Expand Down
46 changes: 1 addition & 45 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from flax.nnx.statelib import FlatState, State, map_state
from flax.nnx.variablelib import Variable, is_array_ref, V
from flax.typing import Key, PathParts, is_key_like
from flax.typing import HashableMapping, Key, PathParts, is_key_like
import jax
import numpy as np
import treescope # type: ignore[import-not-found,import-untyped]
Expand Down Expand Up @@ -301,50 +301,6 @@ def get_node_impl_for_type(
return None


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
_mapping: dict[HA, HB] | tp.Mapping[HA, HB]

def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True):
self._mapping = dict(mapping) if copy else mapping

def __contains__(self, key: object) -> bool:
return key in self._mapping

def __getitem__(self, key: HA) -> HB:
return self._mapping[key]

def __iter__(self) -> tp.Iterator[HA]:
return iter(self._mapping)

def __len__(self) -> int:
return len(self._mapping)

def __hash__(self) -> int:
# use type-aware sorting to support int keys
def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]:
key, _ = item
if isinstance(key, int):
return (0, key)
elif isinstance(key, str):
return (1, key)
else:
raise ValueError(f'Unsupported key type: {type(key)!r}')
return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn)))

def __eq__(self, other: tp.Any) -> bool:
return (
isinstance(other, HashableMapping) and self._mapping == other._mapping
)

def __repr__(self) -> str:
return repr(self._mapping)

def update(self, other: tp.Mapping[HA, HB]) -> HashableMapping[HA, HB]:
"""Updates the mapping with another mapping."""
mapping = dict(self._mapping)
mapping.update(other)
return HashableMapping(mapping, copy=False)


@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True, repr=False)
Expand Down
Loading
Loading