diff --git a/docs_nnx/hijax/hijax.ipynb b/docs_nnx/hijax/hijax.ipynb index 0e7573b71..62aa47c32 100644 --- a/docs_nnx/hijax/hijax.ipynb +++ b/docs_nnx/hijax/hijax.ipynb @@ -5,7 +5,7 @@ "id": "15c2d208", "metadata": {}, "source": [ - "# Hijax Variable" + "# Hijax" ] }, { @@ -20,7 +20,76 @@ "import jax.numpy as jnp\n", "import optax\n", "\n", - "current_mode = nnx.using_hijax()" + "current_mode = nnx.using_hijax() # ignore: only needed for testing" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d1aaa0ec", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "safe_zip() argument 2 is shorter than argument 1", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mInvalidInputException\u001b[39m Traceback (most recent call last)", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:149\u001b[39m, in \u001b[36m_python_pjit_helper\u001b[39m\u001b[34m(fun, jit_info, *args, **kwargs)\u001b[39m\n\u001b[32m 148\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m149\u001b[39m out_flat = \u001b[43mjit_p\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs_flat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 150\u001b[39m compiled = \u001b[38;5;28;01mNone\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:632\u001b[39m, in \u001b[36mPrimitive.bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 631\u001b[39m args = args \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.skip_canonicalization \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(canonicalize_value, args)\n\u001b[32m--> \u001b[39m\u001b[32m632\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_true_bind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:648\u001b[39m, in \u001b[36mPrimitive._true_bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 647\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprev_trace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 649\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:659\u001b[39m, in \u001b[36mPrimitive.bind_with_trace\u001b[39m\u001b[34m(self, trace, args, params)\u001b[39m\n\u001b[32m 658\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m set_current_trace(trace):\n\u001b[32m--> \u001b[39m\u001b[32m659\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mto_lojax\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[32m 660\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m trace.process_primitive(\u001b[38;5;28mself\u001b[39m, args, params)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:1346\u001b[39m, in \u001b[36m_to_lojax\u001b[39m\u001b[34m(jaxpr, *hi_args, **params)\u001b[39m\n\u001b[32m 1345\u001b[39m lo_jaxpr = pe.lower_jaxpr(jaxpr)\n\u001b[32m-> \u001b[39m\u001b[32m1346\u001b[39m all_outs = \u001b[43mjit_p\u001b[49m\u001b[43m.\u001b[49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mlo_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlo_jaxpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1347\u001b[39m out_mut, lo_outs = split_list(all_outs, [lo_muts_out])\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:632\u001b[39m, in \u001b[36mPrimitive.bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 631\u001b[39m args = args \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.skip_canonicalization \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(canonicalize_value, args)\n\u001b[32m--> \u001b[39m\u001b[32m632\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_true_bind\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:648\u001b[39m, in \u001b[36mPrimitive._true_bind\u001b[39m\u001b[34m(self, *args, **params)\u001b[39m\n\u001b[32m 647\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m648\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprev_trace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 649\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:660\u001b[39m, in \u001b[36mPrimitive.bind_with_trace\u001b[39m\u001b[34m(self, trace, args, params)\u001b[39m\n\u001b[32m 659\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.to_lojax(*args, **params) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m660\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrace\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 661\u001b[39m trace.process_primitive(\u001b[38;5;28mself\u001b[39m, args, params) \u001b[38;5;66;03m# may raise lojax error\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/core.py:1189\u001b[39m, in \u001b[36mEvalTrace.process_primitive\u001b[39m\u001b[34m(self, primitive, args, params)\u001b[39m\n\u001b[32m 1188\u001b[39m \u001b[38;5;66;03m# check_eval_args(args)\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1189\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:1691\u001b[39m, in \u001b[36m_pjit_call_impl\u001b[39m\u001b[34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)\u001b[39m\n\u001b[32m 1684\u001b[39m cache_key = pxla.JitGlobalCppCacheKeys(\n\u001b[32m 1685\u001b[39m donate_argnums=donated_argnums, donate_argnames=\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m 1686\u001b[39m device=\u001b[38;5;28;01mNone\u001b[39;00m, backend=\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[32m (...)\u001b[39m\u001b[32m 1689\u001b[39m in_layouts_treedef=\u001b[38;5;28;01mNone\u001b[39;00m, in_layouts_leaves=in_layouts,\n\u001b[32m 1690\u001b[39m out_layouts_treedef=\u001b[38;5;28;01mNone\u001b[39;00m, out_layouts_leaves=out_layouts)\n\u001b[32m-> \u001b[39m\u001b[32m1691\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mxc\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_xla\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpjit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1692\u001b[39m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcall_impl_cache_miss\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcache_key\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1693\u001b[39m \u001b[43m \u001b[49m\u001b[43mtree_util\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdispatch_registry\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpxla\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcc_shard_arg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1694\u001b[39m \u001b[43m \u001b[49m\u001b[43m_get_cpp_global_cache\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcache_key\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcontains_explicit_attributes\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:1667\u001b[39m, in \u001b[36m_pjit_call_impl..call_impl_cache_miss\u001b[39m\u001b[34m(*args_, **kwargs_)\u001b[39m\n\u001b[32m 1663\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mcall_impl_cache_miss\u001b[39m(*args_, **kwargs_):\n\u001b[32m 1664\u001b[39m \u001b[38;5;66;03m# args_ do not include the const args\u001b[39;00m\n\u001b[32m 1665\u001b[39m \u001b[38;5;66;03m# See https://docs.jax.dev/en/latest/internals/constants.html.\u001b[39;00m\n\u001b[32m 1666\u001b[39m \u001b[38;5;66;03m# TODO(necula): remove num_const_args when fixing the C++ path\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m1667\u001b[39m out_flat, compiled, pgle_profiler, const_args = \u001b[43m_pjit_call_impl_python\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 1668\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mjaxpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_shardings\u001b[49m\u001b[43m=\u001b[49m\u001b[43min_shardings\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1669\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_shardings\u001b[49m\u001b[43m=\u001b[49m\u001b[43mout_shardings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43min_layouts\u001b[49m\u001b[43m=\u001b[49m\u001b[43min_layouts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1670\u001b[39m \u001b[43m \u001b[49m\u001b[43mout_layouts\u001b[49m\u001b[43m=\u001b[49m\u001b[43mout_layouts\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdonated_invars\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1671\u001b[39m \u001b[43m \u001b[49m\u001b[43mctx_mesh\u001b[49m\u001b[43m=\u001b[49m\u001b[43mctx_mesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m=\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_unused\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkeep_unused\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 1672\u001b[39m \u001b[43m \u001b[49m\u001b[43minline\u001b[49m\u001b[43m=\u001b[49m\u001b[43minline\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcompiler_options_kvs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1673\u001b[39m fastpath_data = _get_fastpath_data(\n\u001b[32m 1674\u001b[39m compiled, tree_structure(out_flat), args, out_flat,\n\u001b[32m 1675\u001b[39m jaxpr.effects, jaxpr.consts, \u001b[38;5;28;01mNone\u001b[39;00m, pgle_profiler,\n\u001b[32m 1676\u001b[39m const_args)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:1642\u001b[39m, in \u001b[36m_pjit_call_impl_python\u001b[39m\u001b[34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)\u001b[39m\n\u001b[32m 1635\u001b[39m distributed_debug_log((\u001b[33m\"\u001b[39m\u001b[33mRunning pjit\u001b[39m\u001b[33m'\u001b[39m\u001b[33md function\u001b[39m\u001b[33m\"\u001b[39m, name),\n\u001b[32m 1636\u001b[39m (\u001b[33m\"\u001b[39m\u001b[33min_shardings\u001b[39m\u001b[33m\"\u001b[39m, in_shardings),\n\u001b[32m 1637\u001b[39m (\u001b[33m\"\u001b[39m\u001b[33mout_shardings\u001b[39m\u001b[33m\"\u001b[39m, out_shardings),\n\u001b[32m (...)\u001b[39m\u001b[32m 1640\u001b[39m (\u001b[33m\"\u001b[39m\u001b[33mabstract args\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28mmap\u001b[39m(core.abstractify, args)),\n\u001b[32m 1641\u001b[39m (\u001b[33m\"\u001b[39m\u001b[33mfingerprint\u001b[39m\u001b[33m\"\u001b[39m, fingerprint))\n\u001b[32m-> \u001b[39m\u001b[32m1642\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m (\u001b[43mcompiled\u001b[49m\u001b[43m.\u001b[49m\u001b[43munsafe_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mcomputation\u001b[49m\u001b[43m.\u001b[49m\u001b[43mconst_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[32m 1643\u001b[39m compiled, pgle_profiler, computation.const_args)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/profiler.py:359\u001b[39m, in \u001b[36mannotate_function..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 358\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, **decorator_kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m359\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:1366\u001b[39m, in \u001b[36mExecuteReplicated.__call__\u001b[39m\u001b[34m(self, *args)\u001b[39m\n\u001b[32m 1365\u001b[39m args = [*args, *\u001b[38;5;28mself\u001b[39m.mut.in_mut]\n\u001b[32m-> \u001b[39m\u001b[32m1366\u001b[39m input_bufs = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43min_handler\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1367\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m profiler.PGLEProfiler.trace(\u001b[38;5;28mself\u001b[39m.pgle_profiler):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:1249\u001b[39m, in \u001b[36mInputsHandler.__call__\u001b[39m\u001b[34m(self, input_buffers)\u001b[39m\n\u001b[32m 1248\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_buffers):\n\u001b[32m-> \u001b[39m\u001b[32m1249\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mhandler\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_buffers\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/profiler.py:359\u001b[39m, in \u001b[36mannotate_function..wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 358\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m TraceAnnotation(name, **decorator_kwargs):\n\u001b[32m--> \u001b[39m\u001b[32m359\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:134\u001b[39m, in \u001b[36mshard_args\u001b[39m\u001b[34m(shardings, layouts, copy_semantics, args, canonicalize)\u001b[39m\n\u001b[32m 133\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m canonicalize:\n\u001b[32m--> \u001b[39m\u001b[32m134\u001b[39m arg = \u001b[43mdtypes\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcanonicalize_value\u001b[49m\u001b[43m(\u001b[49m\u001b[43marg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 135\u001b[39m batch = batches[\u001b[38;5;28mtype\u001b[39m(arg)]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/dtypes.py:388\u001b[39m, in \u001b[36mcanonicalize_value\u001b[39m\u001b[34m(x)\u001b[39m\n\u001b[32m 383\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 384\u001b[39m \u001b[33m'\u001b[39m\u001b[33mTriggering __jax_array__() during abstractification is no longer\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 385\u001b[39m \u001b[33m'\u001b[39m\u001b[33m supported. To avoid this error, either explicitly convert your object\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 386\u001b[39m \u001b[33m'\u001b[39m\u001b[33m using jax.numpy.array(), or register your object as a pytree.\u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 387\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m388\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m InvalidInputException(\n\u001b[32m 389\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mArgument \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mx\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(x)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m is not a valid JAX type.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 390\u001b[39m )\n", + "\u001b[31mInvalidInputException\u001b[39m: Argument 'JitTracer' of type is not a valid JAX type.", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[31mValueError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 14\u001b[39m x, y = rngs.uniform((\u001b[32m4\u001b[39m, \u001b[32m2\u001b[39m)), rngs.uniform((\u001b[32m4\u001b[39m, \u001b[32m3\u001b[39m))\n\u001b[32m 15\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[32m3\u001b[39m):\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m)\n", + " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:264\u001b[39m, in \u001b[36m_cpp_pjit..cache_miss\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 259\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m config.no_tracing.value:\n\u001b[32m 260\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mre-tracing function \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mjit_info.fun_sourceinfo\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m for \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 261\u001b[39m \u001b[33m\"\u001b[39m\u001b[33m`jit`, but \u001b[39m\u001b[33m'\u001b[39m\u001b[33mno_tracing\u001b[39m\u001b[33m'\u001b[39m\u001b[33m is set\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 263\u001b[39m (outs, out_flat, out_tree, args_flat, jaxpr,\n\u001b[32m--> \u001b[39m\u001b[32m264\u001b[39m executable, pgle_profiler, const_args) = \u001b[43m_python_pjit_helper\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 265\u001b[39m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjit_info\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 267\u001b[39m maybe_fastpath_data = _get_fastpath_data(\n\u001b[32m 268\u001b[39m executable, out_tree, args_flat, out_flat, jaxpr.effects, jaxpr.consts,\n\u001b[32m 269\u001b[39m jit_info.abstracted_axes, pgle_profiler,\n\u001b[32m 270\u001b[39m const_args)\n\u001b[32m 272\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/repos/flax/.venv/lib/python3.11/site-packages/jax/_src/pjit.py:166\u001b[39m, in \u001b[36m_python_pjit_helper\u001b[39m\u001b[34m(fun, jit_info, *args, **kwargs)\u001b[39m\n\u001b[32m 164\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(e.args[\u001b[32m0\u001b[39m]) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01me\u001b[39;00m\n\u001b[32m 165\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m166\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m arg, name, aval \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(args_flat, arg_names, p.in_avals):\n\u001b[32m 167\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 168\u001b[39m dtypes.canonicalize_value(arg)\n", + "\u001b[31mValueError\u001b[39m: safe_zip() argument 2 is shorter than argument 1" + ] + } + ], + "source": [ + "nnx.use_hijax(True)\n", + "\n", + "rngs = nnx.Rngs(0)\n", + "model = nnx.Linear(2, 3, rngs=rngs)\n", + "optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param)\n", + "\n", + "@jax.jit\n", + "def train_step(x, y):\n", + " loss_fn = lambda m: jnp.mean((m(x) - y) ** 2)\n", + " loss, grads = jax.value_and_grad(loss_fn)(model) # tmp fix for jax.grad\n", + " optimizer.update(model, grads)\n", + " return loss\n", + "\n", + "x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3))\n", + "for _ in range(3):\n", + " print(train_step(x, y))" + ] + }, + { + "cell_type": "markdown", + "id": "04458d66", + "metadata": {}, + "source": [ + "## Hijax Variable" ] }, { @@ -33,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "396a07a3", "metadata": {}, "outputs": [ @@ -58,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "2ab7d801", "metadata": {}, "outputs": [ @@ -98,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "fcd0de3f", "metadata": {}, "outputs": [ @@ -138,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "0d83a130", "metadata": {}, "outputs": [ @@ -176,7 +245,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "0a55df94", "metadata": {}, "outputs": [ @@ -210,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "b7b1f421", "metadata": {}, "outputs": [ @@ -250,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "594cb65e", "metadata": {}, "outputs": [ @@ -283,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "fcd4fb4f", "metadata": {}, "outputs": [ @@ -309,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "18256668", "metadata": {}, "outputs": [ @@ -349,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "5400fe58", "metadata": {}, "outputs": [], @@ -376,7 +445,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "566c4249", "metadata": {}, "outputs": [ @@ -422,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "d8136be4", "metadata": {}, "outputs": [], @@ -462,7 +531,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "c6062d19", "metadata": {}, "outputs": [ @@ -487,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "id": "8bb1e9e7", "metadata": {}, "outputs": [ @@ -524,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "045d03c1", "metadata": {}, "outputs": [ @@ -555,7 +624,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "bc2e87e5", "metadata": {}, "outputs": [ @@ -587,7 +656,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "id": "6298f3d9", "metadata": {}, "outputs": [ @@ -610,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "00854d38", "metadata": {}, "outputs": [ @@ -645,7 +714,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "195296c8", "metadata": {}, "outputs": [], diff --git a/docs_nnx/hijax/hijax.md b/docs_nnx/hijax/hijax.md index 7a24d416a..f553e9cdd 100644 --- a/docs_nnx/hijax/hijax.md +++ b/docs_nnx/hijax/hijax.md @@ -8,7 +8,7 @@ jupytext: jupytext_version: 1.13.8 --- -# Hijax Variable +# Hijax ```{code-cell} ipython3 from flax import nnx @@ -16,9 +16,32 @@ import jax import jax.numpy as jnp import optax -current_mode = nnx.using_hijax() +current_mode = nnx.using_hijax() # ignore: only needed for testing ``` +```{code-cell} ipython3 +nnx.use_hijax(True) + +rngs = nnx.Rngs(0) +model = nnx.Linear(2, 3, rngs=rngs) +optimizer = nnx.Optimizer(model, optax.adamw(1e-2), wrt=nnx.Param) + +@jax.jit +def train_step(x, y): + loss_fn = lambda m: jnp.mean((m(x) - y) ** 2) + loss, grads = jax.value_and_grad(loss_fn)(model) # tmp fix for jax.grad + optimizer.update(model, grads) + return loss + +x, y = rngs.uniform((4, 2)), rngs.uniform((4, 3)) +for _ in range(3): + print(train_step(x, y)) +``` + +## Hijax Variable + ++++ + State propagation: ```{code-cell} ipython3 diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 6fb14fa69..80c1291e5 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -32,7 +32,7 @@ ) from flax.nnx.statelib import FlatState, State, map_state from flax.nnx.variablelib import Variable, is_array_ref, V -from flax.typing import Key, PathParts, is_key_like +from flax.typing import HashableMapping, Key, PathParts, is_key_like import jax import numpy as np import treescope # type: ignore[import-not-found,import-untyped] @@ -301,50 +301,6 @@ def get_node_impl_for_type( return None -class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): - _mapping: dict[HA, HB] | tp.Mapping[HA, HB] - - def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True): - self._mapping = dict(mapping) if copy else mapping - - def __contains__(self, key: object) -> bool: - return key in self._mapping - - def __getitem__(self, key: HA) -> HB: - return self._mapping[key] - - def __iter__(self) -> tp.Iterator[HA]: - return iter(self._mapping) - - def __len__(self) -> int: - return len(self._mapping) - - def __hash__(self) -> int: - # use type-aware sorting to support int keys - def _pytree__key_sort_fn(item: tuple[tp.Any, tp.Any]) -> tuple[int, tp.Any]: - key, _ = item - if isinstance(key, int): - return (0, key) - elif isinstance(key, str): - return (1, key) - else: - raise ValueError(f'Unsupported key type: {type(key)!r}') - return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn))) - - def __eq__(self, other: tp.Any) -> bool: - return ( - isinstance(other, HashableMapping) and self._mapping == other._mapping - ) - - def __repr__(self) -> str: - return repr(self._mapping) - - def update(self, other: tp.Mapping[HA, HB]) -> HashableMapping[HA, HB]: - """Updates the mapping with another mapping.""" - mapping = dict(self._mapping) - mapping.update(other) - return HashableMapping(mapping, copy=False) - @jax.tree_util.register_static @dataclasses.dataclass(frozen=True, repr=False) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 6cf01d869..7feb467b7 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -229,103 +229,146 @@ class VariableMetadata(tp.Generic[A]): PyTreeDef = tp.Any +Leaf = tp.Any # --------------------------------- # hijax # --------------------------------- -def _new_hijax_variable(var_type: type[Variable]) -> HijaxVariable: - variable = var_type._new(None, {}) - (), treedef = jax.tree.flatten(variable) - return new_variable_p.bind(treedef=treedef, var_type=var_type) - - -def _get_hijax_state(hijax_var) -> Variable: - tys: VariableQDD = jax.experimental.cur_qdd(hijax_var) - leaf_vals = get_variable_p.bind(hijax_var, avals=tuple(tys.leaf_avals)) - variable = jax.tree.unflatten(tys.treedef, leaf_vals) - return variable - - -def _set_hijax_state(hijax_var, variable: Variable): - leaves, treedef = jax.tree.flatten(variable) - set_variable_p.bind( - hijax_var, *leaves, treedef=treedef, var_type=type(variable) - ) - - -def _new_hijax_from_variable(variable: Variable) -> HijaxVariable: - hijax_var = _new_hijax_variable(type(variable)) - _set_hijax_state(hijax_var, variable) - return hijax_var - @dataclasses.dataclass(frozen=True) class VariableQDD: leaf_avals: tuple[hijax.AbstractValue, ...] treedef: PyTreeDef + var_type: type[Variable[Any]] def to_tangent_qdd(self): leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals) - return VariableQDD(leaf_avals, self.treedef) + return VariableQDD(leaf_avals, self.treedef, self.var_type) def normalize(self): leaf_types = tuple(a.normalize() for a in self.leaf_avals) - return VariableQDD(leaf_types, self.treedef) - + return VariableQDD(leaf_types, self.treedef, self.var_type) -class VariableEffect(jax.core.Effect): - ... +class VariableEffect(jax.core.Effect): ... variable_effect = VariableEffect() hijax.control_flow_allowed_effects.add_type(VariableEffect) +def _new_hijax_from_variable(variable: Variable) -> HijaxVariable: + has_qdd = variable.is_mutable and not variable.has_ref + leaves, treedef = jax.tree.flatten(variable) + var_type = type(variable) + hijax_var = new_variable_p.bind( + *leaves, treedef=treedef, var_type=var_type, has_qdd=has_qdd + ) + return hijax_var + + class NewVariable(hijax.HiPrimitive): - def is_high(self, *, treedef, var_type) -> bool: + def is_high(self, *leaves, treedef, var_type, has_qdd) -> bool: return True # type: ignore - def abstract_eval(self, *, treedef, var_type: type[Variable]): - variable = var_type._new(None, {}) - leaves, treedef = jax.tree.flatten(variable) - qdd = VariableQDD(tuple(leaves), treedef) - return hijax.AvalQDD(AbstractVariable(var_type), qdd), {variable_effect} # type: ignore + def impl(self, *leaves, treedef, var_type, has_qdd): + return HijaxVariable._new(leaves, treedef, var_type, has_qdd) - def to_lojax(self, *, treedef, var_type: type[Variable]): - return HijaxVariable._new(None, {}, var_type) + def abstract_eval(self, *leaves, treedef, var_type, has_qdd): + aval = AbstractVariable(var_type, treedef, leaves, has_qdd) + if has_qdd: + qdd = VariableQDD(tuple(leaves), treedef, var_type) + aval_qdd = hijax.AvalQDD(aval, qdd) # type: ignore + return aval_qdd, {variable_effect} + else: + return aval, set() - def jvp(_, primals, tangents, *, treedef, var_type): - raise NotImplementedError('jvp not implemented for NewHijaxVariable') + def to_lojax(self, *leaves, treedef, var_type, has_qdd): + return HijaxVariable._new(leaves, treedef, var_type, has_qdd) - def transpose(_, *args, treedef, var_type): - raise NotImplementedError('transpose not implemented for NewHijaxVariable') + def jvp(_, primals, tangents, *, treedef, var_type, has_qdd): + if has_qdd: + raise NotImplementedError( + "jvp not implemented for 'new_variable' with QDD" + ) + primal_hijax_var = new_variable_p.bind( + *primals, treedef=treedef, var_type=var_type, has_qdd=has_qdd + ) + tangent_hijax_var = new_variable_p.bind( + *tangents, treedef=treedef, var_type=var_type, has_qdd=has_qdd + ) + return primal_hijax_var, tangent_hijax_var + + def transpose( + _, out_var: HijaxVariable, *input_leaves, treedef, var_type, has_qdd + ): + if has_qdd: + raise NotImplementedError( + "transpose not implemented for 'new_variable' with QDD" + ) + avals = tuple( + map( + lambda x: x.aval if hijax.is_undefined_primal(x) else jax.typeof(x), + input_leaves, + ) + ) + leaves_dot = get_variable_p.bind( + out_var, + treedef=treedef, + avals=avals, + var_type=var_type, + has_qdd=has_qdd, + ) + return leaves_dot new_variable_p = NewVariable(f'new_variable') +def _set_hijax_state(hijax_var, variable: Variable): + leaves, treedef = jax.tree.flatten(variable) + set_variable_p.bind( + hijax_var, *leaves, treedef=treedef, var_type=type(variable) + ) + + class SetVariable(hijax.HiPrimitive): multiple_results = True - def is_high(self, *leaf_avals, treedef, var_type) -> bool: + def is_high(_, *leaf_avals, treedef, var_type) -> bool: return True # type: ignore # TODO: upstream this to Box - def impl(self, hijax_var: HijaxVariable, *leaves, treedef, var_type): - variable: Variable = jax.tree.unflatten(treedef, leaves) - object.__setattr__(hijax_var, '_raw_value', variable._raw_value) - object.__setattr__(hijax_var, '_metadata', variable._var_metadata) + def impl(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): + if not hijax_var.has_qdd: + raise errors.ImmutableVariableError( + "Trying to update Variable with 'has_qdd=False'." + ) + assert var_type is hijax_var._var_type + object.__setattr__(hijax_var, '_leaves', leaves) + object.__setattr__(hijax_var, '_treedef', treedef) return [] - def abstract_eval(self, hijax_var_type, *leaf_avals, treedef, var_type): - hijax_var_type.mutable_qdd.update(VariableQDD(leaf_avals, treedef)) - return [], {variable_effect} # TODO better typechecking... + def abstract_eval(_, hijax_var_type, *leaf_avals, treedef, var_type): + if not hijax_var_type.has_qdd: + raise errors.ImmutableVariableError( + "Trying to update Variable with 'has_qdd=False'." + ) + assert var_type is hijax_var_type._var_type + hijax_var_type.mutable_qdd.update( + VariableQDD(leaf_avals, treedef, var_type) + ) + effects = {variable_effect} if hijax_var_type.has_qdd else set() + return [], effects # TODO better typechecking... def to_lojax(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): - variable: Variable = jax.tree.unflatten(treedef, leaves) - object.__setattr__(hijax_var, '_raw_value', variable._raw_value) - object.__setattr__(hijax_var, '_metadata', variable._var_metadata) + if not hijax_var.has_qdd: + raise errors.ImmutableVariableError( + "Trying to update Variable with 'has_qdd=False'." + ) + assert var_type is hijax_var._var_type + object.__setattr__(hijax_var, '_leaves', leaves) + object.__setattr__(hijax_var, '_treedef', treedef) return [] def jvp(_, primals, tangents, *, treedef, var_type): @@ -353,26 +396,92 @@ def transpose(_, *args, treedef, var_type): set_variable_p = SetVariable(f'set_variable') +def _get_hijax_state(hijax_var: HijaxVariable | AbstractVariable) -> Variable: + if hijax_var.has_qdd: + tys: VariableQDD = jax.experimental.cur_qdd(hijax_var) + leaf_vals = get_variable_p.bind( + hijax_var, + treedef=tys.treedef, + avals=tuple(tys.leaf_avals), + var_type=hijax_var._var_type, + has_qdd=hijax_var.has_qdd, + ) + variable = jax.tree.unflatten(tys.treedef, leaf_vals) + else: + assert hijax_var._treedef is not None + assert hijax_var._leaves is not None + if isinstance(hijax_var, (jax.core.Tracer, AbstractVariable)): + leaf_avals = hijax_var._leaves + else: + leaf_avals = tuple(map(jax.typeof, hijax_var._leaves)) + leaf_vals = get_variable_p.bind( + hijax_var, + treedef=hijax_var._treedef, + avals=leaf_avals, + var_type=hijax_var._var_type, + has_qdd=hijax_var.has_qdd, + ) + variable = jax.tree.unflatten(hijax_var._treedef, leaf_vals) + + return variable + + class GetVariable(hijax.HiPrimitive): multiple_results = True - def abstract_eval(self, abstract_var, *, avals): - return avals, {variable_effect} + def impl( + self, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd + ): + return hijax_var._leaves - def to_lojax(_, hijax_var: HijaxVariable, *, avals): - return jax.tree.leaves(hijax_var._raw_value) + def abstract_eval(self, abstract_var, *, treedef, avals, var_type, has_qdd): + if has_qdd: + return avals, {variable_effect} + else: + return avals, set() - def jvp(_, primals, tangents, *, avals): - (box,), (variable_dot,) = primals, tangents + def to_lojax( + _, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd + ): + return hijax_var._leaves + + def jvp(_, primals, tangents, *, treedef, avals, var_type, has_qdd): + if has_qdd: + raise NotImplementedError( + "jvp not implemented for 'get_variable' with QDD" + ) + (hijax_var,), (hijax_var_dot,) = primals, tangents return ( - get_variable_p.bind(box, avals=avals), get_variable_p.bind( - variable_dot, avals=tuple(a.to_tangent_aval() for a in avals) + hijax_var, + treedef=treedef, + avals=avals, + var_type=var_type, + has_qdd=has_qdd, + ), + get_variable_p.bind( + hijax_var_dot, + treedef=treedef, + avals=tuple(a.to_tangent_aval() for a in avals), + var_type=var_type, + has_qdd=has_qdd, ), ) - def transpose(_, *args): - raise NotImplementedError('transpose not implemented for GetHijaxVariable') + def transpose(_, out, hijax_var, *, treedef, avals, var_type, has_qdd): + if has_qdd: + raise NotImplementedError( + "transpose not implemented for 'get_variable' with QDD" + ) + abstract_var: AbstractVariable = ( + hijax_var.aval + if hijax.is_undefined_primal(hijax_var) + else jax.typeof(hijax_var) + ) + hijax_var_dot = new_variable_p.bind( + *out, treedef=abstract_var._treedef, var_type=var_type, has_qdd=has_qdd + ) + return (hijax_var_dot,) get_variable_p = GetVariable(f'get_variable') @@ -477,26 +586,29 @@ def __instancecheck__(self, instance): return isinstance(ty, AbstractVariable) return False -jax.Ref + class HijaxVariable( tp.Generic[A], reprlib.Representable, metaclass=HijaxVariableMeta ): # type: ignore - __slots__ = ('_raw_value', '_metadata', '_var_type') - _raw_value: A - _metadata: dict[str, tp.Any] + __slots__ = ('_treedef', '_leaves', '_var_type', 'has_qdd') + _treedef: PyTreeDef + _leaves: tuple[Leaf, ...] _var_type: type[Variable[tp.Any]] + has_qdd: bool @classmethod def _new( cls, - value, - metadata: dict[str, tp.Any], + leaves: tuple[Leaf, ...], + treedef: PyTreeDef, var_type: type[Variable[A]], + has_qdd: bool, ): hijax_var = object.__new__(cls) - object.__setattr__(hijax_var, '_raw_value', value) - object.__setattr__(hijax_var, '_metadata', metadata) + object.__setattr__(hijax_var, '_treedef', treedef) + object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_var_type', var_type) + object.__setattr__(hijax_var, 'has_qdd', has_qdd) return hijax_var __init__ = _as_hijax_method('__init__') @@ -641,26 +753,35 @@ def from_metadata(cls, value: A, metadata: dict[str, tp.Any]): def cur_qdd(self): return self.type_state() - @property - def ty(self): - return AbstractVariable(self._var_type) - def type_state(self): - variable = self._var_type._new(self._raw_value, self._metadata) - leaves, treedef = jax.tree.flatten(variable) - leaf_avals = tuple(map(jax.typeof, leaves)) - return VariableQDD(leaf_avals, treedef) + leaf_avals = tuple(map(jax.typeof, self._leaves)) + return VariableQDD(leaf_avals, self._treedef, self._var_type) -hijax.register_hitype(HijaxVariable, lambda b: b.ty) +def _to_abstract_variable(hijax_var: HijaxVariable): + if hijax_var.has_qdd: + treedef = None + leaves = None + else: + leaves = tuple(map(jax.typeof, hijax_var._leaves)) + treedef = hijax_var._treedef + return AbstractVariable( + hijax_var._var_type, treedef, leaves, hijax_var.has_qdd + ) + + +hijax.register_hitype(HijaxVariable, _to_abstract_variable) # --------------------------------- # AbstractVariable # --------------------------------- class AbstractVariable(tp.Generic[A], hijax.MutableHiType): - __slots__ = ['_var_type'] + __slots__ = ['_var_type', '_treedef', '_leaves', 'has_qdd'] _var_type: type[Variable[A]] + _treedef: PyTreeDef | None + _leaves: tuple[hijax.AbstractValue, ...] | None + has_qdd: bool # forwarded to value var_type = hijax.aval_property(lambda self: self.aval._var_type) is_hijax = _as_aval_property(HijaxVariable.is_hijax) @@ -670,8 +791,19 @@ class AbstractVariable(tp.Generic[A], hijax.MutableHiType): _can_update = _as_aval_property(HijaxVariable._can_update) _check_can_update = hijax.aval_method(HijaxVariable._check_can_update) - def __init__(self, var_type: type[Variable[A]]): + def __init__( + self, + var_type: type[Variable[A]], + treedef: PyTreeDef | None, + leaves: tuple[hijax.AbstractValue, ...] | None, + has_qdd: bool, + ): + if (treedef is None) ^ (leaves is None): + raise ValueError('treedef and leaves must be both provided or both None') + object.__setattr__(self, '_treedef', treedef) + object.__setattr__(self, '_leaves', leaves) object.__setattr__(self, '_var_type', var_type) + object.__setattr__(self, 'has_qdd', has_qdd) @property def dtype(self): @@ -802,13 +934,19 @@ def __treescope_repr__(self, path, subtree_renderer): # -------------------------------- # hijax interface # -------------------------------- - has_qdd = True - def __hash__(self): - return hash((AbstractVariable, self._var_type)) + if self._leaves is not None and self._treedef is not None: + return hash( + (AbstractVariable, self._var_type, self._treedef, self._leaves) + ) + else: + assert self._leaves is None and self._treedef is None + return hash((AbstractVariable, self._var_type)) def __eq__(self, other): - return isinstance(other, AbstractVariable) and self._var_type == other._var_type + return ( + isinstance(other, AbstractVariable) and self._var_type == other._var_type + ) def str_short(self, short_dtypes=False, **_) -> str: # type: ignore return f'{self._var_type.__name__}()' @@ -828,7 +966,10 @@ def new_from_loval( # type: ignore[override] assert next(lo_vals_, None) is None variable: Variable = jax.tree.unflatten(variable_state.treedef, hi_vals) return HijaxVariable._new( - variable._raw_value, variable._var_metadata, self._var_type + hi_vals, + variable_state.treedef, + self._var_type, + has_qdd=self.has_qdd, ) # will be mutated def read_loval(self, variable_state: VariableQDD, variable) -> list: # type: ignore @@ -852,7 +993,18 @@ def update_from_loval( # type: ignore[override] _set_hijax_state(variable, jax.tree.unflatten(box_state.treedef, hi_vals)) def to_tangent_aval(self): - return AbstractVariable(self._var_type) + if self.has_qdd: + variable = _get_hijax_state(self) + variable = variable.copy(is_mutable=False) + hijax_var = _new_hijax_from_variable(variable) + return _to_abstract_variable(hijax_var) + else: + return AbstractVariable( + self._var_type, + self._treedef, + self._leaves, + self.has_qdd, + ) # -------------------------------------------- @@ -974,9 +1126,9 @@ class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): _raw_value: A _trace_state: tracers.TraceState _var_metadata: dict[str, tp.Any] - required_metadata = frozenset([ - 'is_hijax', 'has_ref', 'is_mutable', 'eager_sharding' - ]) + required_metadata = frozenset( + ['is_hijax', 'has_ref', 'is_mutable', 'eager_sharding'] + ) @property def var_type(self): @@ -999,14 +1151,14 @@ def shape(self: Variable[jax.Array]) -> tuple[int, ...]: return self.get_value().shape def __init__( - self, - value: A | VariableMetadata[A], - *, - is_hijax: bool | None = None, - has_ref: bool = False, - is_mutable: bool = True, - eager_sharding: bool | None = None, - **metadata: tp.Any, + self, + value: A | VariableMetadata[A], + *, + is_hijax: bool | None = None, + has_ref: bool = False, + is_mutable: bool = True, + eager_sharding: bool | None = None, + **metadata: tp.Any, ): var_t = type(self) @@ -1052,16 +1204,6 @@ def __init__( if eager_sharding is None: eager_sharding = using_eager_sharding() - if is_hijax and not is_mutable: - raise ValueError( - 'Cannot set is_hijax=True and is_mutable=False simultaneously.' - ) - - if has_ref and is_hijax: - raise ValueError( - 'Cannot set has_ref=True and is_hijax=True simultaneously.' - ) - if has_ref and not is_mutable: raise ValueError( 'Cannot set has_ref=True and is_mutable=False simultaneously.' @@ -1169,12 +1311,12 @@ def type(self): return type(self) @tp.overload - def get_metadata(self, *, exclude_required: bool = False) -> dict[str, tp.Any]: - ... + def get_metadata( + self, *, exclude_required: bool = False + ) -> dict[str, tp.Any]: ... @tp.overload - def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: - ... + def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: ... def get_metadata( self, @@ -1211,16 +1353,13 @@ def get_metadata( return metadata[name] @tp.overload - def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: - ... + def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... @tp.overload - def set_metadata(self, name: str, value: tp.Any, /) -> None: - ... + def set_metadata(self, name: str, value: tp.Any, /) -> None: ... @tp.overload - def set_metadata(self, **metadata: tp.Any) -> None: - ... + def set_metadata(self, **metadata: tp.Any) -> None: ... def set_metadata(self, *args, **kwargs) -> None: """Set metadata for the Variable. @@ -1469,12 +1608,10 @@ def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): self._var_metadata['on_remove_axis'](self, axis_index, axis_name) @tp.overload - def copy(self, value: B, **kwargs) -> Variable[B]: - ... + def copy(self, value: B, **kwargs) -> Variable[B]: ... @tp.overload - def copy(self, **kwargs) -> Variable[A]: - ... + def copy(self, **kwargs) -> Variable[A]: ... def copy( self, @@ -1485,36 +1622,17 @@ def copy( ) -> Variable[tp.Any]: assert 'raw_value' not in updates - if updates.get('has_ref', False) and updates.get('is_hijax', False): - raise ValueError( - 'Cannot set has_ref=True and is_hijax=True simultaneously.' - ) if not updates.get('is_mutable', True) and updates.get('has_ref', False): raise ValueError( 'Cannot set has_ref=True and is_mutable=False simultaneously.' ) - if updates.get('is_mutable', False) and updates.get('is_hijax', False): - raise ValueError( - 'Cannot set is_hijax=True and is_mutable=False simultaneously.' - ) new_metadata = self.get_metadata() | updates - if updates.get('has_ref', False): - new_metadata['is_hijax'] = False - new_metadata.pop('was_hijax', None) - if updates.get('is_hijax', False): - new_metadata['has_ref'] = False - new_metadata.pop('had_ref', None) if not updates.get('is_mutable', True) and self.is_mutable: new_metadata['has_ref'] = False - new_metadata['is_hijax'] = False if self.has_ref: new_metadata['had_ref'] = True - if self.is_hijax: - new_metadata['was_hijax'] = True if updates.get('is_mutable', False) or updates.get('has_ref', False): new_metadata.pop('had_ref', None) - if updates.get('is_mutable', False) or updates.get('is_hijax', False): - new_metadata.pop('was_hijax', None) if not isinstance(value, Missing): pass @@ -1531,8 +1649,6 @@ def copy( ): value = jax.new_ref(value) new_metadata['has_ref'] = True - if new_metadata['is_mutable'] and self.get_metadata('was_hijax', False): - new_metadata['is_hijax'] = True obj = self.from_metadata(value, new_metadata) return obj diff --git a/flax/typing.py b/flax/typing.py index 350be2e36..0e9b1d0dd 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -24,6 +24,7 @@ TypeVar, Union, ) +from collections.abc import Iterator from collections.abc import Callable, Hashable, Mapping, Sequence import jax @@ -135,6 +136,8 @@ class Out(Generic[T]): Sharding = tuple[AxisName, ...] A = TypeVar('A') +HA = TypeVar('HA', bound=Hashable) +HB = TypeVar('HB') class PytreeDeque(deque[A]): @@ -233,4 +236,50 @@ def from_any(cls, x): class PromoteDtypeFn(Protocol): def __call__( self, args: TupleArg, /, *, dtype: Any = None, inexact: bool = True - ) -> TupleArg: ... \ No newline at end of file + ) -> TupleArg: ... + + +class HashableMapping(Mapping[HA, HB], Hashable): + _mapping: dict[HA, HB] | Mapping[HA, HB] + + def __init__(self, mapping: Mapping[HA, HB], copy: bool = True): + self._mapping = dict(mapping) if copy else mapping + + def __contains__(self, key: object) -> bool: + return key in self._mapping + + def __getitem__(self, key: HA) -> HB: + return self._mapping[key] + + def __iter__(self) -> Iterator[HA]: + return iter(self._mapping) + + def __len__(self) -> int: + return len(self._mapping) + + def __hash__(self) -> int: + # use type-aware sorting to support int keys + def _pytree__key_sort_fn(item: tuple[Any, Any]) -> tuple[int, Any]: + key, _ = item + if isinstance(key, int): + return (0, key) + elif isinstance(key, str): + return (1, key) + else: + raise ValueError(f'Unsupported key type: {type(key)!r}') + + return hash(tuple(sorted(self._mapping.items(), key=_pytree__key_sort_fn))) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, HashableMapping) and self._mapping == other._mapping + ) + + def __repr__(self) -> str: + return repr(self._mapping) + + def update(self, other: Mapping[HA, HB]) -> HashableMapping[HA, HB]: + """Updates the mapping with another mapping.""" + mapping = dict(self._mapping) + mapping.update(other) + return HashableMapping(mapping, copy=False) diff --git a/tests/nnx/mutable_array_test.py b/tests/nnx/mutable_array_test.py index 823539e38..990c32509 100644 --- a/tests/nnx/mutable_array_test.py +++ b/tests/nnx/mutable_array_test.py @@ -151,7 +151,13 @@ def __init__(self): m5 = nnx.as_hijax_vars(m2) self.assertFalse(m5.a.has_ref) self.assertTrue(m5.a.is_hijax) - self.assertNotIn('had_ref', m5.a.get_metadata()) + self.assertIn('had_ref', m5.a.get_metadata()) + + m6 = nnx.as_mutable_vars(m5) + self.assertIsInstance(m6.a.get_raw_value(), jax.Ref) + self.assertTrue(m6.a.has_ref) + self.assertTrue(m6.a.is_hijax) + self.assertNotIn('had_ref', m6.a.get_metadata()) def test_to_arrays_example(self): node = [nnx.Variable(1.0), nnx.Variable(2.0, mode='ref')] @@ -706,7 +712,7 @@ def loss_fn(params): model = nnx.merge(graphdef, params, nondiff) return jnp.mean((model(x) - y) ** 2) - loss, grads = jax.value_and_grad(loss_fn)(nnx.as_immutable_vars(params)) + loss, grads = jax.value_and_grad(loss_fn)(nnx.as_pytree_vars(params)) optimizer.update(params, grads) return loss @@ -783,7 +789,8 @@ def set(v_hi, a): v_low = nnx.as_immutable_vars(v_hi) - assert not v_low.is_hijax and not v_low.is_mutable + assert v_low.is_hijax + assert not v_low.is_mutable assert v_low[...] == 10 def test_immutable_variable(self): @@ -910,6 +917,55 @@ def f(v): self.assertEqual(y.shape, ()) + @nnx.use_hijax(True) + def test_qdd_grad(self): + v = nnx.Param(jnp.array(3.0)) + + self.assertTrue(v.is_mutable) + self.assertTrue(v.is_hijax) + + def f(v): + self.assertFalse(v.is_mutable) + self.assertTrue(v.is_hijax) + return v[...] ** 2 + + grad = jax.grad(f)(v) + + self.assertIsInstance(grad, nnx.Param) + self.assertEqual(grad[...], 6.0) + + @nnx.use_hijax(True) + def test_no_qdd_grad(self): + v = nnx.Param(jnp.array(3.0), is_mutable=False) + + self.assertFalse(v.is_mutable) + self.assertTrue(v.is_hijax) + + def f(v): + self.assertTrue(v.is_hijax) + return v[...] ** 2 + + grad = jax.grad(f)(v) + + self.assertIsInstance(grad, nnx.Param) + self.assertEqual(grad[...], 6.0) + + @nnx.use_hijax(True) + def test_no_qdd_grad_new(self): + x = jnp.array(3.0) + + def f(x): + v = nnx.Param(x, is_mutable=False) + self.assertFalse(v.is_mutable) + self.assertTrue(v.is_hijax) + self.assertTrue(v.is_hijax) + return v[...] ** 2 + + grad = jax.grad(f)(x) + + self.assertIsInstance(grad, jax.Array) + self.assertEqual(grad, 6.0) + if __name__ == '__main__': absltest.main()