|
30 | 30 | """Utilities for defining custom classes that can be used with jax transformations. |
31 | 31 | """ |
32 | 32 |
|
33 | | -from typing import TypeVar |
| 33 | +import typing |
| 34 | +from typing import TypeVar, Callable, Tuple, Union, Any |
34 | 35 |
|
35 | 36 | from . import serialization |
36 | 37 |
|
|
39 | 40 | import jax |
40 | 41 |
|
41 | 42 |
|
42 | | -def dataclass(clz: type): |
| 43 | + |
| 44 | +# This decorator is interpreted by static analysis tools as a hint |
| 45 | +# that a decorator or metaclass causes dataclass-like behavior. |
| 46 | +# See https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md |
| 47 | +# for more information about the __dataclass_transform__ magic. |
| 48 | +_T = TypeVar("_T") |
| 49 | +def __dataclass_transform__( |
| 50 | + *, |
| 51 | + eq_default: bool = True, |
| 52 | + order_default: bool = False, |
| 53 | + kw_only_default: bool = False, |
| 54 | + field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()), |
| 55 | +) -> Callable[[_T], _T]: |
| 56 | + # If used within a stub file, the following implementation can be |
| 57 | + # replaced with "...". |
| 58 | + return lambda a: a |
| 59 | + |
| 60 | + |
| 61 | +@__dataclass_transform__() |
| 62 | +def dataclass(clz: _T) -> _T: |
43 | 63 | """Create a class which can be passed to functional transformations. |
44 | 64 |
|
45 | 65 | NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when |
@@ -77,7 +97,8 @@ def __apply__(self, *args): |
77 | 97 | Returns: |
78 | 98 | The new class. |
79 | 99 | """ |
80 | | - data_clz = dataclasses.dataclass(frozen=True)(clz) |
| 100 | + # workaround for pytype not recognizing __dataclass_fields__ |
| 101 | + data_clz: Any = dataclasses.dataclass(frozen=True)(clz) |
81 | 102 | meta_fields = [] |
82 | 103 | data_fields = [] |
83 | 104 | for name, field_info in data_clz.__dataclass_fields__.items(): |
@@ -143,7 +164,15 @@ def field(pytree_node=True, **kwargs): |
143 | 164 | TNode = TypeVar('TNode', bound='PyTreeNode') |
144 | 165 |
|
145 | 166 |
|
146 | | -class PyTreeNode(): |
| 167 | +if typing.TYPE_CHECKING: |
| 168 | + @__dataclass_transform__() |
| 169 | + class PyTreeNodeMeta(type): |
| 170 | + pass |
| 171 | +else: |
| 172 | + PyTreeNodeMeta = type |
| 173 | + |
| 174 | + |
| 175 | +class PyTreeNode(metaclass=PyTreeNodeMeta): |
147 | 176 | """Base class for dataclasses that should act like a JAX pytree node. |
148 | 177 |
|
149 | 178 | See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior. |
|
0 commit comments