Skip to content

Commit 6ac5a78

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:sst2] Fix pytype errors.
PiperOrigin-RevId: 838929177
1 parent 324f4b4 commit 6ac5a78

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/sst2/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
class Metrics(struct.PyTreeNode):
4040
"""Computed metrics."""
4141

42-
loss: float
43-
accuracy: float
42+
loss: Array
43+
accuracy: Array
4444
count: int | None = None
4545

4646

@@ -166,7 +166,7 @@ def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics:
166166
)
167167

168168

169-
def batch_to_numpy(batch: dict[str, tf.Tensor]) -> dict[str, Array]:
169+
def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]:
170170
"""Converts a batch with TF tensors to a batch of NumPy arrays."""
171171
# _numpy() reuses memory, does not make a copy.
172172
# pylint: disable=protected-access

0 commit comments

Comments
 (0)