Skip to content

Commit 9350b44

Browse files
author
Flax Authors
committed
Merge pull request #1478 from jheek:add-dataclass-autocomplete-support
PiperOrigin-RevId: 389608780
2 parents f558f49 + 2d05863 commit 9350b44

File tree

4 files changed

+46
-8
lines changed

4 files changed

+46
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ vNext
2323
-
2424
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated ([bug](https://github.com/google/flax/issues/1429)).
2525
-
26-
-
26+
- linen Modules and dataclasses made with `flax.struct.dataclass` or `flax.struct.PyTreeNode` are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
2727
-
2828
-
2929
- `flax.linen.Conv` no longer interprets an int past as kernel_size as a 1d convolution. Instead a type error is raised stating that

flax/linen/module.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from flax.core import Scope
3939
from flax.core.scope import CollectionFilter, DenyList, Variable, VariableDict, FrozenVariableDict, union_filters
4040
from flax.core.frozen_dict import FrozenDict, freeze
41+
from flax.struct import __dataclass_transform__
4142

4243
# from .dotgetter import DotGetter
4344

@@ -366,7 +367,17 @@ def reimport(self, other):
366367
# -----------------------------------------------------------------------------
367368

368369

369-
class Module:
370+
# This metaclass + decorator is used by static analysis tools recognize that
371+
# Module behaves as a dataclass (attributes are constructor args).
372+
if typing.TYPE_CHECKING:
373+
@__dataclass_transform__()
374+
class ModuleMeta(type):
375+
pass
376+
else:
377+
ModuleMeta = type
378+
379+
380+
class Module(metaclass=ModuleMeta):
370381
"""Base class for all neural network modules. Layers and models should subclass this class.
371382
372383
All Flax Modules are Python 3.7

flax/struct.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
"""Utilities for defining custom classes that can be used with jax transformations.
3131
"""
3232

33-
from typing import TypeVar
33+
import typing
34+
from typing import TypeVar, Callable, Tuple, Union, Any
3435

3536
from . import serialization
3637

@@ -39,7 +40,26 @@
3940
import jax
4041

4142

42-
def dataclass(clz: type):
43+
44+
# This decorator is interpreted by static analysis tools as a hint
45+
# that a decorator or metaclass causes dataclass-like behavior.
46+
# See https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
47+
# for more information about the __dataclass_transform__ magic.
48+
_T = TypeVar("_T")
49+
def __dataclass_transform__(
50+
*,
51+
eq_default: bool = True,
52+
order_default: bool = False,
53+
kw_only_default: bool = False,
54+
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
55+
) -> Callable[[_T], _T]:
56+
# If used within a stub file, the following implementation can be
57+
# replaced with "...".
58+
return lambda a: a
59+
60+
61+
@__dataclass_transform__()
62+
def dataclass(clz: _T) -> _T:
4363
"""Create a class which can be passed to functional transformations.
4464
4565
NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when
@@ -77,7 +97,8 @@ def __apply__(self, *args):
7797
Returns:
7898
The new class.
7999
"""
80-
data_clz = dataclasses.dataclass(frozen=True)(clz)
100+
# workaround for pytype not recognizing __dataclass_fields__
101+
data_clz: Any = dataclasses.dataclass(frozen=True)(clz)
81102
meta_fields = []
82103
data_fields = []
83104
for name, field_info in data_clz.__dataclass_fields__.items():
@@ -143,7 +164,15 @@ def field(pytree_node=True, **kwargs):
143164
TNode = TypeVar('TNode', bound='PyTreeNode')
144165

145166

146-
class PyTreeNode():
167+
if typing.TYPE_CHECKING:
168+
@__dataclass_transform__()
169+
class PyTreeNodeMeta(type):
170+
pass
171+
else:
172+
PyTreeNodeMeta = type
173+
174+
175+
class PyTreeNode(metaclass=PyTreeNodeMeta):
147176
"""Base class for dataclasses that should act like a JAX pytree node.
148177
149178
See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.

tests/linen/module_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import functools
1919
import operator
2020

21-
22-
2321
from absl.testing import absltest
2422

2523
import jax

0 commit comments

Comments
 (0)