Skip to content

Commit 485e101

Browse files
author
Flax Authors
committed
Merge pull request #4909 from google:fix-mypy
PiperOrigin-RevId: 799720766
2 parents c24e655 + 011b69f commit 485e101

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

.github/workflows/flax_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
python-version: ${{ matrix.python-version }}
7373
- uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
7474
with:
75-
uv-version: "0.3.0"
75+
version: "0.8.13"
7676
- name: Install standalone dependencies only
7777
run: |
7878
uv sync
@@ -108,7 +108,7 @@ jobs:
108108
- name: Setup uv
109109
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
110110
with:
111-
version: "0.3.0"
111+
version: "0.8.13"
112112

113113
- name: Install dependencies
114114
run: |

.github/workflows/jax_nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
- name: Setup uv
3535
uses: astral-sh/setup-uv@887a942a15af3a7626099df99e897a18d9e5ab3a # v5.1.0
3636
with:
37-
version: "0.3.0"
37+
version: "0.8.13"
3838
- name: Install dependencies
3939
run: |
4040
uv sync --extra testing --extra docs

flax/nnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class Repeated: ...
7373
@jax.tree_util.register_dataclass
7474
@dataclasses.dataclass(frozen=True, slots=True, repr=False)
7575
class ArrayRefOutput(reprlib.Representable):
76-
value: jax.Array | NoUpdate | Repeated
76+
value: jax.Array
7777

7878
def __nnx_repr__(self):
7979
yield reprlib.Object(type=type(self))

flax/nnx/rnglib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def split_rngs_wrapper(*args, **kwargs):
840840
if squeeze:
841841
key = key[0]
842842
if variablelib.is_array_ref(stream.key.raw_value):
843-
stream.key.raw_value = variablelib.array_ref(key)
843+
stream.key.raw_value = variablelib.array_ref(key) # type: ignore[assignment]
844844
else:
845845
stream.key.value = key
846846
if squeeze:
@@ -852,7 +852,7 @@ def split_rngs_wrapper(*args, **kwargs):
852852

853853
count = jnp.zeros(counts_shape, dtype=jnp.uint32)
854854
if variablelib.is_array_ref(stream.count.raw_value):
855-
stream.count.raw_value = variablelib.array_ref(count)
855+
stream.count.raw_value = variablelib.array_ref(count) # type: ignore[assignment]
856856
else:
857857
stream.count.value = count
858858

flax/nnx/variablelib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def __init__(
247247
if is_array_ref(value):
248248
_value = tp.cast(A, value)
249249
else:
250-
_value = array_ref(jnp.asarray(value))
250+
_value = array_ref(jnp.asarray(value)) # type: ignore[assignment]
251251
else:
252252
_value = value
253253

0 commit comments

Comments
 (0)