Skip to content

Commit 74b72bd

Browse files
author
Flax Authors
committed
Merge pull request #5066 from google:dataclass
PiperOrigin-RevId: 831131118
2 parents 49ef97c + 5b4c600 commit 74b72bd

File tree

8 files changed

+331
-124
lines changed

8 files changed

+331
-124
lines changed

docs_nnx/guides/pytree.ipynb

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@
395395
"id": "8055b72c",
396396
"metadata": {},
397397
"source": [
398-
"### Class Annotations\n",
399-
"Annotations can also be added at the type level via `nnx.Static` and `nnx.Data`. This is useful for creating `dataclasses` but the mechanism also works for regular classes."
398+
"### Dataclasses\n",
399+
"`nnx.Pytree` dataclasses can be created by using the `nnx.dataclass` decorator. To control the status of each field, `nnx.static` and `nnx.data` can be used as `field` specifiers."
400400
]
401401
},
402402
{
@@ -420,25 +420,58 @@
420420
"source": [
421421
"import dataclasses\n",
422422
"\n",
423-
"@dataclasses.dataclass\n",
423+
"@nnx.dataclass\n",
424424
"class Foo(nnx.Pytree):\n",
425-
" i: nnx.Data[int]\n",
426-
" s: nnx.Static[str]\n",
425+
" i: int = nnx.data()\n",
427426
" x: jax.Array\n",
428427
" a: int\n",
428+
" s: str = nnx.static(default='hi', kw_only=True)\n",
429429
"\n",
430-
"@dataclasses.dataclass\n",
430+
"@nnx.dataclass\n",
431431
"class Bar(nnx.Pytree):\n",
432-
" ls: nnx.Data[list[Foo]]\n",
432+
" ls: list[Foo] = nnx.data()\n",
433433
" shapes: list[int]\n",
434434
"\n",
435435
"pytree = Bar(\n",
436-
" ls=[Foo(i, \"Hi\" + \"!\" * i, jnp.array(42 * i), hash(i)) for i in range(2)],\n",
436+
" ls=[Foo(i, jnp.array(42 * i), hash(i)) for i in range(2)],\n",
437437
" shapes=[8, 16, 32]\n",
438438
")\n",
439439
"pytree_structure(pytree)"
440440
]
441441
},
442+
{
443+
"cell_type": "markdown",
444+
"id": "fca51f65",
445+
"metadata": {},
446+
"source": [
447+
"`dataclasses.dataclass` can also be used directly, however type checkers will not handle `nnx.static` and `nnx.data` correctly. To solve this `dataclasses.field` can be used by setting `metadata` with the appropriate entry for `static`."
448+
]
449+
},
450+
{
451+
"cell_type": "code",
452+
"execution_count": 10,
453+
"id": "ff54e732",
454+
"metadata": {},
455+
"outputs": [
456+
{
457+
"name": "stdout",
458+
"output_type": "stream",
459+
"text": [
460+
"dataclass pytree structure:\n",
461+
" - pytree.a = 10\n"
462+
]
463+
}
464+
],
465+
"source": [
466+
"@dataclasses.dataclass\n",
467+
"class Bar(nnx.Pytree):\n",
468+
" a: int = dataclasses.field(metadata={'static': False}) # data\n",
469+
" b: str = dataclasses.field(metadata={'static': True}) # static\n",
470+
"\n",
471+
"pytree = Bar(a=10, b=\"hello\")\n",
472+
"pytree_structure(pytree, title='dataclass pytree structure')"
473+
]
474+
},
442475
{
443476
"cell_type": "markdown",
444477
"id": "d6036a0e",
@@ -457,7 +490,7 @@
457490
},
458491
{
459492
"cell_type": "code",
460-
"execution_count": 10,
493+
"execution_count": 11,
461494
"id": "509a517e",
462495
"metadata": {},
463496
"outputs": [
@@ -501,7 +534,7 @@
501534
},
502535
{
503536
"cell_type": "code",
504-
"execution_count": 11,
537+
"execution_count": 12,
505538
"id": "98e04ff9",
506539
"metadata": {},
507540
"outputs": [
@@ -534,7 +567,7 @@
534567
},
535568
{
536569
"cell_type": "code",
537-
"execution_count": 12,
570+
"execution_count": 13,
538571
"id": "c864d5b1",
539572
"metadata": {},
540573
"outputs": [
@@ -568,7 +601,7 @@
568601
},
569602
{
570603
"cell_type": "code",
571-
"execution_count": 13,
604+
"execution_count": 14,
572605
"id": "628698dd",
573606
"metadata": {},
574607
"outputs": [
@@ -616,20 +649,20 @@
616649
"id": "37ee2429",
617650
"metadata": {},
618651
"source": [
619-
"Checking for `nnx.data` or `nnx.static` annotations stored in inside nested structures that are not `nnx.Pytree` instances:"
652+
"Checking for `nnx.data` or `nnx.static` annotations stored inside nested structures that are not `nnx.Pytree` instances:"
620653
]
621654
},
622655
{
623656
"cell_type": "code",
624-
"execution_count": 14,
657+
"execution_count": 15,
625658
"id": "f9d69634",
626659
"metadata": {},
627660
"outputs": [
628661
{
629662
"name": "stdout",
630663
"output_type": "stream",
631664
"text": [
632-
"ValueError: Found unexpected tags {'data', 'static'} on attribute 'Foo.a'. Values from nnx.data(...) and\n",
665+
"ValueError: Found unexpected tags {'static', 'data'} on attribute 'Foo.a'. Values from nnx.data(...) and\n",
633666
"nnx.static(...) should be assigned to nnx.Pytree attributes directly, they should not be inside other structures. Got value of type '<class 'list'>' on Pytree of type '<class '__main__.Foo'>'.\n"
634667
]
635668
}
@@ -656,7 +689,7 @@
656689
},
657690
{
658691
"cell_type": "code",
659-
"execution_count": 15,
692+
"execution_count": 16,
660693
"id": "668db479",
661694
"metadata": {},
662695
"outputs": [
@@ -704,7 +737,7 @@
704737
},
705738
{
706739
"cell_type": "code",
707-
"execution_count": 16,
740+
"execution_count": 17,
708741
"id": "32c46ce8",
709742
"metadata": {},
710743
"outputs": [
@@ -741,7 +774,7 @@
741774
},
742775
{
743776
"cell_type": "code",
744-
"execution_count": 17,
777+
"execution_count": 18,
745778
"id": "c33e4862",
746779
"metadata": {},
747780
"outputs": [
@@ -778,7 +811,7 @@
778811
},
779812
{
780813
"cell_type": "code",
781-
"execution_count": 18,
814+
"execution_count": 19,
782815
"id": "dda51b67",
783816
"metadata": {},
784817
"outputs": [
@@ -825,7 +858,7 @@
825858
},
826859
{
827860
"cell_type": "code",
828-
"execution_count": 19,
861+
"execution_count": 20,
829862
"id": "caa01e3b",
830863
"metadata": {},
831864
"outputs": [
@@ -864,7 +897,7 @@
864897
},
865898
{
866899
"cell_type": "code",
867-
"execution_count": 20,
900+
"execution_count": 21,
868901
"id": "d2e03753",
869902
"metadata": {},
870903
"outputs": [
@@ -929,7 +962,7 @@
929962
},
930963
{
931964
"cell_type": "code",
932-
"execution_count": 21,
965+
"execution_count": 22,
933966
"id": "ca9f58a2",
934967
"metadata": {},
935968
"outputs": [
@@ -1003,7 +1036,7 @@
10031036
},
10041037
{
10051038
"cell_type": "code",
1006-
"execution_count": 38,
1039+
"execution_count": 23,
10071040
"id": "41398e14",
10081041
"metadata": {},
10091042
"outputs": [
@@ -1050,20 +1083,20 @@
10501083
},
10511084
{
10521085
"cell_type": "code",
1053-
"execution_count": null,
1086+
"execution_count": 24,
10541087
"id": "d10effba",
10551088
"metadata": {},
10561089
"outputs": [
10571090
{
10581091
"name": "stdout",
10591092
"output_type": "stream",
10601093
"text": [
1061-
"step = 0, loss = Array(0.7326511, dtype=float32), perturbations = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
1094+
"step = 0, loss = Array(0.7326511, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
10621095
" \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n",
10631096
" \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.430146 , -0.14356601, 0.2935633 ]], dtype=float32)\n",
10641097
" \u001b[38;2;255;213;3m)\u001b[0m\n",
10651098
"\u001b[38;2;255;213;3m})\u001b[0m\n",
1066-
"step = 1, loss = Array(0.65039134, dtype=float32), perturbations = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
1099+
"step = 1, loss = Array(0.65039134, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n",
10671100
" \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n",
10681101
" \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.38535568, -0.11745065, 0.24441527]], dtype=float32)\n",
10691102
" \u001b[38;2;255;213;3m)\u001b[0m\n",
@@ -1108,7 +1141,7 @@
11081141
},
11091142
{
11101143
"cell_type": "code",
1111-
"execution_count": 24,
1144+
"execution_count": 25,
11121145
"id": "a9cab639",
11131146
"metadata": {},
11141147
"outputs": [

docs_nnx/guides/pytree.md

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,31 +195,43 @@ pytree = Bar(1.0, True)
195195
pytree_structure(pytree)
196196
```
197197

198-
### Class Annotations
199-
Annotations can also be added at the type level via `nnx.Static` and `nnx.Data`. This is useful for creating `dataclasses` but the mechanism also works for regular classes.
198+
### Dataclasses
199+
`nnx.Pytree` dataclasses can be created by using the `nnx.dataclass` decorator. To control the status of each field, `nnx.static` and `nnx.data` can be used as `field` specifiers.
200200

201201
```{code-cell} ipython3
202202
import dataclasses
203203
204-
@dataclasses.dataclass
204+
@nnx.dataclass
205205
class Foo(nnx.Pytree):
206-
i: nnx.Data[int]
207-
s: nnx.Static[str]
206+
i: int = nnx.data()
208207
x: jax.Array
209208
a: int
209+
s: str = nnx.static(default='hi', kw_only=True)
210210
211-
@dataclasses.dataclass
211+
@nnx.dataclass
212212
class Bar(nnx.Pytree):
213-
ls: nnx.Data[list[Foo]]
213+
ls: list[Foo] = nnx.data()
214214
shapes: list[int]
215215
216216
pytree = Bar(
217-
ls=[Foo(i, "Hi" + "!" * i, jnp.array(42 * i), hash(i)) for i in range(2)],
217+
ls=[Foo(i, jnp.array(42 * i), hash(i)) for i in range(2)],
218218
shapes=[8, 16, 32]
219219
)
220220
pytree_structure(pytree)
221221
```
222222

223+
`dataclasses.dataclass` can also be used directly, however type checkers will not handle `nnx.static` and `nnx.data` correctly. To solve this `dataclasses.field` can be used by setting `metadata` with the appropriate entry for `static`.
224+
225+
```{code-cell} ipython3
226+
@dataclasses.dataclass
227+
class Bar(nnx.Pytree):
228+
a: int = dataclasses.field(metadata={'static': False}) # data
229+
b: str = dataclasses.field(metadata={'static': True}) # static
230+
231+
pytree = Bar(a=10, b="hello")
232+
pytree_structure(pytree, title='dataclass pytree structure')
233+
```
234+
223235
### Attribute Updates
224236

225237
+++
@@ -281,7 +293,7 @@ except ValueError as e:
281293
print("ValueError:", e)
282294
```
283295

284-
Checking for `nnx.data` or `nnx.static` annotations stored in inside nested structures that are not `nnx.Pytree` instances:
296+
Checking for `nnx.data` or `nnx.static` annotations stored inside nested structures that are not `nnx.Pytree` instances:
285297

286298
```{code-cell} ipython3
287299
class Foo(nnx.Pytree):

flax/nnx/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .pytreelib import Object as Object
3737
from .pytreelib import Data as Data
3838
from .pytreelib import Static as Static
39+
from .pytreelib import dataclass as dataclass
3940
from .pytreelib import data as data
4041
from .pytreelib import static as static
4142
from .pytreelib import register_data_type as register_data_type
@@ -209,12 +210,12 @@
209210
from .summary import tabulate as tabulate
210211
from . import traversals as traversals
211212

212-
# alias VariableState
213-
VariableState = Variable
214213

215214
import typing as _tp
216215

217-
if not _tp.TYPE_CHECKING:
216+
if _tp.TYPE_CHECKING:
217+
VariableState = Variable
218+
else:
218219
def __getattr__(name):
219220
if name == "VariableState":
220221
import warnings
@@ -224,4 +225,5 @@ def __getattr__(name):
224225
DeprecationWarning,
225226
stacklevel=2,
226227
)
228+
return Variable
227229
raise AttributeError(f"Module {__name__} has no attribute '{name}'")

0 commit comments

Comments
 (0)