We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 88d7497 commit 33a601eCopy full SHA for 33a601e
examples/sst2/train.py
@@ -39,8 +39,8 @@
39
class Metrics(struct.PyTreeNode):
40
"""Computed metrics."""
41
42
- loss: float
43
- accuracy: float
+ loss: Array
+ accuracy: Array
44
count: int | None = None
45
46
@@ -166,7 +166,7 @@ def normalize_batch_metrics(batch_metrics: Sequence[Metrics]) -> Metrics:
166
)
167
168
169
-def batch_to_numpy(batch: dict[str, tf.Tensor]) -> dict[str, Array]:
+def batch_to_numpy(batch: dict[str, Array]) -> dict[str, Array]:
170
"""Converts a batch with TF tensors to a batch of NumPy arrays."""
171
# _numpy() reuses memory, does not make a copy.
172
# pylint: disable=protected-access
0 commit comments