|
395 | 395 | "id": "8055b72c", |
396 | 396 | "metadata": {}, |
397 | 397 | "source": [ |
398 | | - "### Class Annotations\n", |
399 | | - "Annotations can also be added at the type level via `nnx.Static` and `nnx.Data`. This is useful for creating `dataclasses` but the mechanism also works for regular classes." |
| 398 | + "### Dataclasses\n", |
| 399 | + "`nnx.Pytree` dataclasses can be created by using the `nnx.dataclass` decorator. To control the status of each field, `nnx.static` and `nnx.data` can be used as `field` specifiers." |
400 | 400 | ] |
401 | 401 | }, |
402 | 402 | { |
|
420 | 420 | "source": [ |
421 | 421 | "import dataclasses\n", |
422 | 422 | "\n", |
423 | | - "@dataclasses.dataclass\n", |
| 423 | + "@nnx.dataclass\n", |
424 | 424 | "class Foo(nnx.Pytree):\n", |
425 | | - " i: nnx.Data[int]\n", |
426 | | - " s: nnx.Static[str]\n", |
| 425 | + " i: int = nnx.data()\n", |
427 | 426 | " x: jax.Array\n", |
428 | 427 | " a: int\n", |
| 428 | + " s: str = nnx.static(default='hi', kw_only=True)\n", |
429 | 429 | "\n", |
430 | | - "@dataclasses.dataclass\n", |
| 430 | + "@nnx.dataclass\n", |
431 | 431 | "class Bar(nnx.Pytree):\n", |
432 | | - " ls: nnx.Data[list[Foo]]\n", |
| 432 | + " ls: list[Foo] = nnx.data()\n", |
433 | 433 | " shapes: list[int]\n", |
434 | 434 | "\n", |
435 | 435 | "pytree = Bar(\n", |
436 | | - " ls=[Foo(i, \"Hi\" + \"!\" * i, jnp.array(42 * i), hash(i)) for i in range(2)],\n", |
| 436 | + " ls=[Foo(i, jnp.array(42 * i), hash(i)) for i in range(2)],\n", |
437 | 437 | " shapes=[8, 16, 32]\n", |
438 | 438 | ")\n", |
439 | 439 | "pytree_structure(pytree)" |
440 | 440 | ] |
441 | 441 | }, |
| 442 | + { |
| 443 | + "cell_type": "markdown", |
| 444 | + "id": "fca51f65", |
| 445 | + "metadata": {}, |
| 446 | + "source": [ |
| 447 | + "`dataclasses.dataclass` can also be used directly, however type checkers will not handle `nnx.static` and `nnx.data` correctly. To solve this `dataclasses.field` can be used by setting `metadata` with the appropriate entry for `static`." |
| 448 | + ] |
| 449 | + }, |
| 450 | + { |
| 451 | + "cell_type": "code", |
| 452 | + "execution_count": 10, |
| 453 | + "id": "ff54e732", |
| 454 | + "metadata": {}, |
| 455 | + "outputs": [ |
| 456 | + { |
| 457 | + "name": "stdout", |
| 458 | + "output_type": "stream", |
| 459 | + "text": [ |
| 460 | + "dataclass pytree structure:\n", |
| 461 | + " - pytree.a = 10\n" |
| 462 | + ] |
| 463 | + } |
| 464 | + ], |
| 465 | + "source": [ |
| 466 | + "@dataclasses.dataclass\n", |
| 467 | + "class Bar(nnx.Pytree):\n", |
| 468 | + " a: int = dataclasses.field(metadata={'static': False}) # data\n", |
| 469 | + " b: str = dataclasses.field(metadata={'static': True}) # static\n", |
| 470 | + "\n", |
| 471 | + "pytree = Bar(a=10, b=\"hello\")\n", |
| 472 | + "pytree_structure(pytree, title='dataclass pytree structure')" |
| 473 | + ] |
| 474 | + }, |
442 | 475 | { |
443 | 476 | "cell_type": "markdown", |
444 | 477 | "id": "d6036a0e", |
|
457 | 490 | }, |
458 | 491 | { |
459 | 492 | "cell_type": "code", |
460 | | - "execution_count": 10, |
| 493 | + "execution_count": 11, |
461 | 494 | "id": "509a517e", |
462 | 495 | "metadata": {}, |
463 | 496 | "outputs": [ |
|
501 | 534 | }, |
502 | 535 | { |
503 | 536 | "cell_type": "code", |
504 | | - "execution_count": 11, |
| 537 | + "execution_count": 12, |
505 | 538 | "id": "98e04ff9", |
506 | 539 | "metadata": {}, |
507 | 540 | "outputs": [ |
|
534 | 567 | }, |
535 | 568 | { |
536 | 569 | "cell_type": "code", |
537 | | - "execution_count": 12, |
| 570 | + "execution_count": 13, |
538 | 571 | "id": "c864d5b1", |
539 | 572 | "metadata": {}, |
540 | 573 | "outputs": [ |
|
568 | 601 | }, |
569 | 602 | { |
570 | 603 | "cell_type": "code", |
571 | | - "execution_count": 13, |
| 604 | + "execution_count": 14, |
572 | 605 | "id": "628698dd", |
573 | 606 | "metadata": {}, |
574 | 607 | "outputs": [ |
|
616 | 649 | "id": "37ee2429", |
617 | 650 | "metadata": {}, |
618 | 651 | "source": [ |
619 | | - "Checking for `nnx.data` or `nnx.static` annotations stored in inside nested structures that are not `nnx.Pytree` instances:" |
| 652 | + "Checking for `nnx.data` or `nnx.static` annotations stored inside nested structures that are not `nnx.Pytree` instances:" |
620 | 653 | ] |
621 | 654 | }, |
622 | 655 | { |
623 | 656 | "cell_type": "code", |
624 | | - "execution_count": 14, |
| 657 | + "execution_count": 15, |
625 | 658 | "id": "f9d69634", |
626 | 659 | "metadata": {}, |
627 | 660 | "outputs": [ |
628 | 661 | { |
629 | 662 | "name": "stdout", |
630 | 663 | "output_type": "stream", |
631 | 664 | "text": [ |
632 | | - "ValueError: Found unexpected tags {'data', 'static'} on attribute 'Foo.a'. Values from nnx.data(...) and\n", |
| 665 | + "ValueError: Found unexpected tags {'static', 'data'} on attribute 'Foo.a'. Values from nnx.data(...) and\n", |
633 | 666 | "nnx.static(...) should be assigned to nnx.Pytree attributes directly, they should not be inside other structures. Got value of type '<class 'list'>' on Pytree of type '<class '__main__.Foo'>'.\n" |
634 | 667 | ] |
635 | 668 | } |
|
656 | 689 | }, |
657 | 690 | { |
658 | 691 | "cell_type": "code", |
659 | | - "execution_count": 15, |
| 692 | + "execution_count": 16, |
660 | 693 | "id": "668db479", |
661 | 694 | "metadata": {}, |
662 | 695 | "outputs": [ |
|
704 | 737 | }, |
705 | 738 | { |
706 | 739 | "cell_type": "code", |
707 | | - "execution_count": 16, |
| 740 | + "execution_count": 17, |
708 | 741 | "id": "32c46ce8", |
709 | 742 | "metadata": {}, |
710 | 743 | "outputs": [ |
|
741 | 774 | }, |
742 | 775 | { |
743 | 776 | "cell_type": "code", |
744 | | - "execution_count": 17, |
| 777 | + "execution_count": 18, |
745 | 778 | "id": "c33e4862", |
746 | 779 | "metadata": {}, |
747 | 780 | "outputs": [ |
|
778 | 811 | }, |
779 | 812 | { |
780 | 813 | "cell_type": "code", |
781 | | - "execution_count": 18, |
| 814 | + "execution_count": 19, |
782 | 815 | "id": "dda51b67", |
783 | 816 | "metadata": {}, |
784 | 817 | "outputs": [ |
|
825 | 858 | }, |
826 | 859 | { |
827 | 860 | "cell_type": "code", |
828 | | - "execution_count": 19, |
| 861 | + "execution_count": 20, |
829 | 862 | "id": "caa01e3b", |
830 | 863 | "metadata": {}, |
831 | 864 | "outputs": [ |
|
864 | 897 | }, |
865 | 898 | { |
866 | 899 | "cell_type": "code", |
867 | | - "execution_count": 20, |
| 900 | + "execution_count": 21, |
868 | 901 | "id": "d2e03753", |
869 | 902 | "metadata": {}, |
870 | 903 | "outputs": [ |
|
929 | 962 | }, |
930 | 963 | { |
931 | 964 | "cell_type": "code", |
932 | | - "execution_count": 21, |
| 965 | + "execution_count": 22, |
933 | 966 | "id": "ca9f58a2", |
934 | 967 | "metadata": {}, |
935 | 968 | "outputs": [ |
|
1003 | 1036 | }, |
1004 | 1037 | { |
1005 | 1038 | "cell_type": "code", |
1006 | | - "execution_count": 38, |
| 1039 | + "execution_count": 23, |
1007 | 1040 | "id": "41398e14", |
1008 | 1041 | "metadata": {}, |
1009 | 1042 | "outputs": [ |
|
1050 | 1083 | }, |
1051 | 1084 | { |
1052 | 1085 | "cell_type": "code", |
1053 | | - "execution_count": null, |
| 1086 | + "execution_count": 24, |
1054 | 1087 | "id": "d10effba", |
1055 | 1088 | "metadata": {}, |
1056 | 1089 | "outputs": [ |
1057 | 1090 | { |
1058 | 1091 | "name": "stdout", |
1059 | 1092 | "output_type": "stream", |
1060 | 1093 | "text": [ |
1061 | | - "step = 0, loss = Array(0.7326511, dtype=float32), perturbations = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", |
| 1094 | + "step = 0, loss = Array(0.7326511, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", |
1062 | 1095 | " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", |
1063 | 1096 | " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.430146 , -0.14356601, 0.2935633 ]], dtype=float32)\n", |
1064 | 1097 | " \u001b[38;2;255;213;3m)\u001b[0m\n", |
1065 | 1098 | "\u001b[38;2;255;213;3m})\u001b[0m\n", |
1066 | | - "step = 1, loss = Array(0.65039134, dtype=float32), perturbations = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", |
| 1099 | + "step = 1, loss = Array(0.65039134, dtype=float32), iterm_grads = \u001b[38;2;79;201;177mState\u001b[0m\u001b[38;2;255;213;3m({\u001b[0m\u001b[38;2;105;105;105m\u001b[0m\n", |
1067 | 1100 | " \u001b[38;2;156;220;254m'xgrad'\u001b[0m\u001b[38;2;212;212;212m: \u001b[0m\u001b[38;2;79;201;177mPerturbation\u001b[0m\u001b[38;2;255;213;3m(\u001b[0m\u001b[38;2;105;105;105m # 3 (12 B)\u001b[0m\n", |
1068 | 1101 | " \u001b[38;2;156;220;254mvalue\u001b[0m\u001b[38;2;212;212;212m=\u001b[0mArray([[-0.38535568, -0.11745065, 0.24441527]], dtype=float32)\n", |
1069 | 1102 | " \u001b[38;2;255;213;3m)\u001b[0m\n", |
|
1108 | 1141 | }, |
1109 | 1142 | { |
1110 | 1143 | "cell_type": "code", |
1111 | | - "execution_count": 24, |
| 1144 | + "execution_count": 25, |
1112 | 1145 | "id": "a9cab639", |
1113 | 1146 | "metadata": {}, |
1114 | 1147 | "outputs": [ |
|
0 commit comments