diff --git a/examples/sst2/train.py b/examples/sst2/train.py index 24543a615..2b555c3bb 100644 --- a/examples/sst2/train.py +++ b/examples/sst2/train.py @@ -39,8 +39,8 @@ class Metrics(struct.PyTreeNode): """Computed metrics.""" - loss: float - accuracy: float + loss: Array + accuracy: Array count: int | None = None @@ -166,7 +166,7 @@ def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics: ) -def batch_to_numpy(batch: dict[str, tf.Tensor]) -> dict[str, Array]: +def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]: """Converts a batch with TF tensors to a batch of NumPy arrays.""" # _numpy() reuses memory, does not make a copy. # pylint: disable=protected-access