diff --git a/docs_nnx/guides/array_ref.ipynb b/docs_nnx/guides/array_ref.ipynb deleted file mode 100644 index 1df71c629..000000000 --- a/docs_nnx/guides/array_ref.ipynb +++ /dev/null @@ -1,602 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "15c2d208", - "metadata": {}, - "source": [ - "# Array Refs (experimental)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "99809892", - "metadata": {}, - "outputs": [], - "source": [ - "from flax import nnx\n", - "import jax\n", - "import jax.numpy as jnp\n", - "import optax" - ] - }, - { - "cell_type": "markdown", - "id": "787cf22a", - "metadata": {}, - "source": [ - "## Basics" - ] - }, - { - "cell_type": "markdown", - "id": "d896c926", - "metadata": {}, - "source": [ - "### Array Refs 101" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cae099ce", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[1] = ArrayRef([1, 2, 3], dtype=int32)\n", - "[2] = ArrayRef([2, 3, 4], dtype=int32)\n" - ] - } - ], - "source": [ - "a_ref = jax.new_ref(jnp.array([1, 2, 3]))\n", - "\n", - "@jax.jit\n", - "def increment(a_ref: jax.Ref): # no return!\n", - " array: jax.Array = a_ref[...] # access\n", - " a_ref[...] = array + 1 # update\n", - "\n", - "print(\"[1] =\", a_ref); increment(a_ref); print(\"[2] =\", a_ref)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "fb081f49", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "module @jit_increment attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", - " func.func public @main(%arg0: tensor<3xi32> {tf.aliasing_output = 0 : i32}) -> (tensor<3xi32> {jax.result_info = \"\"}) {\n", - " %c = stablehlo.constant dense<1> : tensor\n", - " %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32>\n", - " %1 = stablehlo.add %arg0, %0 : tensor<3xi32>\n", - " return %1 : tensor<3xi32>\n", - " }\n", - "}\n", - "\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def inc(x):\n", - " x[...] += 1\n", - "\n", - "print(increment.lower(a_ref).as_text())" - ] - }, - { - "cell_type": "markdown", - "id": "26969861", - "metadata": {}, - "source": [ - "### Variables Refs" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "8c3da93c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "variable.has_ref = True\n", - "\n", - "[1] = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef([1, 2, 3], dtype=int32)\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n", - "[2] = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef([2, 3, 4], dtype=int32)\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True)\n", - "print(f\"{variable.has_ref = }\\n\")\n", - "\n", - "print(\"[1] =\", variable); increment(variable); print(\"[2] =\", variable)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0a55df94", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "variable.has_ref = True\n" - ] - } - ], - "source": [ - "with nnx.use_refs(True):\n", - " variable = nnx.Variable(jnp.array([1, 2, 3]))\n", - "\n", - "print(f\"{variable.has_ref = }\")" - ] - }, - { - "cell_type": "markdown", - "id": "839332be", - "metadata": {}, - "source": [ - "Mention `nnx.use_refs` can be used as global flag" - ] - }, - { - "cell_type": "markdown", - "id": "1b2632f1", - "metadata": {}, - "source": [ - "### Changing Status" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "b7b1f421", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nnx.to_refs(model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 6 (24 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n", - "nnx.to_arrays(refs_model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 6 (24 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "class Linear(nnx.Module):\n", - " def __init__(self, in_features, out_features, rngs: nnx.Rngs):\n", - " self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features)))\n", - " self.bias = nnx.Param(jnp.zeros(out_features))\n", - "\n", - " def __call__(self, x):\n", - " return x @ self.kernel + self.bias[None]\n", - "\n", - "model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs\n", - "refs_model = nnx.to_refs(model) # convert to array refs\n", - "arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays\n", - "\n", - "print(\"nnx.to_refs(model) =\", refs_model)\n", - "print(\"nnx.to_arrays(refs_model) =\", arrays_model)" - ] - }, - { - "cell_type": "markdown", - "id": "f4e35e75", - "metadata": {}, - "source": [ - "## Examples" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5400fe58", - "metadata": {}, - "outputs": [], - "source": [ - "class Block(nnx.Module):\n", - " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", - " self.linear = Linear(din, dmid, rngs=rngs)\n", - " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", - " self.dropout = nnx.Dropout(0.1, rngs=rngs)\n", - " self.linear_out = Linear(dmid, dout, rngs=rngs)\n", - "\n", - " def __call__(self, x):\n", - " x = nnx.gelu(self.dropout(self.bn(self.linear(x))))\n", - " return self.linear_out(x)" - ] - }, - { - "cell_type": "markdown", - "id": "ba980b6b", - "metadata": {}, - "source": [ - "### Training Loop" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "566c4249", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(1.000178, dtype=float32)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with nnx.use_refs(True):\n", - " model = Block(2, 64, 3, rngs=nnx.Rngs(0))\n", - " optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", - "\n", - "@jax.jit\n", - "def train_step(model, optimizer, x, y):\n", - " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", - " def loss_fn(params):\n", - " model = nnx.merge(graphdef, params, nondiff)\n", - " return ((model(x) - y) ** 2).mean()\n", - "\n", - " loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad\n", - " optimizer.update(model, grads)\n", - "\n", - " return loss\n", - "\n", - "train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))" - ] - }, - { - "cell_type": "markdown", - "id": "1dea99c1", - "metadata": {}, - "source": [ - "### Scan Over Layers" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d8136be4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "y = [[ 0.82840395 -0.25364894]\n", - " [ 4.9552917 4.93638 ]\n", - " [-7.6721525 -3.4668717 ]]\n" - ] - } - ], - "source": [ - "@nnx.vmap\n", - "def create_stack(rngs):\n", - " return Block(2, 64, 2, rngs=rngs)\n", - "\n", - "with nnx.use_refs(True):\n", - " block_stack = create_stack(nnx.Rngs(0).fork(split=8))\n", - "\n", - "def scan_fn(x, block):\n", - " x = block(x)\n", - " return x, None\n", - "\n", - "x = jax.random.uniform(jax.random.key(0), (3, 2))\n", - "y, _ = jax.lax.scan(scan_fn, x, block_stack)\n", - "\n", - "print(\"y = \", y)" - ] - }, - { - "cell_type": "markdown", - "id": "7ca18a0d", - "metadata": {}, - "source": [ - "## Limitations" - ] - }, - { - "cell_type": "markdown", - "id": "1dd39c79", - "metadata": {}, - "source": [ - "### MutableArray Outputs" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "c6062d19", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Error: function create_model at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1421484665.py:1 traced for jit returned a mutable array reference of type Ref{float32[64]} at output tree path result.bn.bias.value, but mutable array references cannot be returned.\n", - "\n", - "The returned mutable array was created on line /Users/cgarciae/repos/flax/flax/nnx/variablelib.py:250:17 (Variable.__init__).\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def create_model(rngs):\n", - " return Block(2, 64, 3, rngs=rngs)\n", - "\n", - "try:\n", - " with nnx.use_refs(True):\n", - " model = create_model(nnx.Rngs(0))\n", - "except Exception as e:\n", - " print(f\"Error:\", e)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "8bb1e9e7", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 192 (768 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 64 (256 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m64\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 128 (512 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "with nnx.use_refs(False): # <-- disable array refs\n", - " model = create_model(nnx.Rngs(0))\n", - "\n", - "model = nnx.to_refs(model) # convert to mutable after creation\n", - "\n", - "print(\"model.linear =\", model.linear)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "3a078025", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 192 (768 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 64 (256 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m64\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 128 (512 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArrayRef\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "@nnx.jit\n", - "def create_model(rngs):\n", - " return Block(2, 64, 3, rngs=rngs)\n", - "\n", - "with nnx.use_refs(True):\n", - " model = create_model(nnx.Rngs(0))\n", - "\n", - "print(\"model.linear =\", model.linear)" - ] - }, - { - "cell_type": "markdown", - "id": "609bed7c", - "metadata": {}, - "source": [ - "### Reference Sharing (aliasing)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "045d03c1", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing f at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1563421490.py:9 for jit the mutable array reference of type Ref{int32[]} appeared at both a and b.\n" - ] - } - ], - "source": [ - "def get_error(f, *args):\n", - " try:\n", - " return f(*args)\n", - " except Exception as e:\n", - " return f\"{type(e).__name__}: {e}\"\n", - " \n", - "x = jax.new_ref(jnp.array(0))\n", - "\n", - "@jax.jit\n", - "def f(a, b):\n", - " ...\n", - "\n", - "print(get_error(f, x, x))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "bc2e87e5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SharedVariables ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing g at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1828746469.py:13 for jit the mutable array reference of type Ref{int32[]} appeared at both pytree.a.value and pytree.c.value.\n", - "SharedModules ValueError: only one reference to a mutable array may be passed as an argument to a function, but when tracing g at /var/folders/qj/tkq3kvtd66z1t36rfyj9vg0w016bdd/T/ipykernel_43144/1828746469.py:13 for jit the mutable array reference of type Ref{float32[1]} appeared at both pytree.d.bias.value and pytree.f.bias.value.\n" - ] - } - ], - "source": [ - "class SharedVariables(nnx.Pytree):\n", - " def __init__(self):\n", - " self.a = nnx.Variable(jnp.array(0))\n", - " self.b = nnx.Variable(jnp.array(1))\n", - " self.c = self.a\n", - "\n", - "class SharedModules(nnx.Pytree):\n", - " def __init__(self):\n", - " self.d = Linear(1, 1, rngs=nnx.Rngs(0))\n", - " self.e = Linear(1, 1, rngs=nnx.Rngs(0))\n", - " self.f = self.d\n", - "\n", - "@jax.jit\n", - "def g(pytree):\n", - " ...\n", - "\n", - "with nnx.use_refs(True):\n", - " shared_variables = SharedVariables()\n", - " shared_modules = SharedModules()\n", - "\n", - "print(\"SharedVariables\", get_error(g, shared_variables))\n", - "print(\"SharedModules\", get_error(g, shared_modules))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "6298f3d9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "shared variables duplicates: [[('a',), ('c',)]]\n", - "shared modules duplicates: [[('d',), ('f',)]]\n" - ] - } - ], - "source": [ - "if (duplicates := nnx.find_duplicates(shared_variables)):\n", - " print(\"shared variables duplicates:\", duplicates)\n", - "\n", - "if (duplicates := nnx.find_duplicates(shared_modules)):\n", - " print(\"shared modules duplicates: \", duplicates)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "00854d38", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", - " \u001b[38;2;156;220;254m'a'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(0, dtype=int32, weak_type=True)\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254m'b'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(1, dtype=int32, weak_type=True)\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m})\u001b[0m\n", - "updated \u001b[38;2;79;201;177mSharedVariables\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Variable: 2 (8 B)\u001b[0m\n", - " \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(10, dtype=int32)\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(1, dtype=int32, weak_type=True)\n", - " \u001b[38;2;255;213;3m)\u001b[0m,\n", - " \u001b[38;2;156;220;254mc\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", - " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArrayRef(10, dtype=int32)\n", - " \u001b[38;2;255;213;3m)\u001b[0m\n", - "\u001b[38;2;255;213;3m)\u001b[0m\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def h(graphdef, state):\n", - " obj = nnx.merge(graphdef, state)\n", - " obj.a[...] += 10\n", - "\n", - "graphdef, state = nnx.split(shared_variables)\n", - "print(state) # split deduplicates the state\n", - "\n", - "h(graphdef, state)\n", - "\n", - "print(\"updated\", shared_variables)" - ] - } - ], - "metadata": { - "jupytext": { - "formats": "ipynb,md:myst" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs_nnx/guides/hijax.ipynb b/docs_nnx/guides/hijax.ipynb new file mode 100644 index 000000000..23ae167d4 --- /dev/null +++ b/docs_nnx/guides/hijax.ipynb @@ -0,0 +1,572 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "15c2d208", + "metadata": {}, + "source": [ + "# Array Refs (experimental)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "99809892", + "metadata": {}, + "outputs": [], + "source": [ + "from flax import nnx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import optax" + ] + }, + { + "cell_type": "markdown", + "id": "787cf22a", + "metadata": {}, + "source": [ + "## Basics" + ] + }, + { + "cell_type": "markdown", + "id": "26969861", + "metadata": {}, + "source": [ + "### Variables Refs" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8c3da93c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "variable.is_hijax = True\n", + "\n", + "Before = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([1, 2, 3], dtype=int32),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "After = \u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([2, 3, 4], dtype=int32),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True)\n", + "print(f\"{variable.is_hijax = }\\n\")\n", + "\n", + "@jax.jit\n", + "def increment(variable: nnx.Variable[jax.Array]): # no return!\n", + " new_value = variable + 1 # Array-like operations\n", + " variable[...] = new_value # in-place updates\n", + "\n", + "print(\"Before =\", variable); increment(variable); print(\"After =\", variable)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "703265df", + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: enable once as_text is fixed\n", + "# print(increment.lower(variable).as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0a55df94", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "variable.is_hijax = True\n" + ] + } + ], + "source": [ + "nnx.use_hijax(True)\n", + "\n", + "variable = nnx.Variable(jnp.array([1, 2, 3]))\n", + "\n", + "print(f\"{variable.is_hijax = }\")" + ] + }, + { + "cell_type": "markdown", + "id": "839332be", + "metadata": {}, + "source": [ + "Mention `nnx.use_refs` can be used as global flag" + ] + }, + { + "cell_type": "markdown", + "id": "1b2632f1", + "metadata": {}, + "source": [ + "### Changing Status" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b7b1f421", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nnx.to_hijax(model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # HijaxVariable: 6 (24 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n", + "nnx.to_lojax(refs_model) = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # Param: 6 (24 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m3\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m1\u001b[0m, \u001b[38;2;182;207;169m3\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "class Linear(nnx.Module):\n", + " def __init__(self, in_features, out_features, rngs: nnx.Rngs):\n", + " self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features)))\n", + " self.bias = nnx.Param(jnp.zeros(out_features))\n", + "\n", + " def __call__(self, x):\n", + " return x @ self.kernel + self.bias[None]\n", + "\n", + "with nnx.use_hijax(False): # use lojax Variables\n", + " model = Linear(1, 3, rngs=nnx.Rngs(0))\n", + "\n", + "hijax_model = nnx.to_hijax(model) # convert hijax Variables\n", + "arrays_model = nnx.to_lojax(hijax_model) # convert to lojax Variables\n", + "\n", + "print(\"nnx.to_hijax(model) =\", hijax_model)\n", + "print(\"nnx.to_lojax(refs_model) =\", arrays_model)" + ] + }, + { + "cell_type": "markdown", + "id": "f4e35e75", + "metadata": {}, + "source": [ + "## Examples" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5400fe58", + "metadata": {}, + "outputs": [], + "source": [ + "class Block(nnx.Module):\n", + " def __init__(self, din, dmid, dout, rngs: nnx.Rngs):\n", + " self.linear = Linear(din, dmid, rngs=rngs)\n", + " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", + " self.dropout = nnx.Dropout(0.1, rngs=rngs)\n", + " self.linear_out = Linear(dmid, dout, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " x = nnx.gelu(self.dropout(self.bn(self.linear(x))))\n", + " return self.linear_out(x)" + ] + }, + { + "cell_type": "markdown", + "id": "ba980b6b", + "metadata": {}, + "source": [ + "### Training Loop" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "566c4249", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(1.000178, dtype=float32)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# hijax Variables by default\n", + "model = Block(2, 64, 3, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n", + "\n", + "@jax.jit\n", + "def train_step(model, optimizer, x, y):\n", + " graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)\n", + " def loss_fn(params):\n", + " model = nnx.merge(graphdef, params, nondiff)\n", + " return ((model(x) - y) ** 2).mean()\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(nnx.to_lojax(params)) # lojax Variables for jax.grad\n", + " optimizer.update(model, grads)\n", + "\n", + " return loss\n", + "\n", + "train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))" + ] + }, + { + "cell_type": "markdown", + "id": "1dea99c1", + "metadata": {}, + "source": [ + "### Scan Over Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d8136be4", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'aval_property' object has no attribute 'spec'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 5\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;129m@jax\u001b[39m.vmap\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcreate_stack\u001b[39m(rngs):\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m nnx.to_lojax(Block(\u001b[32m2\u001b[39m, \u001b[32m64\u001b[39m, \u001b[32m2\u001b[39m, rngs=rngs))\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m block_stack = nnx.to_hijax(\u001b[43mcreate_stack\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnnx\u001b[49m\u001b[43m.\u001b[49m\u001b[43mRngs\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfork\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m8\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mscan_fn\u001b[39m(x, block):\n\u001b[32m 8\u001b[39m x = block(x)\n", + " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/api.py:1105\u001b[39m, in \u001b[36mvmap..vmap_f\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 1101\u001b[39m api_util._check_no_aliased_ref_args(dbg, avals, args_flat)\n\u001b[32m 1103\u001b[39m axis_size_ = (axis_size \u001b[38;5;28;01mif\u001b[39;00m axis_size \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m\n\u001b[32m 1104\u001b[39m _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, \u001b[33m\"\u001b[39m\u001b[33mvmap\u001b[39m\u001b[33m\"\u001b[39m))\n\u001b[32m-> \u001b[39m\u001b[32m1105\u001b[39m explicit_mesh_axis = \u001b[43m_mapped_axis_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs_flat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_axes_flat\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1106\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m spmd_axis_name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m explicit_mesh_axis \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 1107\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 1108\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mOnly one of spmd_axis_name or arrays sharded on `Explicit` mesh\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1109\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m axis type is allowed. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mspmd_axis_name\u001b[38;5;132;01m=}\u001b[39;00m\u001b[33m and\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1110\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m arrays sharded on \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mexplicit_mesh_axis\u001b[38;5;132;01m=}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/api.py:1141\u001b[39m, in \u001b[36m_mapped_axis_spec\u001b[39m\u001b[34m(args_flat, in_axes_flat)\u001b[39m\n\u001b[32m 1139\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m arg, i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(args_flat, in_axes_flat):\n\u001b[32m 1140\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1141\u001b[39m spec = \u001b[43m_get_spec\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1142\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m out_spec \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m out_spec != spec:\n\u001b[32m 1143\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 1144\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mMapped away dimension of inputs passed to vmap should be sharded\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 1145\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m the same. Got inconsistent axis specs: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mout_spec\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m vs \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mspec\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/api.py:1134\u001b[39m, in \u001b[36m_mapped_axis_spec.._get_spec\u001b[39m\u001b[34m(arg, i)\u001b[39m\n\u001b[32m 1131\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_get_spec\u001b[39m(arg, i):\n\u001b[32m 1132\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 1133\u001b[39m \u001b[38;5;66;03m# Duck type arrays like BCOO arrays can be passed to vmap.\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1134\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mshaped_abstractify\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43msharding\u001b[49m\u001b[43m.\u001b[49m\u001b[43mspec\u001b[49m[i]\n\u001b[32m 1135\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mIndexError\u001b[39;00m, \u001b[38;5;167;01mTypeError\u001b[39;00m):\n\u001b[32m 1136\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[31mAttributeError\u001b[39m: 'aval_property' object has no attribute 'spec'" + ] + } + ], + "source": [ + "@jax.vmap\n", + "def create_stack(rngs):\n", + " return nnx.to_lojax(Block(2, 64, 2, rngs=rngs))\n", + "\n", + "block_stack = nnx.to_hijax(create_stack(nnx.Rngs(0).fork(split=8)))\n", + "\n", + "def scan_fn(x, block):\n", + " x = block(x)\n", + " return x, None\n", + "\n", + "x = jax.random.uniform(jax.random.key(0), (3, 2))\n", + "y, _ = jax.lax.scan(scan_fn, x, block_stack)\n", + "\n", + "print(\"y = \", y)" + ] + }, + { + "cell_type": "markdown", + "id": "7ca18a0d", + "metadata": {}, + "source": [ + "## Limitations" + ] + }, + { + "cell_type": "markdown", + "id": "1dd39c79", + "metadata": {}, + "source": [ + "### MutableArray Outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c6062d19", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error: mutable hitypes should use lo_ty_qdd instead\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def create_model(rngs):\n", + " return Block(2, 64, 3, rngs=rngs)\n", + "\n", + "try:\n", + " model = create_model(nnx.Rngs(0))\n", + "except Exception as e:\n", + " print(f\"Error:\", e)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "8bb1e9e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # HijaxVariable: 192 (768 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m64\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "with nnx.use_hijax(False): # <-- disable hijax Variables\n", + " model = create_model(nnx.Rngs(0))\n", + "\n", + "model = nnx.to_hijax(model) # convert to mutable after creation\n", + "\n", + "print(\"model.linear =\", model.linear)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a078025", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.linear = \u001b[38;2;79;201;177mLinear\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # HijaxVariable: 192 (768 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mbias\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m64\u001b[0m,\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mkernel\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mParam\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mArray\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;156;220;254mshape\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;182;207;169m2\u001b[0m, \u001b[38;2;182;207;169m64\u001b[0m\u001b[38;2;255;213;3m)\u001b[0m, \u001b[38;2;156;220;254mdtype\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mdtype('float32')\u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "# TODO: why does this work?\n", + "@nnx.jit\n", + "def create_model(rngs):\n", + " return Block(2, 64, 3, rngs=rngs)\n", + "\n", + "model = create_model(nnx.Rngs(0))\n", + "\n", + "print(\"model.linear =\", model.linear)" + ] + }, + { + "cell_type": "markdown", + "id": "609bed7c", + "metadata": {}, + "source": [ + "### Reference Sharing (aliasing)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "045d03c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n" + ] + } + ], + "source": [ + "# TODO: why does this not fail?\n", + "def get_error(f, *args):\n", + " try:\n", + " return f(*args)\n", + " except Exception as e:\n", + " return f\"{type(e).__name__}: {e}\"\n", + "\n", + "x = nnx.Variable(jnp.array(0))\n", + "\n", + "@jax.jit\n", + "def f(a, b):\n", + " ...\n", + "\n", + "print(get_error(f, x, x))" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "bc2e87e5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SharedVariables None\n", + "SharedModules None\n" + ] + } + ], + "source": [ + "class SharedVariables(nnx.Pytree):\n", + " def __init__(self):\n", + " self.a = nnx.Variable(jnp.array(0))\n", + " self.b = nnx.Variable(jnp.array(1))\n", + " self.c = self.a\n", + "\n", + "class SharedModules(nnx.Pytree):\n", + " def __init__(self):\n", + " self.d = Linear(1, 1, rngs=nnx.Rngs(0))\n", + " self.e = Linear(1, 1, rngs=nnx.Rngs(0))\n", + " self.f = self.d\n", + "\n", + "@jax.jit\n", + "def g(pytree):\n", + " ...\n", + "\n", + "shared_variables = SharedVariables()\n", + "shared_modules = SharedModules()\n", + "\n", + "print(\"SharedVariables\", get_error(g, shared_variables))\n", + "print(\"SharedModules\", get_error(g, shared_modules))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6298f3d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shared variables duplicates: [[('a',), ('c',)]]\n", + "shared modules duplicates: [[('d',), ('f',)]]\n" + ] + } + ], + "source": [ + "if (duplicates := nnx.find_duplicates(shared_variables)):\n", + " print(\"shared variables duplicates:\", duplicates)\n", + "\n", + "if (duplicates := nnx.find_duplicates(shared_modules)):\n", + " print(\"shared modules duplicates: \", duplicates)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "00854d38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", + " \u001b[38;2;156;220;254m'a'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(0, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254m'b'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(1, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m})\u001b[0m\n", + "updated \u001b[38;2;79;201;177mSharedVariables\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # HijaxVariable: 2 (8 B)\u001b[0m\n", + " \u001b[38;2;156;220;254ma\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(10, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mb\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(1, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m,\n", + " \u001b[38;2;156;220;254mc\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;79;201;177mVariable\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 1 (4 B)\u001b[0m\n", + " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray(10, dtype=int32, weak_type=True),\n", + " \u001b[38;2;156;220;254mis_hijax\u001b[0m\u001b[38;2;212;212;212m=\u001b[0m\u001b[38;2;86;156;214mTrue\u001b[0m\n", + " \u001b[38;2;255;213;3m)\u001b[0m\n", + "\u001b[38;2;255;213;3m)\u001b[0m\n" + ] + } + ], + "source": [ + "@jax.jit\n", + "def h(graphdef, state):\n", + " obj = nnx.merge(graphdef, state)\n", + " obj.a[...] += 10\n", + "\n", + "graphdef, state = nnx.split(shared_variables)\n", + "print(state) # split deduplicates the state\n", + "\n", + "h(graphdef, state)\n", + "\n", + "print(\"updated\", shared_variables)" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs_nnx/guides/array_ref.md b/docs_nnx/guides/hijax.md similarity index 69% rename from docs_nnx/guides/array_ref.md rename to docs_nnx/guides/hijax.md index 1d00c77f5..5a68162b0 100644 --- a/docs_nnx/guides/array_ref.md +++ b/docs_nnx/guides/hijax.md @@ -21,41 +21,31 @@ import optax +++ -### Array Refs 101 +### Variables Refs ```{code-cell} ipython3 -a_ref = jax.new_ref(jnp.array([1, 2, 3])) +variable = nnx.Variable(jnp.array([1, 2, 3]), is_hijax=True) +print(f"{variable.is_hijax = }\n") @jax.jit -def increment(a_ref: jax.Ref): # no return! - array: jax.Array = a_ref[...] # access - a_ref[...] = array + 1 # update +def increment(variable: nnx.Variable[jax.Array]): # no return! + new_value = variable + 1 # Array-like operations + variable[...] = new_value # in-place updates -print("[1] =", a_ref); increment(a_ref); print("[2] =", a_ref) +print("Before =", variable); increment(variable); print("After =", variable) ``` ```{code-cell} ipython3 -@jax.jit -def inc(x): - x[...] += 1 - -print(increment.lower(a_ref).as_text()) +# TODO: enable once as_text is fixed +# print(increment.lower(variable).as_text()) ``` -### Variables Refs - ```{code-cell} ipython3 -variable = nnx.Variable(jnp.array([1, 2, 3]), use_ref=True) -print(f"{variable.has_ref = }\n") - -print("[1] =", variable); increment(variable); print("[2] =", variable) -``` +nnx.use_hijax(True) -```{code-cell} ipython3 -with nnx.use_refs(True): - variable = nnx.Variable(jnp.array([1, 2, 3])) +variable = nnx.Variable(jnp.array([1, 2, 3])) -print(f"{variable.has_ref = }") +print(f"{variable.is_hijax = }") ``` Mention `nnx.use_refs` can be used as global flag @@ -73,12 +63,14 @@ class Linear(nnx.Module): def __call__(self, x): return x @ self.kernel + self.bias[None] -model = Linear(1, 3, rngs=nnx.Rngs(0)) # without array refs -refs_model = nnx.to_refs(model) # convert to array refs -arrays_model = nnx.to_arrays(refs_model) # convert to regular arrays +with nnx.use_hijax(False): # use lojax Variables + model = Linear(1, 3, rngs=nnx.Rngs(0)) + +hijax_model = nnx.to_hijax(model) # convert hijax Variables +arrays_model = nnx.to_lojax(hijax_model) # convert to lojax Variables -print("nnx.to_refs(model) =", refs_model) -print("nnx.to_arrays(refs_model) =", arrays_model) +print("nnx.to_hijax(model) =", hijax_model) +print("nnx.to_lojax(refs_model) =", arrays_model) ``` ## Examples @@ -99,9 +91,9 @@ class Block(nnx.Module): ### Training Loop ```{code-cell} ipython3 -with nnx.use_refs(True): - model = Block(2, 64, 3, rngs=nnx.Rngs(0)) - optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) +# hijax Variables by default +model = Block(2, 64, 3, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @jax.jit def train_step(model, optimizer, x, y): @@ -110,7 +102,7 @@ def train_step(model, optimizer, x, y): model = nnx.merge(graphdef, params, nondiff) return ((model(x) - y) ** 2).mean() - loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) # freeze ArrayRefs for jax.grad + loss, grads = jax.value_and_grad(loss_fn)(nnx.to_lojax(params)) # lojax Variables for jax.grad optimizer.update(model, grads) return loss @@ -121,12 +113,11 @@ train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3))) ### Scan Over Layers ```{code-cell} ipython3 -@nnx.vmap +@jax.vmap def create_stack(rngs): - return Block(2, 64, 2, rngs=rngs) + return nnx.to_lojax(Block(2, 64, 2, rngs=rngs)) -with nnx.use_refs(True): - block_stack = create_stack(nnx.Rngs(0).fork(split=8)) +block_stack = nnx.to_hijax(create_stack(nnx.Rngs(0).fork(split=8))) def scan_fn(x, block): x = block(x) @@ -150,28 +141,27 @@ def create_model(rngs): return Block(2, 64, 3, rngs=rngs) try: - with nnx.use_refs(True): - model = create_model(nnx.Rngs(0)) + model = create_model(nnx.Rngs(0)) except Exception as e: print(f"Error:", e) ``` ```{code-cell} ipython3 -with nnx.use_refs(False): # <-- disable array refs +with nnx.use_hijax(False): # <-- disable hijax Variables model = create_model(nnx.Rngs(0)) -model = nnx.to_refs(model) # convert to mutable after creation +model = nnx.to_hijax(model) # convert to mutable after creation print("model.linear =", model.linear) ``` ```{code-cell} ipython3 +# TODO: why does this work? @nnx.jit def create_model(rngs): return Block(2, 64, 3, rngs=rngs) -with nnx.use_refs(True): - model = create_model(nnx.Rngs(0)) +model = create_model(nnx.Rngs(0)) print("model.linear =", model.linear) ``` @@ -179,13 +169,14 @@ print("model.linear =", model.linear) ### Reference Sharing (aliasing) ```{code-cell} ipython3 +# TODO: why does this not fail? def get_error(f, *args): try: return f(*args) except Exception as e: return f"{type(e).__name__}: {e}" - -x = jax.new_ref(jnp.array(0)) + +x = nnx.Variable(jnp.array(0)) @jax.jit def f(a, b): @@ -211,9 +202,8 @@ class SharedModules(nnx.Pytree): def g(pytree): ... -with nnx.use_refs(True): - shared_variables = SharedVariables() - shared_modules = SharedModules() +shared_variables = SharedVariables() +shared_modules = SharedModules() print("SharedVariables", get_error(g, shared_variables)) print("SharedModules", get_error(g, shared_modules)) diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index 58490edc1..07971470e 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -107,15 +107,11 @@ Basic usage model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) - @nnx.jit # automatic state management for JAX transforms + @nnx.jit # automatic state propagation def train_step(model, optimizer, x, y): - def loss_fn(model): - y_pred = model(x) # call methods directly - return ((y_pred - y) ** 2).mean() - + loss_fn = lambda model: ((model(x) - y) ** 2).mean() loss, grads = nnx.value_and_grad(loss_fn)(model) optimizer.update(model, grads) # in-place updates - return loss diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb index ab43de2b5..727478d45 100644 --- a/docs_nnx/nnx_basics.ipynb +++ b/docs_nnx/nnx_basics.ipynb @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -59,7 +59,7 @@ " self.din, self.dout = din, dout\n", "\n", " def __call__(self, x: jax.Array):\n", - " return x @ self.w + self.b" + " return x @ self.w + self.b[None]" ] }, { @@ -84,31 +84,10 @@ "[[1.5643291 0.94782424 0.37971854 1.0724319 0.22112393]]\n" ] }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook :\n", - "Traceback (most recent call last):\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/renderers.py\", line 225, in _render_subtree\n", - " postprocessed_result = hook(\n", - " ^^^^^\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py\", line 47, in use_autovisualizer_if_present\n", - " result = autoviz(node, path)\n", - " ^^^^^^^^^^^^^^^^^^^\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py\", line 306, in __call__\n", - " jax.sharding.PositionalSharding\n", - " File \"/Users/cgarciae/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/deprecations.py\", line 54, in getattr\n", - " raise AttributeError(message)\n", - "AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0\n", - "\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -120,7 +99,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -164,8 +143,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "counter.count.value = Array(0, dtype=int32, weak_type=True)\n", - "counter.count.value = Array(1, dtype=int32, weak_type=True)\n" + "counter.count[...] = Array(0, dtype=int32, weak_type=True)\n", + "counter.count[...] = Array(1, dtype=int32, weak_type=True)\n" ] } ], @@ -177,12 +156,12 @@ " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self):\n", - " self.count.value += 1\n", + " self.count[...] += 1\n", "\n", "counter = Counter()\n", - "print(f'{counter.count.value = }')\n", + "print(f'{counter.count[...] = }')\n", "counter()\n", - "print(f'{counter.count.value = }')" + "print(f'{counter.count[...] = }')" ] }, { @@ -212,7 +191,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -224,7 +203,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -273,13 +252,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -291,7 +270,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -340,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -415,7 +394,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -427,7 +406,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -480,13 +459,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -498,7 +477,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -547,7 +526,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -559,7 +538,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -571,7 +550,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -672,7 +651,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -684,7 +663,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -696,7 +675,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -708,7 +687,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md index 149e859d2..c2aca4cae 100644 --- a/docs_nnx/nnx_basics.md +++ b/docs_nnx/nnx_basics.md @@ -40,7 +40,7 @@ class Linear(nnx.Module): self.din, self.dout = din, dout def __call__(self, x: jax.Array): - return x @ self.w + self.b + return x @ self.w + self.b[None] ``` Also note that the inner values of `Variable`s can be accessed using the `value` property, but for convenience they implement all numeric operators and can be used directly in arithmetic expressions (as shown in the code above). @@ -73,12 +73,12 @@ class Counter(nnx.Module): self.count = Count(jnp.array(0)) def __call__(self): - self.count.value += 1 + self.count[...] += 1 counter = Counter() -print(f'{counter.count.value = }') +print(f'{counter.count[...] = }') counter() -print(f'{counter.count.value = }') +print(f'{counter.count[...] = }') ``` Mutable references are usually avoided in JAX. But Flax NNX provides sound mechanisms diff --git a/examples/gemma/helpers.py b/examples/gemma/helpers.py index f74845bed..eeb02848d 100644 --- a/examples/gemma/helpers.py +++ b/examples/gemma/helpers.py @@ -74,7 +74,7 @@ def assign_val_fn( mapped_path: tuple[str | int, ...], val: Any, ) -> dict[tuple[str, ...], Any]: - state[mapped_path].value = val + state[mapped_path].set_value(val) return state mdl: M = nnx.eval_shape(module_factory) diff --git a/examples/gemma/helpers_test.py b/examples/gemma/helpers_test.py index 8d5e899f9..dd7e5fe4e 100644 --- a/examples/gemma/helpers_test.py +++ b/examples/gemma/helpers_test.py @@ -137,11 +137,11 @@ def _map_key_fn(key: tuple[str, ...]) -> tuple[str | int, ...]: np.testing.assert_array_equal(output, linen_output) for i in range(len(num_features)): np.testing.assert_array_equal( - mdl.layers[i].layers[0].mean.value, + mdl.layers[i].layers[0].mean[...], linen_vars['batch_stats'][f'layers_{i}']['layers_0']['mean'], ) np.testing.assert_array_equal( - mdl.layers[i].layers[0].var.value, + mdl.layers[i].layers[0].var[...], linen_vars['batch_stats'][f'layers_{i}']['layers_0']['var'], ) diff --git a/examples/gemma/layers.py b/examples/gemma/layers.py index f764c61a0..5fb959ada 100644 --- a/examples/gemma/layers.py +++ b/examples/gemma/layers.py @@ -44,11 +44,11 @@ def __init__( self.w = nnx.Param(kernel_init(rngs.params(), shape, dtype)) def __call__(self, x: ArrayLike) -> Array: - return jnp.einsum(self.einsum_str, x, self.w.value) + return jnp.einsum(self.einsum_str, x, self.w[...]) @property def shape(self) -> Shape: - return self.w.value.shape + return self.w.shape class RMSNorm(nnx.Module): @@ -65,12 +65,12 @@ def __init__( self.scale = nnx.Param(scale_init(rngs.params(), dim, dtype)) def __call__(self, x: Array) -> Array: - dtype = self.scale.value.dtype + dtype = self.scale.dtype var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) normed_inputs = jnp.asarray(x * jax.lax.rsqrt(var + 1e-06), dtype=dtype) # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs. - scale = jnp.expand_dims(self.scale.value, axis=range(len(x.shape) - 1)) + scale = jnp.expand_dims(self.scale, axis=range(len(x.shape) - 1)) normed_inputs = normed_inputs * (1 + scale) return normed_inputs diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index 56c426fbd..48c9a018a 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -63,15 +63,15 @@ def encode(self, x: ArrayLike) -> Array: return x def decode(self, x: ArrayLike) -> Array: - return jnp.dot(x, self.input_embedding.value.T) + return jnp.dot(x, self.input_embedding.T) @property def embed_dim(self): - return self.input_embedding.value.shape[1] + return self.input_embedding.shape[1] @property def num_embed(self): - return self.input_embedding.value.shape[0] + return self.input_embedding.shape[0] class Attention(nnx.Module): diff --git a/examples/gemma/sampler_test.py b/examples/gemma/sampler_test.py index 307b0e43e..8d2ed5a83 100644 --- a/examples/gemma/sampler_test.py +++ b/examples/gemma/sampler_test.py @@ -232,9 +232,9 @@ def test_forbidden_tokens(self): transformer_config, rngs=nnx.Rngs(params=0) ) # Pre-cook the embedding matrix so that the output is deterministic. - transformer.embedder.input_embedding.value = jnp.eye( + transformer.embedder.input_embedding.set_value(jnp.eye( vocab.GetPieceSize(), 32 - ) + )) sampler = sampler_lib.Sampler( transformer=transformer, vocab=vocab, diff --git a/examples/gemma/sow_lib.py b/examples/gemma/sow_lib.py index 6bc808501..7580cdfe2 100644 --- a/examples/gemma/sow_lib.py +++ b/examples/gemma/sow_lib.py @@ -49,13 +49,11 @@ def merge(self, decoding_step, layer: nnx.Module): if field.name.startswith('attn_'): step_value = getattr( layer.attn, field.name.replace('attn_', '') - ).value[0] + )[0] elif field.name.startswith('mlp_'): - step_value = getattr(layer.mlp, field.name.replace('mlp_', '')).value[ - 0 - ] + step_value = getattr(layer.mlp, field.name.replace('mlp_', ''))[0] else: - step_value = getattr(layer, field.name).value[0] + step_value = getattr(layer, field.name)[0] except AttributeError as exc: raise ValueError( f'Intermediate {field.name} is not in the step intermediates.' @@ -93,7 +91,7 @@ def merge(self, decoding_step, transformer: nnx.Module): if self.embeddings is not None: try: self.embeddings = self.embeddings.at[:, decoding_step + 1, ...].set( - transformer.embeddings.value[0][:, 0, ...] + transformer.embeddings[0][:, 0, ...] ) except AttributeError as exc: raise ValueError( diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index 842121e96..54fb6748e 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -487,10 +487,10 @@ def _assign_linen_params_to_nnx_state( if 'gate_proj' in mapped_path: if transpose_gating_einsum: val = jnp.swapaxes(val, 1, 2) - state[mapped_path].value = val[0] - state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] + state[mapped_path].set_value(val[0]) + state[mapped_path[:-2] + ('up_proj', 'kernel')].set_value(val[1]) else: - state[mapped_path].value = val + state[mapped_path].set_value(val) return state diff --git a/examples/gemma/transformer_test.py b/examples/gemma/transformer_test.py index 3d30f9277..97916604b 100644 --- a/examples/gemma/transformer_test.py +++ b/examples/gemma/transformer_test.py @@ -461,7 +461,7 @@ def test_sow_intermediates(self, sow_config): if sow_config.embeddings: self.assertTrue(hasattr(transformer, 'embeddings')) - embeddings = transformer.embeddings.value[0] + embeddings = transformer.embeddings[0] self.assertEqual( embeddings.shape, (batch_size, sequence_length, config.embed_dim), @@ -472,7 +472,7 @@ def test_sow_intermediates(self, sow_config): for layer in transformer.layers: if sow_config.rs_after_attention: self.assertTrue(hasattr(layer, 'rs_after_attention')) - rs_after_attention = layer.rs_after_attention.value[0] + rs_after_attention = layer.rs_after_attention[0] self.assertIsNotNone(rs_after_attention) self.assertEqual( rs_after_attention.shape, @@ -482,7 +482,7 @@ def test_sow_intermediates(self, sow_config): self.assertFalse(hasattr(layer, 'rs_after_attention')) if sow_config.rs_after_ffw: self.assertTrue(hasattr(layer, 'rs_after_ffw')) - rs_after_ffw = layer.rs_after_ffw.value[0] + rs_after_ffw = layer.rs_after_ffw[0] self.assertIsNotNone(rs_after_ffw) self.assertEqual( rs_after_ffw.shape, @@ -492,7 +492,7 @@ def test_sow_intermediates(self, sow_config): self.assertFalse(hasattr(layer, 'rs_after_ffw')) if sow_config.attn_logits_topk: self.assertTrue(hasattr(layer.attn, 'logits_topk_values')) - attn_logits_topk_values = layer.attn.logits_topk_values.value[0] + attn_logits_topk_values = layer.attn.logits_topk_values[0] self.assertIsNotNone(attn_logits_topk_values) self.assertEqual( attn_logits_topk_values.shape, @@ -504,7 +504,7 @@ def test_sow_intermediates(self, sow_config): ), ) self.assertTrue(hasattr(layer.attn, 'logits_topk_indices')) - attn_logits_topk_indices = layer.attn.logits_topk_indices.value[0] + attn_logits_topk_indices = layer.attn.logits_topk_indices[0] self.assertIsNotNone(attn_logits_topk_indices) self.assertEqual( attn_logits_topk_indices.shape, @@ -520,7 +520,7 @@ def test_sow_intermediates(self, sow_config): self.assertFalse(hasattr(layer.attn, 'logits_topk_indices')) if sow_config.mlp_hidden_topk: self.assertTrue(hasattr(layer.mlp, 'hidden_topk_values')) - ffw_hidden_topk_values = layer.mlp.hidden_topk_values.value[0] + ffw_hidden_topk_values = layer.mlp.hidden_topk_values[0] self.assertIsNotNone(ffw_hidden_topk_values) self.assertEqual( ffw_hidden_topk_values.shape, @@ -531,7 +531,7 @@ def test_sow_intermediates(self, sow_config): ), ) self.assertTrue(hasattr(layer.mlp, 'hidden_topk_indices')) - ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices.value[0] + ffw_hidden_topk_indices = layer.mlp.hidden_topk_indices[0] self.assertIsNotNone(ffw_hidden_topk_indices) self.assertEqual( ffw_hidden_topk_indices.shape, diff --git a/examples/nnx_toy_examples/mutable_array_basic.py b/examples/nnx_toy_examples/hijax_basic.py similarity index 91% rename from examples/nnx_toy_examples/mutable_array_basic.py rename to examples/nnx_toy_examples/hijax_basic.py index 7386163c1..294508871 100644 --- a/examples/nnx_toy_examples/mutable_array_basic.py +++ b/examples/nnx_toy_examples/hijax_basic.py @@ -54,9 +54,9 @@ def __call__(self, x): self.count[...] += 1 return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5) -with nnx.use_refs(True): - model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) - optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.1), wrt=nnx.Param) +nnx.use_hijax(True) +model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.1), wrt=nnx.Param) @jax.jit @@ -67,7 +67,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, counts) return jnp.mean((y - model(x)) ** 2) - grads = jax.grad(loss_fn)(nnx.to_arrays(params)) + grads = jax.grad(loss_fn)(nnx.to_lojax(params)) optimizer.update(model, grads) diff --git a/examples/nnx_toy_examples/mutable_array_demo.py b/examples/nnx_toy_examples/hijax_demo.py similarity index 81% rename from examples/nnx_toy_examples/mutable_array_demo.py rename to examples/nnx_toy_examples/hijax_demo.py index 6d9619444..9b14e3404 100644 --- a/examples/nnx_toy_examples/mutable_array_demo.py +++ b/examples/nnx_toy_examples/hijax_demo.py @@ -49,9 +49,8 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.w = nnx.Param(initializer(rngs.params(), (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - # [...] is used to access the array def __call__(self, x: jax.Array): - return x @ self.w[...] + self.b[None] + return x @ self.w + self.b[None] # Block implements linear, batch norm, and dropout. Its behavior @@ -88,21 +87,21 @@ def __call__( self, x: jax.Array, *, rngs: nnx.Rngs | None = None ) -> jax.Array: # ----------- linear -------------------- - x = x @ self.w[...] + self.b[None] + x = x @ self.w + self.b[None] # ----------- batch norm ---------------- if self.use_stats: - mean = self.mean[...] - var = self.var[...] + mean = self.mean + var = self.var else: mean = jnp.mean(x, axis=0) var = jnp.var(x, axis=0) # ema updates - # stop gradient is used until a ArrayRef supports updates from grad tracers + # stop gradient is used until a Hijax supports updates from grad tracers sg = jax.lax.stop_gradient - self.mean[...] = sg(self.mu * self.mean[...] + (1 - self.mu) * mean) - self.var[...] = sg(self.mu * self.var[...] + (1 - self.mu) * var) + self.mean[...] = sg(self.mu * self.mean + (1 - self.mu) * mean) + self.var[...] = sg(self.mu * self.var + (1 - self.mu) * var) x = (x - mean[None]) / jnp.sqrt(var[None] + 1e-5) - x = x * self.scale[...] + self.bias[...] + x = x * self.scale + self.bias # ----------- dropout ------------------- if not self.deterministic and self.dropout_rate > 0.0: assert rngs is not None @@ -125,7 +124,7 @@ def __init__( use_scan: bool = True, rngs: nnx.Rngs, ): - self.count: jax.Ref = jax.new_ref(jnp.array(0)) + self.count = nnx.Variable(jnp.array(0)) self.block_in = Block(din, dhidden, rngs=rngs) self.linear_out = Linear(dhidden, dout, rngs=rngs) @@ -136,11 +135,15 @@ def __init__( @jax.vmap def create_block(rngs, /): - return nnx.to_arrays(Block(dhidden, dhidden, rngs=rngs)) + # return nnx.stateless(Block(dhidden, dhidden, rngs=rngs)) + return Block(dhidden, dhidden, rngs=rngs) - self.blocks = nnx.to_refs(create_block(rngs.fork(split=num_blocks))) + # self.blocks = nnx.stateful(create_block(rngs.fork(split=num_blocks))) + self.blocks = create_block(rngs.fork(split=num_blocks)) else: - self.blocks = nnx.List([Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)]) + self.blocks = nnx.List( + [Block(dhidden, dhidden, rngs=rngs) for i in range(num_blocks)] + ) def __call__(self, x: jax.Array, *, rngs: nnx.Rngs | None = None): self.count[...] += 1 @@ -169,7 +172,7 @@ class OptState(nnx.Variable): ... # Optimizer are an interesting case as they are inherently stateful and -# pose a good use case for ArrayRef. Here we implement SGD with +# pose a good use case for MutableHijax. Here we implement SGD with # momentum. The optimizer receives the params as constructor arguments but doesn't # hold a reference to them, it only uses the params to initialize its state # by creating new OptState Variables that reuse the param's metadata. @@ -180,40 +183,36 @@ def __init__(self, params, lr: float, decay: float = 0.9): def make_opt_state(x): if isinstance(x, nnx.Variable): - return OptState(jnp.zeros_like(x.value), **x.get_metadata()) + return OptState(jnp.zeros_like(x[...]), **x.get_metadata()) else: return OptState(jnp.zeros_like(x)) - self.momentum = nnx.data(jax.tree.map( - make_opt_state, - params, - is_leaf=lambda x: isinstance(x, nnx.Variable), - )) + self.momentum = nnx.data(jax.tree.map(make_opt_state, params)) # during the update we simply map over (params, momentum, grads), # for each triplet we implement the SGD update rule which updates # both the optimizer's state (momentum) and the params in place. def update(self, params, grads): - params = nnx.pure(params) - grads = nnx.pure(grads) - momentum = nnx.pure(self.momentum) - def update_fn( - param: jax.Ref, momentum: jax.Ref, grad: jax.Array + param: nnx.Variable[jax.Array], + momentum: nnx.Variable[jax.Array], + grad: nnx.Variable[jax.Array], ): - momentum[...] = self.decay * momentum[...] + (1 - self.decay) * grad[...] - param[...] -= self.lr * momentum[...] + momentum[...] = self.decay * momentum + (1 - self.decay) * grad + param[...] -= self.lr * momentum + + # is_leaf might not be necesarry as MutableHijaxVariable are not pytreees + jax.tree.map(update_fn, params, self.momentum, grads) - jax.tree.map(update_fn, params, momentum, grads) # ## Training -with nnx.use_refs(True): - rngs = nnx.Rngs(params=0, dropout=1) - model = Model( - num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs - ) - optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99) +nnx.use_hijax(True) +rngs = nnx.Rngs(params=0, dropout=1) +model = Model( + num_blocks=3, din=1, dhidden=256, dout=1, use_scan=False, rngs=rngs +) +optimizer = SGD(params=nnx.state(model, nnx.Param), lr=3e-3, decay=0.99) # Create a copy of the model structure and set its attributes to eval model. # This works because they share the underlying ArrayRefs so both models @@ -237,13 +236,14 @@ def loss_fn(params): loss = jnp.mean((model(x, rngs=rngs) - y) ** 2) return loss - # For the time being we have to use 'freeze' make the Variables immutable - # as 'jax.grad' doesn't support ArrayRefs yet. - grads = jax.grad(loss_fn)(nnx.to_arrays(params)) + # For the time being we have to use 'to_lojax' + # as 'jax.grad' doesn't support Hijax types yet. + grads = jax.grad(loss_fn)(nnx.to_lojax(params)) # 'update' mutates the optimizer's state and the params in place # so we don't need to return anything 🚀 optimizer.update(params, grads) + # simple test step that computes the loss @jax.jit def test_step(model: Model, x, y): diff --git a/flax/configurations.py b/flax/configurations.py index 9240b697f..ca34a387f 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -27,7 +27,7 @@ class Config: flax_pytree_module: bool flax_max_repr_depth: int | None flax_always_shard_variable: bool - flax_hijax_variable: bool + flax_variable_mode: str # See https://google.github.io/pytype/faq.html. _HAS_DYNAMIC_ATTRIBUTES = True @@ -201,6 +201,38 @@ def static_bool_env(varname: str, default: bool) -> bool: ) +def str_flag(name: str, *, default: str, help: str) -> FlagHolder[str]: + """Set up a string flag. + + Example:: + + some_string = str_flag( + name='flax_some_string', + default='default_value', + help='Some string configuration.', + ) + + Now the ``FLAX_SOME_STRING`` shell environment variable can be used to + control the process-level value of the flag, in addition to using e.g. + ``config.update("flax_some_string", "new_value")`` directly. + + Args: + name: converted to lowercase to define the name of the flag. It is + converted to uppercase to define the corresponding shell environment + variable. + default: a default value for the flag. + help: used to populate the docstring of the returned flag holder object. + + Returns: + A flag holder object for accessing the value of the flag. + """ + name = name.lower() + config._add_option(name, static_str_env(name.upper(), default)) + fh = FlagHolder[str](name, help) + setattr(Config, name, property(lambda _: fh.value, doc=help)) + return fh + + def static_int_env(varname: str, default: int | None) -> int | None: """Read an environment variable and interpret it as an integer. @@ -222,6 +254,18 @@ def static_int_env(varname: str, default: int | None) -> int | None: ) from None +def static_str_env(varname: str, default: str) -> str: + """Read an environment variable and interpret it as a string. + + Args: + varname: the name of the variable + default: the default string value + Returns: + string return value derived from defaults and environment. + """ + return os.getenv(varname, default) + + # Flax Global Configuration Variables: flax_filter_frames = bool_flag( @@ -291,8 +335,8 @@ def static_int_env(varname: str, default: int | None) -> int | None: default=True, help='Whether a `nnx.Variable` should always automatically be sharded if it contains sharding annotations.', ) -flax_hijax_variable = bool_flag( - name='flax_hijax_variable', - default=False, - help='Whether to enable HiJAX support for `nnx.Variable`.', +flax_variable_mode = str_flag( + name='flax_variable_mode', + default='lojax', + help='The variable mode for `nnx.Variable`. Options are "lojax", "hijax", and "ref".', ) \ No newline at end of file diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 3ef14ce4c..066f3e99d 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -68,6 +68,8 @@ from .graph import variables as variables from .graph import to_arrays as to_arrays from .graph import to_refs as to_refs +from .graph import to_hijax as to_hijax +from .graph import to_lojax as to_lojax from .graph import pure as pure from .graph import cached_partial as cached_partial from .graph import flatten as flatten @@ -192,8 +194,8 @@ from .variablelib import variable_type_from_name as variable_type_from_name from .variablelib import variable_name_from_type as variable_name_from_type from .variablelib import register_variable_name as register_variable_name -from .variablelib import use_refs as use_refs -from .variablelib import using_refs as using_refs +from .variablelib import variable_mode as variable_mode +from .variablelib import current_variable_mode as current_variable_mode from .visualization import display as display from .extract import to_tree as to_tree from .extract import from_tree as from_tree diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index ed30d2895..bf913e386 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -391,7 +391,7 @@ def _get_variables(self) -> tp.Mapping: if isinstance( variable, variablelib.Variable ) and bridge_variables.is_vanilla_variable(variable): - leaf = variable.value + leaf = variable.get_value() else: leaf = bridge_variables.to_linen_var(variable) diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index e354f2e70..b0584d388 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -64,7 +64,7 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': def get_partition_spec(self) -> jax.sharding.PartitionSpec: """Returns the ``Partitionspec`` for this partitioned value.""" nnx_var = self.to_nnx_variable() - spec = spmd.get_partition_spec(nnx_var).raw_value + spec = spmd.get_partition_spec(nnx_var).get_raw_value() assert isinstance(spec, jax.sharding.PartitionSpec) return spec @@ -78,11 +78,11 @@ def is_vanilla_variable(vs: variablelib.Variable) -> bool: Returns False only if it has non-empty hooks or any non-built-in attribute. """ for key, value in vs.get_metadata().items(): - if key.endswith('_hooks'): - if value != (): - return False - else: - return False + if key in ('is_hijax', 'eager_sharding'): + continue + if key.endswith('_hooks') and value == (): + continue + return False return True @@ -91,11 +91,11 @@ def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata: if 'linen_meta_type' in metadata: linen_type = metadata['linen_meta_type'] if hasattr(linen_type, 'from_nnx_metadata'): - return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) - return linen_type(vs.value, **metadata) + return linen_type.from_nnx_metadata({'value': vs.get_value(), **metadata}) + return linen_type(vs.get_value(), **metadata) if is_vanilla_variable(vs): - return vs.value - return NNXMeta(type(vs), vs.value, metadata) + return vs.get_value() + return NNXMeta(type(vs), vs.get_value(), metadata) def get_col_name(keypath: tp.Sequence[Any]) -> str: diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 6ab594e81..50a1041e5 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -413,7 +413,7 @@ def _to_linen_var(x): if self.metadata_fn is not None: return self.metadata_fn(x) # pylint: disable=too-many-function-args else: - return x.value + return x.get_value() return x collection_state = nnx.traversals.unflatten_mapping(flat_state) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index fd8db8124..fe48ecc3d 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -63,7 +63,7 @@ def check_consistent_aliasing( lambda: f'Trying to extract graph node from different trace level, got {value!r}' ) if isinstance(value, graph.Variable): - if not value._trace_state.is_valid(): + if not value.trace_state.is_valid(): raise ValueError( f'Cannot extract graph node from different trace level, got {value!r}' ) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 45c0dede5..db98ceedd 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -20,7 +20,7 @@ import threading import typing as tp -import jax.experimental +import jax.core from flax import config from flax.nnx import filterlib, reprlib, traversals, variablelib @@ -34,6 +34,7 @@ from flax.nnx.variablelib import Variable, is_array_ref from flax.typing import Key, PathParts, is_key_like import jax +from jax._src import hijax import numpy as np import treescope # type: ignore[import-not-found,import-untyped] import typing_extensions as tpe @@ -373,7 +374,6 @@ class VariableDef(reprlib.Representable, tp.Generic[Node]): index: int outer_index: int | None metadata: HashableMapping[str, tp.Any] - array_refdef: ArrayRefDef | NodeRef | None def with_no_outer_index(self) -> VariableDef: return VariableDef( @@ -381,9 +381,6 @@ def with_no_outer_index(self) -> VariableDef: index=self.index, outer_index=None, metadata=self.metadata, - array_refdef=self.array_refdef.with_no_outer_index() - if isinstance(self.array_refdef, ArrayRefDef) - else self.array_refdef, ) def with_same_outer_index(self) -> VariableDef: @@ -392,9 +389,6 @@ def with_same_outer_index(self) -> VariableDef: index=self.index, outer_index=self.index, metadata=self.metadata, - array_refdef=self.array_refdef.with_same_outer_index() - if isinstance(self.array_refdef, ArrayRefDef) - else self.array_refdef, ) def __nnx_repr__(self): @@ -761,32 +755,21 @@ def make_mutable_arraydef(value: variablelib.Ref): if is_variable: assert isinstance(node, Variable) assert index is not None - prev_inner_value = node.raw_value - if variablelib.is_array_ref(prev_inner_value): - array_refdef, inner_value = make_mutable_arraydef(prev_inner_value) - else: - array_refdef = None - inner_value = prev_inner_value if path is None: - leaf = inner_value + leaf = node.get_raw_value() else: leaf = node # type: ignore[assignment] - if inner_value is not prev_inner_value: - leaf.raw_value = inner_value variabledef = VariableDef( - type=type(node), + type=type(node), # type: ignore index=index, outer_index=ref_outer_index.get(node, None) if ref_outer_index else None, - metadata=HashableMapping(node._var_metadata), - array_refdef=array_refdef, + metadata=HashableMapping(node.get_metadata()), ) - if type(inner_value) is not Repeated: - assert not isinstance(leaf, Repeated) - leaves.append(leaf) - if path is not None: - assert paths is not None - paths.append(tuple(path)) + leaves.append(leaf) + if path is not None: + assert paths is not None + paths.append(tuple(path)) nodes.append(variabledef) return elif is_array_ref: @@ -944,7 +927,7 @@ def _graph_fingerprint( variable_index = new_ref_index[value] = ctx.next_index ctx.next_index += 1 append_fn(variable_index) - for key_value in value._var_metadata.items(): + for key_value in value.get_metadata().items(): append_fn(key_value) elif not isinstance(value, (jax.Array, np.ndarray)): append_fn(value) @@ -1048,7 +1031,7 @@ def _check_graph_fingerprint( # append_fn(variable_index) if variable_index != next(fp_iterator): return False - for key_value in value._var_metadata.items(): + for key_value in value.get_metadata().items(): # append_fn(key_value) if key_value != next(fp_iterator): return False @@ -1197,10 +1180,10 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): raise RuntimeError(f'Expected a no update for ArrayRef but got {leaf}.') elif type(leaf) in (NoUpdate, Repeated): raise ValueError( - f"Expected a ArrayRefOutput type but got '{leaf.value}.'" + f"Expected a ArrayRefOutput type but got '{leaf}.'" ) elif type(leaf) is ArrayRefOutput: - array_ref = variablelib.new_ref(leaf.value) + array_ref = jax.new_ref(leaf.value) elif variablelib.is_array_ref(leaf): array_ref = leaf else: @@ -1217,26 +1200,9 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): variabledef = tp.cast(VariableDef[Variable], nodedef) # its a unseen variable, create a new one - if variabledef.array_refdef is not None: - if type(variabledef.array_refdef) is NodeRef: - value = index_ref[variabledef.array_refdef.index] - else: - value = next(leaves_iter) - assert type(variabledef.array_refdef) is ArrayRefDef - if isinstance(value, Variable): - value = value.copy() if copy_variables else value - inner_value = value.raw_value - array_ref = get_mutable_array(variabledef.array_refdef, inner_value) - if array_ref is not inner_value: - value.raw_value = array_ref - else: - # if value is an array or array ref, we need call get_mutable_array - # to register it in the index_ref - value = get_mutable_array(variabledef.array_refdef, value) - else: - value = next(leaves_iter) - if isinstance(value, Variable) and copy_variables: - value = value.copy() + value = next(leaves_iter) + if isinstance(value, Variable) and copy_variables: + value = value.copy() # when idxmap is present, check if the Varable exists there # and update existing variables if it does @@ -1252,7 +1218,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf): elif isinstance(value, Variable): variable.update_from_state(value) else: - variable.raw_value = value + variable.set_value(value) else: # variabledef.index not in index_ref_cache # variable reference does not exist outside, create a new one if isinstance(value, Variable): @@ -1437,12 +1403,12 @@ def _update_variable(node: Variable, value): # can happen when using standalone Variables with `grad` pass else: - if is_array_ref(node.raw_value) and ( + if is_array_ref(node.get_value()) and ( isinstance(value, jax.Array) or is_array_ref(value) ): node[...] = value[...] else: - node.raw_value = value + node.set_raw_value(value) if isinstance(node, Variable): _update_variable(node, state) @@ -1780,7 +1746,7 @@ def flatten( # type: ignore[invalid-annotation] else: paths = None leaves = [ - variable.raw_value for variable in node_static_cache.variables + variable.get_value() for variable in node_static_cache.variables ] else: graphdef, flat_state = flatten( @@ -1916,7 +1882,7 @@ def unflatten( # type: ignore[invalid-annotation] if isinstance(leaf, Variable): variable.update_from_state(leaf) else: - variable.raw_value = leaf + variable.set_value(leaf) self.index_ref.update(static_cache_node.new_index_ref) else: # uncached node, create it @@ -2556,7 +2522,7 @@ def pop( >>> assert hasattr(model, 'i') >>> intermediates = nnx.pop(model, nnx.Intermediate) - >>> assert intermediates['i'].value[0].shape == (1, 3) + >>> assert intermediates['i'][0].shape == (1, 3) >>> assert not hasattr(model, 'i') Args: @@ -2616,7 +2582,7 @@ def clone(node: Node, variables: bool = True) -> Node: def _mutable_like(path, x): - return (isinstance(x, Variable) and x.has_ref) or variablelib.is_array_ref(x) + return variablelib.is_array_ref(x) def to_arrays( @@ -2687,7 +2653,7 @@ def to_arrays( def _array_like(path, x): - return (isinstance(x, Variable) and not x.has_ref) or isinstance(x, jax.Array) + return isinstance(x, jax.Array) def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A: @@ -2743,11 +2709,83 @@ def to_refs(node: A, /, only: filterlib.Filter = _array_like) -> A: raise ValueError(f'Found duplicate at paths:{duplicates_strs}') graphdef, frozen_state, rest = split(node, only, ...) # type: ignore[misc] - mutable_state = jax.tree.map(variablelib.new_ref, frozen_state) + mutable_state = jax.tree.map(jax.new_ref, frozen_state) node = merge(graphdef, mutable_state, rest) return node +def _is_lojax_variable(path, x): + return isinstance(x, variablelib.Variable) and not x.mode == 'hijax' + + +def to_hijax( + node: A, /, *, only: filterlib.Filter = ..., mutable: bool = True +) -> A: + """ """ + if not mutable: + raise ValueError('to_hijax only supports mutable=True at the moment.') + + only = filterlib.All(_is_lojax_variable, only) + predicate = filterlib.to_predicate(only) + + if all_duplicates := find_duplicates(node, only=only): + duplicates_strs = '\n ---' + for node_duplicates in all_duplicates: + for path in node_duplicates: + path_str = '/'.join(map(str, path)) + duplicates_strs += f'\n {path_str}' + duplicates_strs += '\n ---' + raise ValueError(f'Found duplicate at paths:{duplicates_strs}') + + def _to_hijax(jax_path, x): + if predicate(to_nnx_path(jax_path), x): + assert isinstance(x, variablelib.Variable) + x = x.copy() + x.set_raw_value(hijax.Box(x.get_raw_value())) + return x + + node = jax.tree.map_with_path( + _to_hijax, node, is_leaf=lambda x: isinstance(x, variablelib.Variable) + ) + return node + + +def _is_hijax_variable(path, x): + return isinstance(x, variablelib.Variable) and x.mode == 'hijax' + +def to_lojax( + node: A, /, *, allow_duplicates: bool = False, only: filterlib.Filter = ... +) -> A: + """ """ + only = filterlib.All(_is_hijax_variable, only) + predicate = filterlib.to_predicate(only) + + if not allow_duplicates and ( + all_duplicates := find_duplicates(node, only=only) + ): + duplicates_strs = '\n ---' + for node_duplicates in all_duplicates: + for path in node_duplicates: + path_str = '/'.join(map(str, path)) + duplicates_strs += f'\n {path_str}' + duplicates_strs += '\n ---' + raise ValueError(f'Found duplicate at paths:{duplicates_strs}') + + def _to_lojax(jax_path, x): + if predicate(to_nnx_path(jax_path), x): + assert isinstance(x, variablelib.Variable) + x = x.copy() + box = x.get_raw_value() + assert isinstance(box, hijax.Box) + x.set_raw_value(box.get()) # unwrap hijax.Box + return x + + node = jax.tree.map_with_path( + _to_lojax, node, is_leaf=lambda x: isinstance(x, variablelib.Variable) + ) + return node + + def pure(tree: A) -> A: """Returns a new tree with all ``Variable`` objects replaced with inner values. @@ -2787,7 +2825,7 @@ def pure(tree: A) -> A: def _pure_fn(x): if isinstance(x, Variable): - return x.raw_value + return x.get_raw_value() return x return jax.tree.map( @@ -3072,7 +3110,7 @@ def _key_path_to_key(key: tp.Any) -> Key: return str(key) -def jax_to_nnx_path(jax_path: tuple, /): +def to_nnx_path(jax_path: tuple, /): return tuple(_key_path_to_key(part) for part in jax_path) diff --git a/flax/nnx/module.py b/flax/nnx/module.py index aa32a7edf..83fafd587 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -182,7 +182,7 @@ def sow( f"Expected '{name}' to be of type '{variable_type.__name__}', " f"got '{type(variable).__name__}'" ) - variable.raw_value = reduce_fn(variable.raw_value, value) + variable.set_value(reduce_fn(variable.get_value(), value)) else: reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 17a164ebf..b6a164adf 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -368,7 +368,7 @@ def __call__( mask=mask, ) # stop_gradient only for flax_array_ref - if self.mean.has_ref or self.var.has_ref: + if self.mean.is_hijax or self.var.is_hijax: stop_gradient = jax.lax.stop_gradient else: stop_gradient = lambda x: x diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index 2d4d2e646..46663cb4b 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -238,7 +238,7 @@ def _collect_stats( var_type = type(node) if issubclass(var_type, nnx.RngState): var_type = nnx.RngState - size_bytes = SizeBytes.from_any(node.raw_value) + size_bytes = SizeBytes.from_any(node.get_value()) if size_bytes: stats[var_type] = size_bytes @@ -355,6 +355,7 @@ def _graph_node_meta_call(cls: tp.Type[P], *args, **kwargs) -> P: return node +@jax.tree_util.register_static @dataclasses.dataclass(frozen=True, repr=False) class ArrayRepr(reprlib.Representable): shape: tp.Tuple[int, ...] @@ -507,12 +508,8 @@ def _setattr(self, name, value: tp.Any) -> None: vars(self)[name] = value def _check_value(self, key, value, new_status: AttributeStatus | None): - def _has_arrays(leaves): - return any( - isinstance(leaf, (np.ndarray, jax.Array)) - or variablelib.is_array_ref(leaf) - for leaf in leaves - ) + def _has_data(leaves): + return any(is_data(leaf) for leaf in leaves) def _get_annotations(leaves): return { @@ -547,7 +544,7 @@ def _has_visited(x): f' _.{key} = nnx.data(...)\n\n' ) - if _has_arrays(leaves): + if _has_data(leaves): # check no data in nnx.static assignments if new_status is not None: if not new_status.is_data and new_status.explicit: @@ -663,7 +660,7 @@ def __nnx_repr__(self): def to_shape_dtype(value): if isinstance(value, Variable): return value.replace( - raw_value=jax.tree.map(to_shape_dtype, value.raw_value) + value=jax.tree.map(to_shape_dtype, value.get_value()) ) elif variablelib.is_array_ref(value) and np.prod(value.shape) > 1: return MutableArrayRepr(value.shape, value.dtype) diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 4489b6431..b48a61d46 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -23,7 +23,6 @@ from flax import errors, struct from flax import typing from flax.nnx import graph -from flax.nnx import variablelib from flax.nnx.nn import initializers from flax.nnx.variablelib import Variable from flax.nnx import filterlib @@ -116,7 +115,7 @@ def __init__( self.count = RngCount(count, tag=tag) def __call__(self) -> jax.Array: - if not self.count.has_ref and not self.count._trace_state.is_valid(): + if not self.count.trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) @@ -826,14 +825,11 @@ def split_rngs_wrapper(*args, **kwargs): and predicate((*path, 'count'), stream.count) ): key = stream() - backups.append((stream, stream.key.raw_value, stream.count.raw_value)) + backups.append((stream, stream.key[...], stream.count[...])) key = random.split(key, splits) if squeeze: key = key[0] - if variablelib.is_array_ref(stream.key.raw_value): - stream.key.raw_value = variablelib.new_ref(key) # type: ignore[assignment] - else: - stream.key.value = key + stream.key.set_value(key) if squeeze: counts_shape = stream.count.shape elif isinstance(splits, int): @@ -841,11 +837,7 @@ def split_rngs_wrapper(*args, **kwargs): else: counts_shape = (*splits, *stream.count.shape) - count = jnp.zeros(counts_shape, dtype=jnp.uint32) - if variablelib.is_array_ref(stream.count.raw_value): - stream.count.raw_value = variablelib.new_ref(count) # type: ignore[assignment] - else: - stream.count.value = count + stream.count.set_value(jnp.zeros(counts_shape, dtype=jnp.uint32)) return SplitBackups(backups) @@ -992,10 +984,10 @@ def fork_rngs_wrapper(*args, **kwargs): ): forked_stream = stream.fork(split=splits) # backup the original stream state - backups.append((stream, stream.key.raw_value, stream.count.raw_value)) + backups.append((stream, stream.key[...], stream.count[...])) # apply the forked key and count to the original stream - stream.key.raw_value = forked_stream.key.raw_value - stream.count.raw_value = forked_stream.count.raw_value + stream.key.set_value(forked_stream.key.get_value()) + stream.count.set_value(forked_stream.count.get_value()) return SplitBackups(backups) @@ -1004,7 +996,7 @@ def backup_keys(node: tp.Any, /): backups: list[StreamBackup] = [] for _, stream in graph.iter_graph(node): if isinstance(stream, RngStream): - backups.append((stream, stream.key.raw_value)) + backups.append((stream, stream.key[...])) return backups def _scalars_only( @@ -1090,13 +1082,13 @@ def reseed( if stream.key.tag in stream_keys: key = rngs[stream.key.tag]() key = policy(path, key, stream.key.shape) - stream.key[...] = key - stream.count[...] = jnp.zeros(key.shape, dtype=jnp.uint32) + stream.key.set_value(key) + stream.count.set_value(jnp.zeros(key.shape, dtype=jnp.uint32)) def restore_rngs(backups: tp.Iterable[StreamBackup], /): for backup in backups: stream = backup[0] - stream.key.raw_value = backup[1] + stream.key.set_value(backup[1]) if len(backup) == 3: - stream.count.raw_value = backup[2] # count + stream.count.set_value(backup[2]) # count diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py index 5cbab6394..2a11b325e 100644 --- a/flax/nnx/statelib.py +++ b/flax/nnx/statelib.py @@ -491,7 +491,7 @@ def to_pure_dict( ) -> dict[str, tp.Any]: # Works for nnx.Variable if extract_fn is None: - extract_fn = lambda x: x.value if isinstance(x, variablelib.Variable) else x + extract_fn = lambda x: x.get_value() if isinstance(x, variablelib.Variable) else x flat_values = {k: extract_fn(x) for k, x in to_flat_state(state)} return traversals.unflatten_mapping(flat_values) @@ -767,6 +767,6 @@ def create_path_filters(state: State): value_paths: dict[tp.Any, set[PathParts]] = {} for path, value in flat_state: if isinstance(value, variablelib.Variable): - value = value.raw_value + value = value.get_value() value_paths.setdefault(value, set()).add(path) return {filterlib.PathIn(*value_paths[value]): value for value in value_paths} \ No newline at end of file diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py index 000561985..7d9d8f9ef 100644 --- a/flax/nnx/summary.py +++ b/flax/nnx/summary.py @@ -96,7 +96,7 @@ def _collect_stats( var_type = type(value) if issubclass(var_type, nnx.RngState): var_type = nnx.RngState - size_bytes = SizeBytes.from_any(value.value) + size_bytes = SizeBytes.from_any(value.get_value()) if var_type in stats: stats[var_type] += size_bytes else: @@ -455,11 +455,13 @@ def do_vjp(*args, **kwargs): for var_type in variable_types: attributes = {} + variable: variablelib.Variable for name, variable in node_info.variable_groups[var_type].items(): - value = variable.value + value = variable.get_value() value_repr = _render_array(value) if _has_shape_dtype(value) else '' metadata = variable.get_metadata() - + metadata.pop('is_hijax') + metadata.pop('eager_sharding', None) if metadata: attributes[name] = { 'value': value_repr, diff --git a/flax/nnx/training/optimizer.py b/flax/nnx/training/optimizer.py index 586f64d7f..2bd1068cb 100644 --- a/flax/nnx/training/optimizer.py +++ b/flax/nnx/training/optimizer.py @@ -30,6 +30,11 @@ # TODO: add tests and docstrings +def _to_pure_lojax(x): + x = nnx.pure(x) + x = nnx.to_arrays(x, allow_duplicates=True) + x = nnx.to_lojax(x, allow_duplicates=True) + return x class OptState(Variable): @@ -53,7 +58,7 @@ class OptVariable(OptState): def to_opt_state(tree): def _to_opt_state(x): if isinstance(x, Variable): - opt_state = OptVariable(x.value, **x.get_metadata()) # type: ignore + opt_state = OptVariable(x.get_value(), **x.get_metadata()) # type: ignore else: opt_state = OptArray(x) return opt_state @@ -210,10 +215,10 @@ def update(self, model: M, grads, /, **kwargs): **kwargs: additional keyword arguments passed to the tx.update, to support ``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``. """ - param_arrays = nnx.to_arrays(nnx.pure(nnx.state(model, self.wrt))) - grad_arrays = nnx.to_arrays(nnx.pure(nnx.state(grads))) - opt_state_arrays = nnx.to_arrays(nnx.pure(self.opt_state)) - kwargs_arrays = nnx.to_arrays(nnx.pure(kwargs)) + param_arrays = _to_pure_lojax(nnx.state(model, self.wrt)) + grad_arrays = _to_pure_lojax(nnx.state(grads, self.wrt)) + opt_state_arrays = _to_pure_lojax(self.opt_state) + kwargs_arrays = _to_pure_lojax(kwargs) updates, new_opt_state = self.tx.update( grad_arrays, opt_state_arrays, param_arrays, **kwargs_arrays diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index d08d83335..e493bc9a4 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -30,7 +30,7 @@ statelib, variablelib, ) -from flax.typing import MISSING, Missing +from flax.typing import MISSING, Missing, PathParts F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) P = tp.ParamSpec('P') @@ -71,7 +71,7 @@ def shardings(self) -> tuple[tp.Any, ...]: return self._shardings def map_prefix( - self, path: variablelib.PathParts, variable: variablelib.Variable + self, path: PathParts, variable: variablelib.Variable ) -> tp.Any: for filter, sharding in zip(self.filters, self.shardings): predicate = filterlib.to_predicate(filter) diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 54027c6ff..132ca2b0e 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -146,7 +146,9 @@ def eval_shape( def _eval_shape_fn(*args, **kwargs): args, kwargs = extract.from_tree((args, kwargs)) out = f(*args, **kwargs) - return graph.to_arrays(extract.to_tree(out), allow_duplicates=True) + return graph.to_lojax( + graph.to_arrays(extract.to_tree(out), allow_duplicates=True) + ) out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(out) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 6d4c358a1..9601bf681 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -14,29 +14,32 @@ # pytype: skip-file from __future__ import annotations -import contextlib import dataclasses import functools from functools import partial import threading import typing as tp from typing import Any +import warnings + from flax import config +from jax._src import hijax import jax import treescope # type: ignore[import-untyped] from flax import errors from flax.core import spmd as core_spmd -from flax.nnx import filterlib, reprlib, tracers, visualization -from flax.typing import MISSING, Missing, PathParts, SizeBytes +from flax.nnx import reprlib, tracers, visualization +from flax.typing import MISSING, Missing, SizeBytes import jax.tree_util as jtu -import jax.numpy as jnp from jax._src.state.types import AbstractRef A = tp.TypeVar('A') B = tp.TypeVar('B') +C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) +P = tp.TypeVar('P', bound=property) V = tp.TypeVar('V', bound='Variable[Any]') GetValueHook = tp.Callable[['Variable[A]', A], A] SetValueHook = tp.Callable[['Variable[A]', A], A] @@ -50,109 +53,77 @@ # The following ensures we avoid an ImportError or DeprecationWarning. if hasattr(jax, 'new_ref') and hasattr(jax, 'Ref'): # JAX v0.7.2 or newer - from jax import new_ref from jax import Ref elif hasattr(jax, 'array_ref') and hasattr(jax, 'ArrayRef'): # JAX v0.7.1 - from jax import array_ref as new_ref # type: ignore[import-untyped] from jax import ArrayRef as Ref # type: ignore[import-untyped] else: # JAX v0.7.0 or older - from jax.experimental import mutable_array as new_ref from jax.experimental import MutableArray as Ref @dataclasses.dataclass class VariableContext(threading.local): - mutable_variable_stack: list[bool] = dataclasses.field(default_factory=list) + variable_mode_stack: list[tp.Literal['lojax', 'hijax', 'ref']] = ( + dataclasses.field(default_factory=list) + ) VARIABLE_CONTEXT = VariableContext() -def using_refs() -> bool: - """Returns whether Variables are using ArrayRefs by default. +def current_variable_mode() -> tp.Literal['lojax', 'hijax', 'ref']: + """ """ + if VARIABLE_CONTEXT.variable_mode_stack: + return VARIABLE_CONTEXT.variable_mode_stack[-1] + match config.flax_variable_mode: + case 'lojax' | 'hijax' | 'ref': + return config.flax_variable_mode + case other: + raise ValueError(f'Unrecognized variable mode: {other}') - Example:: - >>> from flax import nnx - ... - >>> nnx.using_refs() - False - >>> nnx.use_refs(True) - <...> - >>> nnx.using_refs() - True - >>> nnx.use_refs(False) - <...> - >>> nnx.using_refs() - False - - - Returns: - A boolean indicating if Variables are using ArrayRefs by default. - """ - if VARIABLE_CONTEXT.mutable_variable_stack: - return VARIABLE_CONTEXT.mutable_variable_stack[-1] +def variable_mode(value: tp.Literal['lojax', 'hijax', 'ref'], /): + """ """ + if VARIABLE_CONTEXT.variable_mode_stack: + prev_value = VARIABLE_CONTEXT.variable_mode_stack[-1] + VARIABLE_CONTEXT.variable_mode_stack[-1] = value else: - return config.flax_array_ref + prev_value = None + VARIABLE_CONTEXT.variable_mode_stack.append(value) + return ModeContext(prev_value, value) -def use_refs(value: bool, /): - """Sets whether Variables should use ArrayRefs by default or not. +class ModeContext: + def __init__( + self, + prev_value: tp.Literal['lojax', 'hijax', 'ref'] | None, + new_value: tp.Literal['lojax', 'hijax', 'ref'], + ): + self.prev_value: tp.Literal['lojax', 'hijax', 'ref'] | None = prev_value + self.new_value: tp.Literal['lojax', 'hijax', 'ref'] = new_value - Example usage:: + def __enter__(self): + if self.prev_value is not None: + VARIABLE_CONTEXT.variable_mode_stack.insert(-1, self.prev_value) - >>> from flax import nnx - >>> # Use ArrayRefs by default - >>> nnx.use_refs(True) - <...> - >>> # Variable will now use ArrayRefs - >>> v = nnx.Variable(jax.numpy.ones((2, 3))) - >>> v.has_ref - True - >>> v.raw_value - Ref(...) - >>> nnx.use_refs(False) - <...> - - It can also be used as a context manager to temporarily - change the default behavior for a block of code:: - - >>> nnx.use_refs(False) - <...> - >>> with nnx.use_refs(True): - ... v = nnx.Variable(jax.numpy.ones((2, 3))) - ... v.has_ref - True - >>> # it will reset outside - >>> v = nnx.Variable(jax.numpy.ones((2, 3))) - >>> v.has_ref - False - - Args: - value: A boolean indicating if Variables should use ArrayRefs by default. - - Returns: - A context manager that resets the context to the previous value. - """ - # prev_value = VARIABLE_CONTEXT.mutable_variable_stack[-1] if VARIABLE_CONTEXT.mutable_variable_stack else None - # VARIABLE_CONTEXT.mutable_variable_stack.append(value) - if VARIABLE_CONTEXT.mutable_variable_stack: - prev_value = VARIABLE_CONTEXT.mutable_variable_stack[-1] - VARIABLE_CONTEXT.mutable_variable_stack[-1] = value - else: - prev_value = None - VARIABLE_CONTEXT.mutable_variable_stack.append(value) - return _clean_mutable_arrays_context(prev_value) + def __exit__(self, exc_type, exc_value, traceback): + VARIABLE_CONTEXT.variable_mode_stack.pop() -@contextlib.contextmanager -def _clean_mutable_arrays_context(prev_value: bool | None): - if prev_value is not None: - VARIABLE_CONTEXT.mutable_variable_stack.insert(-1, prev_value) - try: - yield - finally: - VARIABLE_CONTEXT.mutable_variable_stack.pop() + def __call__(self, f: F) -> F: + # undo eager stack change + VARIABLE_CONTEXT.variable_mode_stack.pop() + if self.prev_value is not None: + VARIABLE_CONTEXT.variable_mode_stack.append(self.prev_value) + + @functools.wraps(f) + def set_variable_mode_wrapper(*args, **kwargs): + VARIABLE_CONTEXT.variable_mode_stack.append(self.new_value) + try: + return f(*args, **kwargs) + finally: + VARIABLE_CONTEXT.variable_mode_stack.pop() + + return set_variable_mode_wrapper # type: ignore[return-value] def is_array_ref(x) -> tp.TypeGuard[Ref]: @@ -172,6 +143,38 @@ class VariableMetadata(tp.Generic[A]): metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict) +# -------------------------------------------- +# Variable +# -------------------------------------------- + + +def _variable_operator(name: str) -> tp.Callable[[Variable[A], tp.Any], A]: + def variable_operator_method(self, other): + value = self.get_value() + if isinstance(other, Variable): + other = other.get_value() + return getattr(value, name)(other) + + variable_operator_method.__name__ = name + return variable_operator_method + + +def _variable_unary_operator(name: str) -> tp.Callable[[Variable[A]], A]: + def variable_unary_operator_method(self): + value = self.get_value() + return getattr(value, name)() + + variable_unary_operator_method.__name__ = name + return variable_unary_operator_method + +@dataclasses.dataclass(frozen=True) +class BoxRepr(reprlib.Representable): + box: hijax.Box + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self.box).__name__) + yield reprlib.Attr('value', self.box.get()) + class VariableMeta(type): def __new__(cls, cls_name, bases, attrs): if '__slots__' not in attrs: @@ -239,32 +242,61 @@ class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): }) """ - __slots__ = ('raw_value', '_trace_state', '_var_metadata') - - raw_value: A + __slots__ = ('_raw_value', '_trace_state', '_var_metadata') + _raw_value: A _trace_state: tracers.TraceState _var_metadata: dict[str, tp.Any] + @property + def mode(self) -> tp.Literal['lojax', 'hijax', 'ref']: + if isinstance(self._raw_value, hijax.Box): + return 'hijax' + elif is_array_ref(self._raw_value): + return 'ref' + else: + return 'lojax' + + @property + def shape(self: Variable[jax.Array]) -> tuple[int, ...]: + return self.get_value().shape + def __init__( self, - value: tp.Union[A, VariableMetadata[A]], + value: A | VariableMetadata[A], *, - use_ref: bool | None = None, + mode: tp.Literal['lojax', 'hijax', 'ref'] | None = None, + eager_sharding: bool | None = None, **metadata: tp.Any, ): - if use_ref is None: - use_ref = using_refs() - var_t = type(self) object.__setattr__(self, '_trace_state', tracers.TraceState()) if isinstance(value, VariableMetadata): - metadata.update(value.metadata) + aux_metadata = dict(value.metadata) + if 'mode' in aux_metadata: + if mode is not None and mode != aux_metadata['mode']: + raise ValueError( + 'Cannot specify mode both in VariableMetadata and as an ' + 'argument to Variable constructor.' + ) + mode = aux_metadata.pop('mode') + if 'eager_sharding' in aux_metadata: + if ( + eager_sharding is not None + and eager_sharding != aux_metadata['eager_sharding'] + ): + raise ValueError( + 'Cannot specify eager_sharding both in VariableMetadata and as ' + 'an argument to Variable constructor.' + ) + eager_sharding = aux_metadata.pop('eager_sharding') + metadata.update(aux_metadata) value = tp.cast(A, value.raw_value) - elif is_array_ref(value): - raise ValueError('Cannot pass a ArrayRef directly into Variable init.') - object.__setattr__(self, 'raw_value', value) + if any(is_array_ref(v) for v in jax.tree.leaves(value)): + raise ValueError('Cannot pass a Ref directly into Variable constructor.') + + object.__setattr__(self, '_raw_value', value) if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata: metadata['on_get_value'] = var_t.on_get_value @@ -284,34 +316,49 @@ def __init__( if 'sharding' in metadata: metadata['sharding_names'] = metadata.pop('sharding') - object.__setattr__(self, '_var_metadata', metadata) # run create_value hooks - value = self.create_value(self.raw_value) + if 'on_create_value' in metadata: + value = metadata['on_create_value'](self, value) + + if eager_sharding is None: + eager_sharding = config.flax_always_shard_variable + + if mode is None: + mode = current_variable_mode() - # shard the value if applicable - do_eager_sharding = config.flax_always_shard_variable - if 'eager_sharding' in metadata: - do_eager_sharding = metadata['eager_sharding'] - if do_eager_sharding and 'sharding_names' in metadata: + metadata['mode'] = mode + object.__setattr__(self, '_var_metadata', metadata) + object.__setattr__(self, '_raw_value', value) + # run create_value hook + value = self.create_value(value) # type: ignore + # shard the _value if applicable + if eager_sharding and 'sharding_names' in metadata: + metadata['eager_sharding'] = eager_sharding value = core_spmd.shard_value( - value, metadata['sharding_names'], metadata.get('sharding_rules', None), - metadata.get('mesh', None)) + value, + metadata['sharding_names'], + metadata.get('sharding_rules', None), + metadata.get('mesh', None), + ) + + if mode == 'hijax': + value = hijax.Box(value) # type: ignore + elif mode == 'ref': + value = jax.new_ref(value) # type: ignore - # Create the ref out of the array value - if use_ref: - value = new_ref(jnp.asarray(value)) # type: ignore[assignment] # type: ignore[assignment] + object.__setattr__(self, '_raw_value', value) - object.__setattr__(self, 'raw_value', value) + @property + def trace_state(self) -> tracers.TraceState: + return self._trace_state def __getattr__(self, name: str) -> tp.Any: if name in object.__getattribute__(self, '_var_metadata'): return self._var_metadata[name] - return getattr(self.raw_value, name) + return getattr(self._raw_value, name) def __setattr__(self, name: str, value: tp.Any): - if not self._trace_state.is_valid() and ( - name != 'value' or not self.has_ref - ): + if not self._trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) @@ -330,32 +377,21 @@ def __delattr__(self, name: str): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) - - if ( - name == 'value' - or name == 'raw_value' - or name == '_var_metadata' - or name == '_trace_state' - ): + try: object.__delattr__(self, name) - else: - del self._var_metadata[name] + except AttributeError as e: + raise AttributeError( + f'Cannot delete attribute {name}. ' + f'To delete Variable metadata use:\n\n' + f" variable.del_metadata('{name}')" + ) from e # NOTE(cgarciae): adding this for backward compatibility with VariableState @property def type(self): """The type of the variable.""" - import warnings - warnings.warn( - "'.type' is deprecated, use 'type(variable)' instead.", - DeprecationWarning, - stacklevel=2, - ) - return type(self) - @property - def has_ref(self) -> bool: - return is_array_ref(self.raw_value) + return type(self) @tp.overload def get_metadata(self) -> dict[str, tp.Any]: ... @@ -372,11 +408,12 @@ def get_metadata( default: The default value to return if the metadata key is not found. If not provided and the key is not found, raises a KeyError. """ + metadata = self._var_metadata.copy() if name is None: - return self._var_metadata - if name not in self._var_metadata and not isinstance(default, Missing): + return metadata + if name not in metadata and not isinstance(default, Missing): return default - return self._var_metadata[name] + return metadata[name] @tp.overload def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... @@ -405,11 +442,28 @@ def set_metadata(self, *args, **kwargs) -> None: 'Cannot mix positional and keyword arguments in set_metadata' ) if len(args) == 1: - self._var_metadata = dict(args[0]) + metadata = dict(args[0]) + if 'mode' not in metadata: + raise ValueError('metadata is missing required key `mode` key') + if metadata['mode'] != self.mode: + raise ValueError( + f'Cannot change `mode` metadata, expected {self.mode}, ' + f'got {metadata["mode"]}' + ) + self._var_metadata = metadata elif len(args) == 2: name, value = args + if name == 'mode' and value != self.mode: + raise ValueError( + f'Cannot change `mode` metadata, expected {self.mode}, got {value}' + ) self._var_metadata[name] = value elif kwargs: + if 'mode' in kwargs and kwargs['mode'] != self.mode: + raise ValueError( + f'Cannot change `mode` metadata, expected {self.mode}, ' + f'got {kwargs["mode"]}' + ) self._var_metadata.update(kwargs) else: raise TypeError( @@ -417,6 +471,20 @@ def set_metadata(self, *args, **kwargs) -> None: f'got args={args}, kwargs={kwargs}' ) + def del_metadata(self, name: str) -> None: + """Delete a metadata entry for the Variable. + + Args: + name: The key of the metadata element to delete. + """ + if not self._trace_state.is_valid(): + raise errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' + ) + if name == 'mode': + raise ValueError('Cannot delete `mode` metadata') + del self._var_metadata[name] + def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( @@ -425,50 +493,103 @@ def copy_from(self, other: Variable[A]) -> None: ) if self is other: return - self.raw_value = other.raw_value + self._raw_value = other._raw_value self._var_metadata.clear() self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: Variable[A]): - if self.has_ref and ( - variable_state.has_ref or isinstance(variable_state.raw_value, jax.Array) - ): - self.raw_value[...] = variable_state.raw_value[...] # type: ignore - else: - object.__setattr__(self, 'raw_value', variable_state.raw_value) + object.__setattr__(self, '_raw_value', variable_state._raw_value) if self._var_metadata != variable_state._var_metadata: - object.__setattr__( - self, '_var_metadata', variable_state._var_metadata.copy() + metadata = variable_state.get_metadata().copy() + metadata['mode'] = self.mode + object.__setattr__(self, '_var_metadata', metadata) + + @tp.final + def get_raw_value(self) -> A: + return self._raw_value + + @tp.final + def set_raw_value(self, value: A): + if not self._trace_state.is_valid(): + raise errors.TraceContextError( + f'Cannot mutate {type(self).__name__} from a different trace level' ) + object.__setattr__(self, '_raw_value', value) + + @property + def raw_value(self) -> A: + warnings.warn( + "'.raw_value' access is now deprecated. Use:\n\n" + ' variable.get_raw_value()\n', + DeprecationWarning, + stacklevel=2, + ) + return self.get_raw_value() + + @raw_value.setter + def raw_value(self, value: A): + warnings.warn( + "'.raw_value' access is now deprecated. Use:\n\n" + ' variable.set_raw_value(value)\n', + DeprecationWarning, + stacklevel=2, + ) + self.set_raw_value(value) @property def value(self) -> A: - value = self.raw_value + warnings.warn( + "'.value' access is now deprecated. For Variable[Array] instances use:\n\n" + ' variable[...]\n\n' + 'For other Variable types use:\n\n' + ' variable.get_value()\n', + DeprecationWarning, + stacklevel=2, + ) + value = self._raw_value if is_array_ref(value): value = value[...] + return self.get_value() + + @value.setter + def value(self, value: A): + warnings.warn( + "'.value' access is now deprecated. For Variable[Array] instances use:\n\n" + ' variable[...] = value\n\n' + 'For other Variable types use:\n\n' + ' variable.set_value(value)\n', + DeprecationWarning, + stacklevel=2, + ) + self.set_value(value) + + def create_value(self, value: A): + return value + + def get_value(self) -> A: + if isinstance(self._raw_value, hijax.Box): + value = self._raw_value.get() + else: + value = jax.tree.map(lambda x: x, self._raw_value) # make a copy if 'on_get_value' in self._var_metadata: value = self._var_metadata['on_get_value'](self, value) return value - @value.setter - def value(self, value: A): + def set_value(self, value: A): + value = jax.tree.map(lambda x: x, value) # make a copy if isinstance(value, Variable): raise ValueError( 'Cannot set value to a Variable, use `copy_from` method instead' ) if 'on_set_value' in self._var_metadata: value = self._var_metadata['on_set_value'](self, value) - if self.has_ref: - self.raw_value[...] = value # type: ignore - else: - object.__setattr__(self, 'raw_value', value) - def create_value(self, value: A): - if 'on_create_value' in self._var_metadata: - value = self._var_metadata['on_create_value'](self, value) - return value + if isinstance(self._raw_value, hijax.Box): + self._raw_value.set(value) + else: + object.__setattr__(self, '_raw_value', value) def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_add_axis' in self._var_metadata: @@ -479,27 +600,28 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): self._var_metadata['on_remove_axis'](self, axis_index, axis_name) @tp.overload - def replace(self, value: B, **kwargs) -> Variable[B]: - ... + def replace(self, value: B, **kwargs) -> Variable[B]: ... @tp.overload - def replace(self, **kwargs) -> Variable[A]: - ... + def replace(self, **kwargs) -> Variable[A]: ... def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: if value is not Missing: - kwargs['raw_value'] = value + kwargs['value'] = value + + if 'is_hijax' in kwargs and kwargs['is_hijax'] != self.is_hijax: + raise ValueError( + f'Cannot change `is_hijax` metadata, expected {self.is_hijax}, ' + f'got {kwargs["is_hijax"]}' + ) - # rename `value` to `raw_value` - if 'value' in kwargs: - kwargs['raw_value'] = kwargs.pop('value') + if 'raw_value' in kwargs: + raise RuntimeError # return `value` if it is a Variable - if 'raw_value' in kwargs and isinstance( - value := kwargs['raw_value'], Variable - ): + if 'value' in kwargs and isinstance(value := kwargs['value'], Variable): # remove value from kwargs - kwargs.pop('raw_value') + kwargs.pop('value') if type(self) is not type(value): raise ValueError( 'Cannot replace value from incompatible container, ' @@ -516,41 +638,59 @@ def replace(self, value: tp.Any = Missing, **kwargs) -> Variable[tp.Any]: # return new instance with updated attributes obj = object.__new__(type(self)) object.__setattr__(obj, '_trace_state', self._trace_state) - object.__setattr__(obj, 'raw_value', kwargs.pop('raw_value')) + object.__setattr__(obj, '_raw_value', kwargs.pop('value')) object.__setattr__(obj, '_var_metadata', self.get_metadata() | kwargs) return obj @classmethod - def from_metadata(cls, value: A, attributes: dict[str, tp.Any]): + def _new( + cls, + value: A, + metadata: dict[str, tp.Any], + ) -> Variable[A]: obj = object.__new__(cls) object.__setattr__(obj, '_trace_state', tracers.TraceState()) - object.__setattr__(obj, 'raw_value', value) - object.__setattr__(obj, '_var_metadata', attributes) + object.__setattr__(obj, '_raw_value', value) + object.__setattr__(obj, '_var_metadata', metadata) return obj + @classmethod + def from_metadata( + cls, + value: A, + attributes: dict[str, tp.Any], + ) -> Variable[A]: + obj = cls._new(value, dict(attributes)) + return obj # type: ignore[return-value] + def copy(self: Variable[A]) -> Variable[A]: obj = object.__new__(type(self)) object.__setattr__(obj, '_trace_state', tracers.TraceState()) - object.__setattr__(obj, 'raw_value', self.raw_value) + object.__setattr__(obj, '_raw_value', self._raw_value) object.__setattr__(obj, '_var_metadata', self.get_metadata().copy()) return obj to_state = copy def __nnx_repr__(self): - stats = SizeBytes.from_any(self.raw_value) + stats = SizeBytes.from_any(self._raw_value) if stats: comment = f' # {stats}' else: comment = '' yield reprlib.Object(type=type(self).__name__, comment=comment) - yield reprlib.Attr('value', self.raw_value) + if isinstance(self._raw_value, hijax.Box): + yield reprlib.Attr('value', BoxRepr(self._raw_value)) + else: + yield reprlib.Attr('value', self._raw_value) for name, value in self._var_metadata.items(): + if name == 'is_hijax' and value is False: + continue yield reprlib.Attr(name, value) def __treescope_repr__(self, path, subtree_renderer): - size_bytes = SizeBytes.from_any(self.value) + size_bytes = SizeBytes.from_any(self.get_value()) if size_bytes: stats_repr = f' # {size_bytes}' first_line_annotation = treescope.rendering_parts.comment_color( @@ -559,7 +699,7 @@ def __treescope_repr__(self, path, subtree_renderer): else: first_line_annotation = None - children = {'value': self.raw_value, **self._var_metadata} + children = {'value': self.get_value(), **self._var_metadata} return visualization.render_object_constructor( object_type=type(self), attributes=children, @@ -586,18 +726,18 @@ def on_remove_axis( ) -> V: ... def __jax_array__(self): - return self.value + return self.get_value() # pickle support def __getstate__(self): return { - 'raw_value': self.raw_value, + '_raw_value': self._raw_value, '_trace_state': self._trace_state, '_var_metadata': self._var_metadata, } def __setstate__(self, state): - object.__setattr__(self, 'raw_value', state['raw_value']) + object.__setattr__(self, '_raw_value', state['_raw_value']) object.__setattr__(self, '_trace_state', state['_trace_state']) object.__setattr__(self, '_var_metadata', state['_var_metadata']) @@ -615,287 +755,161 @@ def __getitem__(self: Variable[tuple[B, ...]], key: int) -> B: ... @tp.overload def __getitem__(self, key) -> tp.Any: ... def __getitem__(self, key): - return self.value[key] # type: ignore + return self.get_value()[key] # type: ignore def __setitem__(self, key, item_value) -> None: - value = self.value + value = self.get_value() if isinstance(value, jax.Array): value = value.at[key].set(item_value) # type: ignore[assignment] else: value[key] = item_value # type: ignore - self.value = value # type: ignore + self.set_value(value) # type: ignore + + def __delitem__(self, key) -> None: + value = self.get_value() + del value[key] # type: ignore + self.set_value(value) # type: ignore def __call__(self, *args, **kwargs) -> tp.Any: - return self.value(*args, **kwargs) # type: ignore + return self.get_value()(*args, **kwargs) # type: ignore def __len__(self) -> int: - return len(self.value) # type: ignore + return len(self.get_value()) # type: ignore def __iter__(self) -> tp.Iterator: - return iter(self.value) # type: ignore + return iter(self.get_value()) # type: ignore def __contains__(self, item) -> bool: - return item in self.value # type: ignore - - def __add__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__add__(other) # type: ignore - - def __sub__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__sub__(other) # type: ignore - - def __mul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__mul__(other) # type: ignore - - def __matmul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__matmul__(other) # type: ignore - - def __truediv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__truediv__(other) # type: ignore - - def __floordiv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__floordiv__(other) # type: ignore - - def __mod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__mod__(other) # type: ignore - - def __divmod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__divmod__(other) # type: ignore - - def __pow__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__pow__(other) # type: ignore - - def __lshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__lshift__(other) # type: ignore - - def __rshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rshift__(other) # type: ignore - - def __and__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__and__(other) # type: ignore - - def __xor__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__xor__(other) # type: ignore - - def __or__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__or__(other) # type: ignore - - def __radd__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__radd__(other) # type: ignore - - def __rsub__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rsub__(other) # type: ignore - - def __rmul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rmul__(other) # type: ignore - - def __rmatmul__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rmatmul__(other) # type: ignore - - def __rtruediv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rtruediv__(other) # type: ignore - - def __rfloordiv__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rfloordiv__(other) # type: ignore - - def __rmod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rmod__(other) # type: ignore - - def __rdivmod__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rdivmod__(other) # type: ignore - - def __rpow__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rpow__(other) # type: ignore - - def __rlshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rlshift__(other) # type: ignore - - def __rrshift__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rrshift__(other) # type: ignore - - def __rand__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rand__(other) # type: ignore - - def __rxor__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__rxor__(other) # type: ignore - - def __ror__(self, other) -> A: - if isinstance(other, Variable): - other = other.value - return self.value.__ror__(other) # type: ignore - + return item in self.get_value() # type: ignore + + # binary operators + __add__ = _variable_operator('__add__') + __sub__ = _variable_operator('__sub__') + __mul__ = _variable_operator('__mul__') + __matmul__ = _variable_operator('__matmul__') + __truediv__ = _variable_operator('__truediv__') + __floordiv__ = _variable_operator('__floordiv__') + __mod__ = _variable_operator('__mod__') + __pow__ = _variable_operator('__pow__') + __lshift__ = _variable_operator('__lshift__') + __rshift__ = _variable_operator('__rshift__') + __and__ = _variable_operator('__and__') + __xor__ = _variable_operator('__xor__') + __or__ = _variable_operator('__or__') + __radd__ = _variable_operator('__radd__') + __rsub__ = _variable_operator('__rsub__') + __rmul__ = _variable_operator('__rmul__') + __rmatmul__ = _variable_operator('__rmatmul__') + __rtruediv__ = _variable_operator('__rtruediv__') + __rfloordiv__ = _variable_operator('__rfloordiv__') + __rmod__ = _variable_operator('__rmod__') + __rpow__ = _variable_operator('__rpow__') + __rlshift__ = _variable_operator('__rlshift__') + __rrshift__ = _variable_operator('__rrshift__') + __rand__ = _variable_operator('__rand__') + __rxor__ = _variable_operator('__rxor__') + __ror__ = _variable_operator('__ror__') + + # in-place operators def __iadd__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value += x` instead.' + 'Use `variable[...] += x` instead.' ) def __isub__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value -= x` instead.' + 'Use `variable[...] -= x` instead.' ) def __imul__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value *= x` instead.' + 'Use `variable[...] *= x` instead.' ) def __imatmul__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value @= x` instead.' + 'Use `variable[...] @= x` instead.' ) def __itruediv__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value /= x` instead.' + 'Use `variable[...] /= x` instead.' ) def __ifloordiv__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value //= x`` instead.' + 'Use `variable[...] //= x` instead.' ) def __imod__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value %= x` instead.' + 'Use `variable[...] %= x` instead.' ) def __ipow__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value **= x`` instead.' + 'Use `variable[...] **= x` instead.' ) def __ilshift__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value <<= x`` instead.' + 'Use `variable[...] <<= x` instead.' ) def __irshift__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value >>= x`` instead.' + 'Use `variable[...] >>= x` instead.' ) def __iand__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value &= x` instead.' + 'Use `variable[...] &= x` instead.' ) def __ixor__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value ^= x` instead.' + 'Use `variable[...] ^= x` instead.' ) def __ior__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' - 'Use `variable.value |= x` instead.' + 'Use `variable[...] |= x` instead.' ) - def __neg__(self) -> A: - return self.value.__neg__() # type: ignore - - def __pos__(self) -> A: - return self.value.__pos__() # type: ignore - - def __abs__(self) -> A: - return self.value.__abs__() # type: ignore - - def __invert__(self) -> A: - return self.value.__invert__() # type: ignore - - def __complex__(self) -> A: - return self.value.__complex__() # type: ignore - - def __int__(self) -> A: - return self.value.__int__() # type: ignore - - def __float__(self) -> A: - return self.value.__float__() # type: ignore - - def __index__(self) -> A: - return self.value.__index__() # type: ignore - - def __round__(self, ndigits: int) -> A: - return self.value.__round__(ndigits) # type: ignore - - def __trunc__(self) -> A: - return self.value.__trunc__() # type: ignore - - def __floor__(self) -> A: - return self.value.__floor__() # type: ignore - - def __ceil__(self) -> A: - return self.value.__ceil__() # type: ignore + __neg__ = _variable_unary_operator('__neg__') + __pos__ = _variable_unary_operator('__pos__') + __abs__ = _variable_unary_operator('__abs__') + __invert__ = _variable_unary_operator('__invert__') + __complex__ = _variable_unary_operator('__complex__') + __int__ = _variable_unary_operator('__int__') + __float__ = _variable_unary_operator('__float__') + __index__ = _variable_unary_operator('__index__') + __trunc__ = _variable_unary_operator('__trunc__') + __floor__ = _variable_unary_operator('__floor__') + __ceil__ = _variable_unary_operator('__ceil__') + + def __round__(self, ndigits: int = 0) -> A: + return self.get_value().__round__(ndigits) # type: ignore # -------------------------------------------- def __init_subclass__(cls) -> None: + if '__slots__' not in vars(cls): + cls.__slots__ = () # type: ignore[assignment] super().__init_subclass__() - jax.tree_util.register_pytree_with_keys( cls, flatten_with_keys=_variable_flatten_with_keys, @@ -906,13 +920,13 @@ def __init_subclass__(cls) -> None: def _variable_flatten_with_keys(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) - node = (jtu.GetAttrKey('value'), x.raw_value) + node = (jtu.GetAttrKey('value'), x._raw_value) return (node,), metadata def _variable_flatten(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) - return (x.raw_value,), metadata + return (x._raw_value,), metadata def _variable_unflatten( @@ -920,7 +934,7 @@ def _variable_unflatten( static: tuple[tuple[str, tp.Any], ...], children: tuple[tp.Any], ): - return cls.from_metadata(value=children[0], attributes=dict(static)) + return cls._new(children[0], dict(static)) jax.tree_util.register_pytree_with_keys( @@ -930,9 +944,9 @@ def _variable_unflatten( flatten_func=_variable_flatten, ) - VariableState = Variable + class Param(Variable[A]): """The canonical learnable parameter. All learnable parameters in NNX layer modules will have the ``Param`` :class:`Variable` @@ -1153,32 +1167,6 @@ def wrapper(*args): return wrapper # type: ignore -def split_flat_state( - flat_state: tp.Iterable[tuple[PathParts, Variable]], - filters: tuple[filterlib.Filter, ...], -) -> tuple[list[tuple[PathParts, Variable]], ...]: - predicates = filterlib.filters_to_predicates(filters) - # we have n + 1 states, where n is the number of predicates - # the last state is for values that don't match any predicate - flat_states: tuple[list[tuple[PathParts, Variable]], ...] = ( - tuple([] for _ in predicates) - ) - - for path, value in flat_state: - for i, predicate in enumerate(predicates): - if predicate(path, value): - flat_states[i].append((path, value)) - break - else: - raise ValueError( - 'Non-exhaustive filters, got a non-empty remainder: ' - f'{path} -> {value}.' - '\nUse `...` to match all remaining elements.' - ) - - return flat_states - - ################################################### ### Variable type/class <-> string name mapping ### ################################################### @@ -1232,13 +1220,6 @@ def variable_name_from_type( return name -class _Missing: - pass - - -_MISSING = _Missing() - - @tp.overload def register_variable_name( name: str, @@ -1258,12 +1239,12 @@ def register_variable_name( def register_variable_name( name: str, - typ: type[Variable[A]] | _Missing = _MISSING, + typ: type[Variable[A]] | Missing = MISSING, *, overwrite=False, ) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]: """Register a pair of Linen collection name and its NNX type.""" - if typ is _MISSING: + if isinstance(typ, Missing): return partial(register_variable_name, name, overwrite=overwrite) typ = tp.cast(type[Variable[A]], typ) if not overwrite and name in VariableTypeCache: diff --git a/tests/nnx/containers_test.py b/tests/nnx/containers_test.py index 70b51283f..7aa80cb73 100644 --- a/tests/nnx/containers_test.py +++ b/tests/nnx/containers_test.py @@ -34,7 +34,7 @@ def test_on_set_value(self): ) x[...] = 5 - assert x.raw_value == 12 + assert x.get_raw_value() == 12 def test_module_unbox(self): class Foo(nnx.Module): @@ -43,8 +43,8 @@ def __init__(self) -> None: module = Foo() - assert module.x.value == 4 - assert vars(module)['x'].raw_value == 1 + assert module.x.get_value() == 4 + assert vars(module)['x'].get_raw_value() == 1 def test_module_box(self): class Foo(nnx.Module): @@ -58,7 +58,7 @@ def __init__(self) -> None: module.x[...] = 5 assert module.x[...] == 12 - assert vars(module)['x'].raw_value == 12 + assert vars(module)['x'][...] == 12 if __name__ == '__main__': diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 523ff31fa..30638b3b8 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -46,8 +46,8 @@ def test_flatten(self): refmap = nnx.graph.RefMap() graphdef, flat_state = nnx.graph.flatten(g, ref_index=refmap) - assert flat_state[0][1].value == 2 - assert flat_state[1][1].value == 4 + assert flat_state[0][1].get_value() == 2 + assert flat_state[1][1].get_value() == 4 assert len(refmap) == 2 # 2 Variables assert a['b'] in refmap @@ -156,8 +156,8 @@ def test_update_dynamic(self): state[0]['b'][...] = 3 nnx.update(g, state) - assert g[0]['b'].value == 3 - assert g[2]['b'].value == 3 + assert g[0]['b'][...] == 3 + assert g[2]['b'][...] == 3 def test_update_from_pure_dict(self): a = {'a': 1, 'b': nnx.Param(jnp.array(2))} @@ -342,7 +342,7 @@ def __init__(self): m2 = nnx.merge(graphdef, state) assert isinstance(m2.tree, Tree) - assert m2.tree.a.raw_value == 1 + assert m2.tree.a.get_value() == 1 assert m2.tree.b == 'a' assert m2.tree.a is m.tree.a assert m2.tree is not m.tree diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py index 67eaa71c5..e70919683 100644 --- a/tests/nnx/integration_test.py +++ b/tests/nnx/integration_test.py @@ -146,17 +146,17 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - self.count = State(0) + self.count = State(jnp.array(0)) def __call__(self, x): - self.count.value += 1 - return x @ self.w.value + self.b.value[None] + self.count[...] += 1 + return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) - assert model.count.value == 1 + assert model.count[...] == 1 @nnx.jit def train_step(model, x, y): @@ -176,7 +176,7 @@ def loss_fn(model): # execute the training step train_step(model, x, y) - assert model.count.value == 2 + assert model.count[...] == 2 def test_functional_example(self): class Count(nnx.Variable[A]): @@ -187,17 +187,17 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): key = rngs.params() self.w = nnx.Param(jax.random.uniform(key, (din, dout))) self.b = nnx.Param(jnp.zeros((dout,))) - self.count = Count(0) + self.count = Count(jnp.array(0)) def __call__(self, x): - self.count.value += 1 - return x @ self.w.value + self.b.value[None] + self.count[...] += 1 + return x @ self.w + self.b[None] model = Linear(din=12, dout=2, rngs=nnx.Rngs(0)) # forward pass x = jnp.ones((8, 12)) y = model(x) - assert model.count.value == 1 + assert model.count[...] == 1 graphdef, params, counts = nnx.split(model, nnx.Param, Count) @@ -218,7 +218,7 @@ def loss_fn(params): # execute the training step params, counts = train_step(params, counts, x, y) model = nnx.merge(graphdef, params, counts) - assert model.count.value == 2 + assert model.count[...] == 2 def test_intermediates_example(self): class Linear(nnx.Module): @@ -228,7 +228,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - y = x @ self.w.value + self.b.value[None] + y = x @ self.w + self.b[None] self.y = nnx.Intermediate(y) return y @@ -248,7 +248,7 @@ def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): self.b = nnx.Param(jnp.zeros((dout,))) def __call__(self, x): - y = x @ self.w.value + self.b.value[None] + y = x @ self.w + self.b[None] self.y = nnx.Intermediate(y) return y @@ -310,7 +310,7 @@ def __call__(self, x): x = nnx.relu(self.dropout(self.bn(self.linear(x)))) return self.linear_out(x) - with nnx.use_refs(True): + with nnx.use_hijax(True): model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) @@ -319,9 +319,9 @@ def train_step(x, y): graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) - return ((model(x) - y) ** 2).mean() # call methods directly + return ((model(x) - y) ** 2).mean() # call methods directly - loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.to_lojax(params)) optimizer.update(model, grads) # in-place updates return loss diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 64bf1dda2..a5e17bd7a 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -1,4 +1,5 @@ # Copyright 2024 The Flax Authors. + # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -333,13 +334,13 @@ def test_clone(self): m2 = nnx.clone(m) assert m is not m2 - assert m2.a[0].value == m2.b.c.value - assert m2.a[1].value == m2.b.d.value + assert m2.a[0].get_value() == m2.b.c.get_value() + assert m2.a[1].get_value() == m2.b.d.get_value() - assert m.a[0].value == m2.a[0].value - assert m.a[1].value == m2.a[1].value - assert m.b.c.value == m2.b.c.value - assert m.b.d.value == m2.b.d.value + assert m.a[0].get_value() == m2.a[0].get_value() + assert m.a[1].get_value() == m2.a[1].get_value() + assert m.b.c.get_value() == m2.b.c.get_value() + assert m.b.d.get_value() == m2.b.d.get_value() def test_sow_basic(self): class Foo(nnx.Module): @@ -354,12 +355,12 @@ def __call__(self, x): assert y1 == 3 assert y2 == 11 - assert m.y.value == (3, 11) + assert m.y.get_value() == (3, 11) intermediates = nnx.pop(m, nnx.Intermediate) assert isinstance(intermediates['y'], nnx.Intermediate) - assert intermediates['y'].value == (3, 11) + assert intermediates['y'].get_value() == (3, 11) assert not hasattr(m, 'y') @@ -550,13 +551,13 @@ def add_submodule(self): def test_create_abstract(self): linear = nnx.eval_shape(lambda: nnx.Linear(2, 3, rngs=nnx.Rngs(0))) - assert linear.kernel.value == jax.ShapeDtypeStruct((2, 3), jnp.float32) - assert linear.bias.value == jax.ShapeDtypeStruct((3,), jnp.float32) + assert linear.kernel.get_value() == jax.ShapeDtypeStruct((2, 3), jnp.float32) + assert linear.bias.get_value() == jax.ShapeDtypeStruct((3,), jnp.float32) def test_create_abstract_stateful(self): linear = nnx.eval_shape(lambda: nnx.Dropout(0.5, rngs=nnx.Rngs(0))) - assert linear.rngs.key.value == jax.ShapeDtypeStruct( + assert linear.rngs.key.get_value() == jax.ShapeDtypeStruct( (), jax.random.key(0).dtype ) @@ -742,13 +743,13 @@ class Foo(nnx.Module): graphdef, state = nnx.split(m) assert len(state) == 4 - assert state['b'].value == 2 + assert state['b'].get_value() == 2 assert isinstance(state['b'], nnx.Variable) - assert state['c'].value == 3 + assert state['c'].get_value() == 3 assert isinstance(state['c'], nnx.Param) - assert state['d'].value == 4 + assert state['d'].get_value() == 4 assert isinstance(state['d'], nnx.Variable) - assert state['e'].value == 5 + assert state['e'].get_value() == 5 assert isinstance(state['e'], nnx.BatchStat) def test_post_init(self): diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index d30a2489d..4339c2b8a 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -15,22 +15,22 @@ import dataclasses from absl.testing import absltest import optax +import pytest from flax import nnx import flax.errors import jax import jax.numpy as jnp - class TestObject(absltest.TestCase): @classmethod def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) + cls.using_hijax = nnx.current_variable_mode() + nnx.variable_mode('hijax') @classmethod def tearDownClass(cls): - nnx.use_refs(cls.using_refs) + nnx.variable_mode(cls.using_hijax) def test_pytree(self): class Foo(nnx.Module): @@ -127,12 +127,12 @@ def __init__(self, a): class TestMutableArrayGraph(absltest.TestCase): @classmethod def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) + cls.using_hijax = nnx.current_variable_mode() + nnx.variable_mode('hijax') @classmethod def tearDownClass(cls): - nnx.use_refs(cls.using_refs) + nnx.variable_mode(cls.using_hijax) def test_split_mutable_array(self): m = jax.new_ref(1) @@ -150,14 +150,14 @@ def __init__(self): self.a = nnx.Param(1) m = Foo() - self.assertTrue(m.a.has_ref) + self.assertEqual(m.a.mode, 'hijax') - m2 = nnx.to_arrays(m) - self.assertFalse(m2.a.has_ref) + m2 = nnx.to_lojax(m) + self.assertEqual(m2.a.mode, 'lojax') self.assertIsNot(m, m2) - m3 = nnx.to_refs(m2) - self.assertTrue(m3.a.has_ref) + m3 = nnx.to_hijax(m2) + self.assertEqual(m3.a.mode, 'hijax') self.assertIsNot(m2, m3) self.assertIsNot(m2.a, m3.a) @@ -188,17 +188,17 @@ def __init__(self): self.b = nnx.BatchStat(2) m = Foo() - self.assertTrue(m.a.has_ref) - self.assertTrue(m.b.has_ref) + self.assertEqual(m.a.mode, 'hijax') + self.assertEqual(m.b.mode, 'hijax') - m2 = nnx.to_arrays(m, only=nnx.BatchStat) - self.assertTrue(m2.a.has_ref) - self.assertFalse(m2.b.has_ref) + m2 = nnx.to_lojax(m, only=nnx.BatchStat) + self.assertEqual(m2.a.mode, 'hijax') + self.assertEqual(m2.b.mode, 'lojax') self.assertIsNot(m, m2) - m3 = nnx.to_refs(m2, nnx.BatchStat) - self.assertTrue(m3.a.has_ref) - self.assertTrue(m3.b.has_ref) + m3 = nnx.to_hijax(m2, only=nnx.BatchStat) + self.assertEqual(m3.a.mode, 'hijax') + self.assertEqual(m3.b.mode, 'hijax') self.assertIsNot(m2, m3) self.assertIs(m.a, m3.a) @@ -233,7 +233,7 @@ def __init__(self): def test_mutable_array_split_merge_in_variable(self): class Foo(nnx.Module): def __init__(self): - self.a = nnx.Param(1, use_ref=True) + self.a = nnx.Param(1, mode='hijax') self.b = self.a m = Foo() @@ -241,7 +241,7 @@ def __init__(self): ref_map = nnx.graph.RefMap() graphdef, state = nnx.graph.flatten(m, ref_index=ref_map) self.assertLen(state, 1) - self.assertLen(ref_map, 3) # 1 Foo + 1 Param + 1 ArrayRef + self.assertLen(ref_map, 2) # 1 Foo + 1 Param m1 = nnx.merge(graphdef, state) self.assertIs(m1.a, m1.b) @@ -251,20 +251,18 @@ def test_mutable_array_split_merge_in_variable_shared_array(self): class Foo(nnx.Module): def __init__(self): m_array = 1 - self.a = nnx.Param(m_array, use_ref=True) - self.b = nnx.Param(m_array, use_ref=True) + self.a = nnx.Param(m_array, mode='hijax') + self.b = nnx.Param(m_array, mode='hijax') m = Foo() - self.assertIsNot(m.a.raw_value, m.b.raw_value) ref_map = nnx.graph.RefMap() graphdef, state = nnx.graph.flatten(m, ref_index=ref_map) self.assertLen(state, 2) - self.assertLen(ref_map, 5) # 1 Foo + 2 Param + 2 ArrayRefs + self.assertLen(ref_map, 3) # 1 Foo + 2 Param m1 = nnx.merge(graphdef, state) # Each variable will own its own array and ref. - self.assertIsNot(m1.a.raw_value, m1.b.raw_value) self.assertIsInstance(m1.a, nnx.Param) def test_mutable_example(self): @@ -283,15 +281,20 @@ def __init__(self): ref_map = nnx.graph.RefMap() graphdef, state = nnx.graph.flatten(m, ref_index=ref_map) - state = nnx.to_arrays(state) + state = nnx.to_lojax(state) self.assertLen(state, 1) - m1 = nnx.merge(graphdef, nnx.to_refs(state)) + m1 = nnx.merge(graphdef, nnx.to_hijax(state)) self.assertIs(m1.a, m1.b) self.assertIsInstance(m1.a, jax.Ref) def test_update_context(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + class Foo(nnx.Module): + def __init__(self): + self.kernel = jax.new_ref(1) + self.bias = jax.new_ref(2) + + m1 = Foo() with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.split(m1) @@ -299,19 +302,15 @@ def test_update_context(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = Foo() with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.split((m2, m_out1, m2)) - self.assertIsInstance(state_out[0]['kernel'].value, nnx.graph.NoUpdate) - self.assertIsInstance(state_out[0]['bias'].value, nnx.graph.NoUpdate) - self.assertIsInstance( - state_out[1]['kernel'].value, nnx.graph.ArrayRefOutput - ) - self.assertIsInstance( - state_out[1]['bias'].value, nnx.graph.ArrayRefOutput - ) + self.assertIsInstance(state_out[0]['kernel'], nnx.graph.NoUpdate) + self.assertIsInstance(state_out[0]['bias'], nnx.graph.NoUpdate) + self.assertIsInstance(state_out[1]['kernel'], nnx.graph.ArrayRefOutput) + self.assertIsInstance(state_out[1]['bias'], nnx.graph.ArrayRefOutput) # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(state_out), 2) @@ -322,7 +321,12 @@ def test_update_context(self): self.assertIsNot(m_out2, m_out1) def test_update_context_flatten(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + class Foo(nnx.Module): + def __init__(self): + self.kernel = jax.new_ref(1) + self.bias = jax.new_ref(2) + + m1 = Foo() with nnx.update_context('example'): with nnx.split_context('example') as ctx: graphdef, state = ctx.flatten(m1) @@ -330,24 +334,20 @@ def test_update_context_flatten(self): with nnx.merge_context('example', True) as ctx: m2 = ctx.merge(graphdef, state) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = Foo() with nnx.split_context('example') as ctx: graphdef_out, state_out = ctx.flatten((m2, m_out1, m2)) state_out_dict = dict(state_out) + self.assertIsInstance(state_out_dict[(0, 'kernel')], nnx.graph.NoUpdate) + self.assertIsInstance(state_out_dict[(0, 'bias')], nnx.graph.NoUpdate) self.assertIsInstance( - state_out_dict[(0, 'kernel')].value, nnx.graph.NoUpdate - ) - self.assertIsInstance( - state_out_dict[(0, 'bias')].value, nnx.graph.NoUpdate - ) - self.assertIsInstance( - state_out_dict[(1, 'kernel')].value, nnx.graph.ArrayRefOutput + state_out_dict[(1, 'kernel')], nnx.graph.ArrayRefOutput ) self.assertIsInstance( - state_out_dict[(1, 'bias')].value, nnx.graph.ArrayRefOutput + state_out_dict[(1, 'bias')], nnx.graph.ArrayRefOutput ) # 2 ArrayRefOutput + 2 NoUpdate, however, NoUpdate are empty nodes self.assertLen(jax.tree.leaves(state_out), 2) @@ -359,29 +359,34 @@ def test_update_context_flatten(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree1(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + class Foo(nnx.Module): + def __init__(self): + self.kernel = jax.new_ref(1) + self.bias = jax.new_ref(2) + + m1 = Foo() with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = Foo() # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example') self.assertIsInstance( - out_tree[0][0].states[0]['kernel'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['kernel'], nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[0][0].states[0]['bias'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['bias'], nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[1].states[0]['kernel'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['kernel'], nnx.graph.ArrayRefOutput ) self.assertIsInstance( - out_tree[1].states[0]['bias'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['bias'], nnx.graph.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State @@ -398,29 +403,34 @@ def test_update_context_to_tree1(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree2(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + class Foo(nnx.Module): + def __init__(self): + self.kernel = jax.new_ref(1) + self.bias = jax.new_ref(2) + + m1 = Foo() with nnx.update_context('example') as ctx: m1_tree = nnx.to_tree((m1,), ctxtag='example') (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = Foo() # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example') self.assertIsInstance( - out_tree[0][0].states[0]['kernel'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['kernel'], nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[0][0].states[0]['bias'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['bias'], nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[1].states[0]['kernel'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['kernel'], nnx.graph.ArrayRefOutput ) self.assertIsInstance( - out_tree[1].states[0]['bias'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['bias'], nnx.graph.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State @@ -437,29 +447,34 @@ def test_update_context_to_tree2(self): self.assertIsNot(m_out2, m_out1) def test_update_context_to_tree_trivial_prefix(self): - m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + class Foo(nnx.Module): + def __init__(self): + self.kernel = jax.new_ref(1) + self.bias = jax.new_ref(2) + + m1 = Foo() with nnx.update_context('example'): m1_tree = nnx.to_tree((m1,), ctxtag='example', prefix=0) (m2,) = nnx.from_tree(m1_tree, ctxtag='example', is_inner=True, prefix=0) - m_out1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) + m_out1 = Foo() # with nnx.split_context('example') as ctx: # graphdef_out, state_out = ctx.split((m2, m_out1)) out_tree = nnx.to_tree(((m2,), m_out1, m2), ctxtag='example', prefix=0) self.assertIsInstance( - out_tree[0][0].states[0]['kernel'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['kernel'], nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[0][0].states[0]['bias'].value, nnx.graph.NoUpdate + out_tree[0][0].states[0]['bias'], nnx.graph.NoUpdate ) self.assertIsInstance( - out_tree[1].states[0]['kernel'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['kernel'], nnx.graph.ArrayRefOutput ) self.assertIsInstance( - out_tree[1].states[0]['bias'].value, nnx.graph.ArrayRefOutput + out_tree[1].states[0]['bias'], nnx.graph.ArrayRefOutput ) self.assertEmpty(out_tree[2].states[0]) # Repeated m2 State @@ -480,12 +495,12 @@ def test_update_context_to_tree_trivial_prefix(self): class TestMutableArrayNNXTransforms(absltest.TestCase): @classmethod def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) + cls.using_hijax = nnx.current_variable_mode() + nnx.variable_mode('hijax') @classmethod def tearDownClass(cls): - nnx.use_refs(cls.using_refs) + nnx.variable_mode(cls.using_hijax) def test_simple_jit(self): m1 = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) @@ -501,7 +516,7 @@ def f(m2): self.assertIsNot(m_out1, m_out2) self.assertIsInstance(m_out2.kernel, nnx.Param) - self.assertIsInstance(m_out2.kernel.raw_value, jax.Ref) + self.assertIsInstance(m_out2.kernel[...], jax.Array) def test_jit_mutable(self): @dataclasses.dataclass @@ -525,12 +540,12 @@ def f(m2: Foo): class TestMutableArray(absltest.TestCase): @classmethod def setUpClass(cls): - cls.using_refs = nnx.using_refs() - nnx.use_refs(True) + cls.using_hijax = nnx.current_variable_mode() + nnx.variable_mode('hijax') @classmethod def tearDownClass(cls): - nnx.use_refs(cls.using_refs) + nnx.variable_mode(cls.using_hijax) def test_static(self): class C(nnx.Module): @@ -554,12 +569,12 @@ def f(x): assert n == 2 def test_variable_creation(self): - v = nnx.Variable(1) + v = nnx.Variable(jnp.array(1)) self.assertEqual(v[...], 1) - self.assertTrue(v.has_ref) + self.assertEqual(v.mode, 'hijax') def test_variable_metadata(self): - v = nnx.Variable(1, a=2, b=3) + v = nnx.Variable(jnp.array(1), a=2, b=3) self.assertEqual(v.a, 2) self.assertEqual(v.b, 3) @@ -568,7 +583,7 @@ class Params(nnx.Pytree): def __init__(self, din: int, dout: int): self.w = nnx.Param(jnp.zeros((din, dout), jnp.float32)) self.b = nnx.Param(jnp.zeros((dout,), jnp.float32)) - self.count = nnx.Variable(0) + self.count = nnx.Variable(jnp.array(0)) params: Params params = Params(3, 4) @@ -580,18 +595,9 @@ def __init__(self, din: int, dout: int): self.assertEqual(leaves[0].shape, (4,)) # b self.assertEqual(leaves[1].shape, ()) # count self.assertEqual(leaves[2].shape, (3, 4)) # w - self.assertEqual( - paths[0], - (jax.tree_util.GetAttrKey('b'), jax.tree_util.GetAttrKey('value')), - ) - self.assertEqual( - paths[1], - (jax.tree_util.GetAttrKey('count'), jax.tree_util.GetAttrKey('value')), - ) - self.assertEqual( - paths[2], - (jax.tree_util.GetAttrKey('w'), jax.tree_util.GetAttrKey('value')), - ) + self.assertEqual(paths[0], (jax.tree_util.GetAttrKey('b'),)) + self.assertEqual(paths[1], (jax.tree_util.GetAttrKey('count'),)) + self.assertEqual(paths[2], (jax.tree_util.GetAttrKey('w'),)) params = jax.tree.unflatten(treedef, leaves) @@ -655,7 +661,6 @@ def test_rngs_create(self): ( jax.tree_util.GetAttrKey('default'), jax.tree_util.GetAttrKey('count'), - jax.tree_util.GetAttrKey('value'), ), ) self.assertEqual( @@ -663,7 +668,6 @@ def test_rngs_create(self): ( jax.tree_util.GetAttrKey('default'), jax.tree_util.GetAttrKey('key'), - jax.tree_util.GetAttrKey('value'), ), ) @@ -687,7 +691,7 @@ def __call__(self, x): x = jax.random.normal(jax.random.key(0), (5, 2)) y = jnp.ones((5, 4)) - with nnx.use_refs(False): + with nnx.variable_mode('lojax'): wrt = lambda path, x: path[-1] == 'w' model = Model(nnx.Rngs(1)) optimizer = nnx.Optimizer( @@ -717,8 +721,8 @@ def loss_fn(params): def test_optimize_mutable_arrays(self): class Model(nnx.Module): def __init__(self, rngs): - self.w = jax.new_ref(jax.random.uniform(rngs(), (2, 4))) - self.count = jax.new_ref(jnp.array(0)) + self.w = nnx.Variable(jax.random.uniform(rngs(), (2, 4))) + self.count = nnx.Variable(jnp.array(0)) def __call__(self, x): self.count[...] += 1 @@ -727,7 +731,7 @@ def __call__(self, x): x = jax.random.normal(jax.random.key(0), (5, 2)) y = jnp.ones((5, 4)) - with nnx.use_refs(True): + with nnx.variable_mode('hijax'): wrt = lambda path, x: path[-1] == 'w' model = Model(nnx.Rngs(1)) optimizer = nnx.Optimizer(model, tx=optax.adam(1e-3), wrt=wrt) @@ -740,7 +744,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.to_arrays(params)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.to_lojax(params)) optimizer.update(params, grads) return loss @@ -748,6 +752,158 @@ def loss_fn(params): self.assertNotEqual(loss, 0.0) +class TestHijaxVariables(absltest.TestCase): + def test_variable_to_hijax(self): + v_low = nnx.Param(jnp.array(1), a='hi') + v_hi = nnx.to_hijax(v_low) + + self.assertEqual(v_hi.mode, 'hijax') + self.assertEqual(v_hi[...], 1) + self.assertIsInstance(v_hi, nnx.Param) + + v_hi[...] = 2 + self.assertEqual(v_hi[...], 2) + + @jax.jit + def set(v_hi: nnx.Param, a): + self.assertIsInstance(v_hi, nnx.Param) + v_hi[...] = a + self.assertEqual(v_hi.a, 'hi') + self.assertEqual(v_hi.mode, 'hijax') + v_hi[...] += 5 + return v_hi + 2 + + y = set(v_hi, 10) + self.assertEqual(v_hi[...], 15) + self.assertEqual(y, 17) + + v_low = nnx.to_lojax(v_hi) + self.assertEqual(v_low.mode, 'lojax') + self.assertIsInstance(v_low, nnx.Param) + + def test_from_metadata(self): + value = 1 + metadata = {'a': 'hi', 'mode': 'lojax'} + v_low = nnx.Param.from_metadata(value, metadata) + self.assertIsInstance(v_low, nnx.Param) + self.assertEqual(v_low.mode, 'lojax') + + metadata['mode'] = 'hijax' + v_hi = nnx.Param.from_metadata(value, metadata) + self.assertIsInstance(v_hi, nnx.Param) + self.assertEqual(v_hi.mode, 'hijax') + + def test_variable_to_hijax_clean(self): + v_low = nnx.Param(jnp.array([1]), tag='hello') + print() + print(v_low) + assert v_low.mode == 'lojax' + v_hi = nnx.to_hijax(v_low) + v_hi[...] = jnp.array([2]) + assert v_hi.mode == 'hijax' + print(v_hi) + assert v_hi[...] == 2 + + @jax.jit + def set(v_hi, a): + v_hi[...] = a + print(v_hi) + assert v_hi.tag == 'hello' + + set(v_hi, 10) + + assert v_hi[...] == 10 + + v_low = nnx.to_lojax(v_hi) + + assert v_low.mode == 'lojax' + assert v_low[...] == 10 + + def test_hijax_and_pytree(self): + class Foo(nnx.Pytree): + def __init__(self, din, dout, rngs: nnx.Rngs): + self.w = nnx.Param(rngs.uniform((din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + self.count = nnx.Variable(0) + + foo = Foo(2, 4, nnx.Rngs(1)) + assert foo.w.mode == 'lojax' + assert foo.b.mode == 'lojax' + + foo = nnx.to_hijax(foo) + + assert foo.w.mode == 'hijax' + assert foo.b.mode == 'hijax' + + @jax.jit + def forward(foo, x): + foo.count[...] += 1 + return x @ foo.w + foo.b[None] + + x = jnp.ones((1, 2)) + y = forward(foo, x) + assert y.shape == (1, 4) + assert foo.count[...] == 1 + + def test_use_hijax(self): + v_low = nnx.Param(1, a='hi') + self.assertEqual(v_low.mode, 'lojax') + + v_hi = nnx.Param(1, a='hi', mode='hijax') + self.assertEqual(v_hi.mode, 'hijax') + + with nnx.variable_mode('hijax'): + v2 = nnx.Param(1, a='hi') + self.assertEqual(v2.mode, 'hijax') + + @nnx.variable_mode('hijax') + def test_hijax_rngs(self): + rngs = nnx.Rngs(0) + + @jax.jit + def f(rngs: nnx.Rngs): + return rngs() + + k1 = f(rngs) + k2 = f(rngs) + + assert k1 != k2 + + @pytest.mark.skip(reason='not yet supported') + def test_return_hijax_from_transform(self): + @jax.jit + def create_var(): + return nnx.Param(1, mode='hijax') + + v = create_var() + self.assertEqual(v.mode, 'hijax') + + @pytest.mark.skip(reason='not yet supported') + @nnx.variable_mode('hijax') + def test_lower(self): + v = nnx.Param(jnp.ones((2, 3))) + + @jax.jit + def f(v): + v[...] += 1 + return v[...] + + e = f.lower(v) + y = e.out_info[2] + self.assertEqual(y.shape, ()) + + @nnx.variable_mode('hijax') + def test_eval_shape(self): + v = nnx.Param(jnp.array(0)) + + def f(v): + v[...] += 1 + return v[...] + + y = jax.eval_shape(f, v) + + self.assertEqual(y.shape, ()) + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/nn/embed_test.py b/tests/nnx/nn/embed_test.py index 8991c8a35..3120ad84d 100644 --- a/tests/nnx/nn/embed_test.py +++ b/tests/nnx/nn/embed_test.py @@ -59,7 +59,7 @@ def test_nnx_linen_equivalence( NUM_EMBEDDINGS, IN_FEATURES, dtype=dtype, param_dtype=param_dtype ) variables = model.init(key, x) - model_nnx.embedding.value = variables['params']['embedding'] + model_nnx.embedding.set_value(variables['params']['embedding']) out_nnx = model_nnx(x) out = model.apply(variables, x) diff --git a/tests/nnx/nn/linear_test.py b/tests/nnx/nn/linear_test.py index a69c37765..5b7e89749 100644 --- a/tests/nnx/nn/linear_test.py +++ b/tests/nnx/nn/linear_test.py @@ -121,9 +121,9 @@ def test_nnx_linear_equivalence( dot_general=dot_general, ) variables = model.init(key, x) - model_nnx.kernel.value = variables['params']['kernel'] + model_nnx.kernel.set_value(variables['params']['kernel']) if use_bias: - model_nnx.bias.value = variables['params']['bias'] + model_nnx.bias.set_value(variables['params']['bias']) out_nnx = model_nnx(x) out = model.apply(variables, x) @@ -184,10 +184,10 @@ def test_nnx_einsum_equivalence( np.testing.assert_array_equal(out, out_nnx) variables = model.init(key, x) - model_nnx.kernel.value = variables['params']['kernel'] + model_nnx.kernel.set_value(variables['params']['kernel']) if bias_shape is not None: assert model_nnx.bias is not None - model_nnx.bias.value = variables['params']['bias'] + model_nnx.bias.set_value(variables['params']['bias']) out_nnx = model_nnx(x) out = model.apply(variables, x) assert isinstance(out, jax.Array) diff --git a/tests/nnx/nn/normalization_test.py b/tests/nnx/nn/normalization_test.py index 2651056c2..bd5d1dd70 100644 --- a/tests/nnx/nn/normalization_test.py +++ b/tests/nnx/nn/normalization_test.py @@ -241,8 +241,8 @@ def __call__(self, x, *, mask=None): use_fast_variance=use_fast_variance, rngs=rngs, ) - nnx_model.linear.kernel.value = variables['params']['linear']['kernel'] - nnx_model.linear.bias.value = variables['params']['linear']['bias'] + nnx_model.linear.kernel.set_value(variables['params']['linear']['kernel']) + nnx_model.linear.bias.set_value(variables['params']['linear']['bias']) nnx_out = nnx_model(x, mask=mask) assert isinstance(linen_out, jax.Array) @@ -468,20 +468,20 @@ def __call__(self, x): ) # Setup the same weights and batch stats var_params_seq_0 = variables['params']['seq']['layers_0'] - nnx_model.seq.layers[0].kernel.value = var_params_seq_0['kernel'] - nnx_model.seq.layers[0].bias.value = var_params_seq_0['bias'] + nnx_model.seq.layers[0].kernel.set_value(var_params_seq_0['kernel']) + nnx_model.seq.layers[0].bias.set_value(var_params_seq_0['bias']) var_params_seq_2 = variables['params']['seq']['layers_2'] - nnx_model.seq.layers[2].scale.value = var_params_seq_2['scale'] - nnx_model.seq.layers[2].bias.value = var_params_seq_0['bias'] + nnx_model.seq.layers[2].scale.set_value(var_params_seq_2['scale']) + nnx_model.seq.layers[2].bias.set_value(var_params_seq_0['bias']) var_norm_layer = variables['batch_stats']['norm_layer'] nnx_model.norm_layer.batch_stats[ ('layers', 0, 'kernel', 'u') - ].value = var_norm_layer['seq/layers_0/kernel/u'] + ].set_value(var_norm_layer['seq/layers_0/kernel/u']) nnx_model.norm_layer.batch_stats[ ('layers', 0, 'kernel', 'sigma') - ].value = var_norm_layer['seq/layers_0/kernel/sigma'] + ].set_value(var_norm_layer['seq/layers_0/kernel/sigma']) linen_out = linen_model.apply(variables, x, mutable=['batch_stats']) nnx_out = nnx_model(x) diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py index c05f9dd4d..d9e3b8faa 100644 --- a/tests/nnx/optimizer_test.py +++ b/tests/nnx/optimizer_test.py @@ -87,7 +87,7 @@ def test_sharding_propagation(self): self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding_names, ('a', 'b')) self.assertEqual( - partition_spec['opt_state'][0]['mu']['kernel'].value, + partition_spec['opt_state'][0]['mu']['kernel'].get_value(), jax.sharding.PartitionSpec('a', 'b'), ) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 4138a9922..646af30d3 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -175,9 +175,9 @@ def __call__(self, x: jax.Array): self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) self.assertEqual(bremoves, [(0, 'layers')]) - @parameterized.product(use_ref=[True, False]) - def test_logical_rules(self, use_ref): - self.enter_context(nnx.use_refs(use_ref)) + @parameterized.product(use_hijax=[True, False]) + def test_logical_rules(self, use_hijax): + self.enter_context(nnx.use_hijax(use_hijax)) class Foo(nnx.Module): def __init__(self): diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index d79aa35c8..eae485dfe 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -422,13 +422,13 @@ class TestEvalShape(absltest.TestCase): def test_eval_shape(self): abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) self.assertIsInstance(abs_model, nnx.Linear) - self.assertIsInstance(abs_model.kernel.value, jax.ShapeDtypeStruct) + self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) def test_eval_shape_mutable_array(self): - with nnx.use_refs(True): + with nnx.use_hijax(True): abs_model = nnx.eval_shape(lambda: nnx.Linear(1, 2, rngs=nnx.Rngs(0))) self.assertIsInstance(abs_model, nnx.Linear) - self.assertIsInstance(abs_model.kernel.value, jax.ShapeDtypeStruct) + self.assertIsInstance(abs_model.kernel.get_value(), jax.ShapeDtypeStruct) self.assertEqual(abs_model.kernel.shape, (1, 2)) class TestShardMap(absltest.TestCase): @@ -782,7 +782,7 @@ def loss_fn(l1: list[nnx.Linear], l2: list[nnx.Linear]): loss = jnp.mean(l1[0].kernel * l2[0].kernel) + jnp.mean( l1[0].bias * l2[0].bias ) - l1[0].kernel.value = jnp.array(-1.0) + l1[0].kernel.set_value(jnp.array(-1.0)) m3 = nnx.Linear(2, 3, rngs=nnx.Rngs(2)) return loss, m3 diff --git a/tests/nnx/variable_test.py b/tests/nnx/variable_test.py index 362519f81..ec2922343 100644 --- a/tests/nnx/variable_test.py +++ b/tests/nnx/variable_test.py @@ -25,12 +25,12 @@ class TestVariable(absltest.TestCase): def test_pytree(self): r1 = nnx.Param(1) - self.assertEqual(r1.value, 1) + self.assertEqual(r1.get_value(), 1) r2 = jax.tree.map(lambda x: x + 1, r1) - self.assertEqual(r1.value, 1) - self.assertEqual(r2.value, 2) + self.assertEqual(r1.get_value(), 1) + self.assertEqual(r2.get_value(), 2) self.assertIsNot(r1, r2) def test_overloads_module(self): @@ -94,38 +94,41 @@ def test_binary_ops(self): self.assertEqual(v1[...], 5) def test_mutable_array_context(self): - with nnx.use_refs(False): + with nnx.use_hijax(False): v = nnx.Variable(jnp.array(1.0)) - self.assertFalse(nnx.using_refs()) - self.assertNotIsInstance(v.raw_value, jax.Ref) + self.assertFalse(nnx.using_hijax()) + self.assertNotIsInstance(v[...], jax.Ref) - with nnx.use_refs(True): + with nnx.use_hijax(True): v = nnx.Variable(jnp.array(1.0)) - self.assertTrue(nnx.using_refs()) - self.assertIsInstance(v.raw_value, jax.Ref) + self.assertTrue(nnx.using_hijax()) + self.assertIsInstance(v[...], jax.Array) v = nnx.Variable(jnp.array(2.0)) - self.assertNotIsInstance(v.raw_value, jax.Ref) - self.assertFalse(nnx.using_refs()) + self.assertIsInstance(v[...], jax.Array) + self.assertFalse(nnx.using_hijax()) - nnx.use_refs(True) + nnx.use_hijax(True) v = nnx.Variable(jnp.array(0.0)) - self.assertTrue(nnx.using_refs()) - self.assertIsInstance(v.raw_value, jax.Ref) + self.assertTrue(nnx.using_hijax()) + self.assertIsInstance(v[...], jax.Array) v = nnx.Variable(jnp.array(1.0)) - self.assertFalse(nnx.using_refs()) - self.assertNotIsInstance(v.raw_value, jax.Ref) + self.assertFalse(nnx.using_hijax()) + self.assertIsInstance(v[...], jax.Array) def test_get_set_metadata(self): v = nnx.Variable(jnp.array(1.0)) - self.assertEqual(v.get_metadata(), {}) + self.assertEqual(v.get_metadata(), {'is_hijax': False}) v.set_metadata(a=1, b=2) self.assertEqual(v.get_metadata('a'), 1) self.assertEqual(v.get_metadata('b'), 2) - v.set_metadata({'b': 3, 'c': 4}) - self.assertEqual(v.get_metadata(), {'b': 3, 'c': 4}) + v.set_metadata({'b': 3, 'c': 4, 'is_hijax': False}) + self.assertEqual( + v.get_metadata(), + {'b': 3, 'c': 4, 'is_hijax': False}, + ) self.assertEqual(v.get_metadata('b'), 3) self.assertEqual(v.get_metadata('c'), 4) c = v.get_metadata('c')