From 6ac5a7835799cfbc0912736958494924517f36a1 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Mon, 1 Dec 2025 14:16:13 -0800 Subject: [PATCH] [flax:examples:sst2] Fix pytype errors. PiperOrigin-RevId: 838929177 --- examples/sst2/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/sst2/train.py b/examples/sst2/train.py index 24543a615..2b555c3bb 100644 --- a/examples/sst2/train.py +++ b/examples/sst2/train.py @@ -39,8 +39,8 @@ class Metrics(struct.PyTreeNode): """Computed metrics.""" - loss: float - accuracy: float + loss: Array + accuracy: Array count: int | None = None @@ -166,7 +166,7 @@ def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics: ) -def batch_to_numpy(batch: dict[str, tf.Tensor]) -> dict[str, Array]: +def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]: """Converts a batch with TF tensors to a batch of NumPy arrays.""" # _numpy() reuses memory, does not make a copy. # pylint: disable=protected-access