Skip to content

Commit 964f503

Browse files
vfdev-5cgarciae
authored andcommitted
Added workaround by P.Hawkins
1 parent 0ad978f commit 964f503

File tree

4 files changed

+42
-12
lines changed

4 files changed

+42
-12
lines changed

.github/workflows/flax_test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ jobs:
191191
run: |
192192
rm -fr .venv
193193
uv sync --extra testing --extra docs
194+
# temporary: install jax nightly
195+
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
194196
- name: Test with pytest
195197
run: |
196198
export XLA_FLAGS='--xla_force_host_platform_device_count=4'

flax/linen/module.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import enum
2020
import functools
2121
import inspect
22+
import sys
2223
import threading
2324
import typing
2425
import weakref
@@ -1097,6 +1098,12 @@ def _customized_dataclass_transform(cls, kw_only: bool):
10971098
) in extra_fields:
10981099
setattr(cls, name, default)
10991100
cls.__annotations__[name] = annotation
1101+
1102+
# TODO: a workaround for the issue:
1103+
# https://github.com/google/flax/pull/5087#issuecomment-3536610568
1104+
if (sys.version_info.major, sys.version_info.minor) in [(3, 12), (3, 13)]:
1105+
setattr(cls, '__annotations__', cls.__annotations__)
1106+
11001107
dataclasses.dataclass( # type: ignore[call-overload]
11011108
unsafe_hash='__hash__' not in cls.__dict__,
11021109
repr=False,

flax/nnx/variablelib.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,8 +1712,8 @@ def __contains__(self, item) -> bool:
17121712

17131713
def __eq__(self, other) -> bool:
17141714
if isinstance(other, Variable):
1715-
other = other[...]
1716-
return self[...].__eq__(other) # type: ignore
1715+
other = other.get_value()
1716+
return self.get_value().__eq__(other) # type: ignore
17171717

17181718
def __iadd__(self: V, other) -> V:
17191719
raise NotImplementedError(
@@ -1736,61 +1736,61 @@ def __imul__(self: V, other) -> V:
17361736
def __imatmul__(self: V, other) -> V:
17371737
raise NotImplementedError(
17381738
'In-place operations are no longer supported for Variable.\n'
1739-
'Use `variable.value @= x` instead.'
1739+
'Use `variable[...] @= x` instead.'
17401740
)
17411741

17421742
def __itruediv__(self: V, other) -> V:
17431743
raise NotImplementedError(
17441744
'In-place operations are no longer supported for Variable.\n'
1745-
'Use `variable.value /= x` instead.'
1745+
'Use `variable[...] /= x` instead.'
17461746
)
17471747

17481748
def __ifloordiv__(self: V, other) -> V:
17491749
raise NotImplementedError(
17501750
'In-place operations are no longer supported for Variable.\n'
1751-
'Use `variable.value //= x`` instead.'
1751+
'Use `variable[...] //= x`` instead.'
17521752
)
17531753

17541754
def __imod__(self: V, other) -> V:
17551755
raise NotImplementedError(
17561756
'In-place operations are no longer supported for Variable.\n'
1757-
'Use `variable.value %= x` instead.'
1757+
'Use `variable[...] %= x` instead.'
17581758
)
17591759

17601760
def __ipow__(self: V, other) -> V:
17611761
raise NotImplementedError(
17621762
'In-place operations are no longer supported for Variable.\n'
1763-
'Use `variable.value **= x`` instead.'
1763+
'Use `variable[...] **= x`` instead.'
17641764
)
17651765

17661766
def __ilshift__(self: V, other) -> V:
17671767
raise NotImplementedError(
17681768
'In-place operations are no longer supported for Variable.\n'
1769-
'Use `variable.value <<= x`` instead.'
1769+
'Use `variable[...] <<= x`` instead.'
17701770
)
17711771

17721772
def __irshift__(self: V, other) -> V:
17731773
raise NotImplementedError(
17741774
'In-place operations are no longer supported for Variable.\n'
1775-
'Use `variable.value >>= x`` instead.'
1775+
'Use `variable[...] >>= x`` instead.'
17761776
)
17771777

17781778
def __iand__(self: V, other) -> V:
17791779
raise NotImplementedError(
17801780
'In-place operations are no longer supported for Variable.\n'
1781-
'Use `variable.value &= x` instead.'
1781+
'Use `variable[...] &= x` instead.'
17821782
)
17831783

17841784
def __ixor__(self: V, other) -> V:
17851785
raise NotImplementedError(
17861786
'In-place operations are no longer supported for Variable.\n'
1787-
'Use `variable.value ^= x` instead.'
1787+
'Use `variable[...] ^= x` instead.'
17881788
)
17891789

17901790
def __ior__(self: V, other) -> V:
17911791
raise NotImplementedError(
17921792
'In-place operations are no longer supported for Variable.\n'
1793-
'Use `variable.value |= x` instead.'
1793+
'Use `variable[...] |= x` instead.'
17941794
)
17951795

17961796
__neg__ = _variable_unary_operator('__neg__')

tests/linen/linen_module_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,6 +2315,27 @@ class Foo(nn.Module):
23152315
Foo(1, None)
23162316
Foo(a=1, parent=None) # type: ignore[call-arg]
23172317

2318+
def test_failure_with_sequencelayer(self):
2319+
# This is a minimal reproducer of the failure seen with
2320+
# SequenceLayer project and Flax Linen when enabled support for 3.14
2321+
# See PR: https://github.com/google/flax/pull/5087
2322+
# Code below is based on
2323+
# https://github.com/google/flax/pull/5087#issuecomment-3535067361
2324+
import abc
2325+
from collections.abc import Iterator
2326+
from typing import Protocol
2327+
2328+
class CheckpointableIterator(Iterator, Protocol):
2329+
pass
2330+
2331+
class Steppable(metaclass=abc.ABCMeta):
2332+
pass
2333+
2334+
isinstance(Steppable, Iterator)
2335+
2336+
class SequenceLayer(nn.Module, Steppable):
2337+
pass
2338+
23182339
def test_module_path_empty(self):
23192340
rngkey = jax.random.key(0)
23202341
scope = Scope({}, {'params': rngkey}, mutable=['params'])

0 commit comments

Comments
 (0)