Skip to content

Commit f237d24

Browse files
andsteingcopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 390551401
1 parent f6d969b commit f237d24

File tree

4 files changed

+54
-25
lines changed

4 files changed

+54
-25
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,7 @@ Documentation:
8585
- Makes `PreprocessFn` addable.
8686
- Allow n-dimensional arrays (and masks) to be passed to Metrics.Average().
8787
- Support slicing `PreprocessFn`.
88+
89+
## v0.0.6
90+
91+
- Makes `deterministic_data` work with `tfds>4.4.0` and `tfds<=4.4.0`.

clu/deterministic_data.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
import jax
6363
import jax.numpy as jnp
6464
import numpy as np
65+
from packaging import version
6566
import tensorflow as tf
6667
import tensorflow_datasets as tfds
6768
import typing_extensions
@@ -73,6 +74,9 @@
7374

7475
AUTOTUNE = tf.data.experimental.AUTOTUNE
7576

77+
_use_split_info = version.parse("4.4.0") < version.parse(
78+
tfds.version.__version__)
79+
7680

7781
class DatasetBuilder(typing_extensions.Protocol):
7882
"""Protocol for dataset builders (subset of tfds.core.DatasetBuilder)."""
@@ -106,15 +110,18 @@ class RemainderOptions(enum.Enum):
106110
def _shard_read_instruction(
107111
absolute_instruction,
108112
*,
109-
split_infos: Dict[str, tfds.core.SplitInfo],
113+
split_infos: Dict[str, Union[int, tfds.core.SplitInfo]],
110114
host_id: int,
111115
host_count: int,
112116
remainder_options: RemainderOptions,
113117
) -> tfds.core.ReadInstruction:
114118
"""Shards a single ReadInstruction. See get_read_instruction_for_host()."""
115119
start = absolute_instruction.from_ or 0
116-
end = absolute_instruction.to or (
117-
split_infos[absolute_instruction.splitname].num_examples)
120+
if _use_split_info:
121+
end = absolute_instruction.to or (
122+
split_infos[absolute_instruction.splitname].num_examples) # pytype: disable=attribute-error
123+
else:
124+
end = absolute_instruction.to or split_infos[absolute_instruction.splitname]
118125
assert end >= start, f"start={start}, end={end}"
119126
num_examples = end - start
120127

@@ -208,16 +215,23 @@ def get_read_instruction_for_host(
208215
f"Invalid combination of host_id ({host_id}) and host_count "
209216
f"({host_count}).")
210217

211-
if dataset_info is None:
212-
split_infos = {
213-
split: tfds.core.SplitInfo(
214-
name=split,
215-
shard_lengths=[num_examples],
216-
num_bytes=0,
217-
),
218-
}
218+
if _use_split_info:
219+
if dataset_info is None:
220+
split_infos = {
221+
split: tfds.core.SplitInfo(
222+
name=split,
223+
shard_lengths=[num_examples],
224+
num_bytes=0,
225+
),
226+
}
227+
else:
228+
split_infos = dataset_info.splits
219229
else:
220-
split_infos = dataset_info.splits
230+
if dataset_info is None:
231+
split_infos = {split: num_examples}
232+
else:
233+
split_infos = {k: v.num_examples for k, v in dataset_info.splits.items()}
234+
221235
read_instruction = tfds.core.ReadInstruction.from_spec(split)
222236
sharded_read_instructions = []
223237
for ri in read_instruction.to_absolute(split_infos):

clu/deterministic_data_test.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
"""Unit tests for the deterministic_data module."""
16+
import dataclasses
1617
import itertools
1718
import math
1819

@@ -21,11 +22,14 @@
2122

2223
from absl.testing import parameterized
2324
from clu import deterministic_data
24-
import dataclasses
2525
import jax
26+
from packaging import version
2627
import tensorflow as tf
2728
import tensorflow_datasets as tfds
2829

30+
_use_split_info = version.parse("4.4.0") < version.parse(
31+
tfds.version.__version__)
32+
2933

3034
@dataclasses.dataclass
3135
class MyDatasetBuilder:
@@ -35,11 +39,14 @@ class MyDatasetBuilder:
3539
def as_dataset(self, split: tfds.core.ReadInstruction, shuffle_files: bool,
3640
read_config: tfds.ReadConfig, decoders) -> tf.data.Dataset:
3741
del shuffle_files, read_config, decoders
38-
split_infos = {
39-
k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0)
40-
for k, v in self.name2len.items()
41-
}
42-
instructions = split.to_absolute(split_infos)
42+
if _use_split_info:
43+
split_infos = {
44+
k: tfds.core.SplitInfo(name=k, shard_lengths=[v], num_bytes=0)
45+
for k, v in self.name2len.items()
46+
}
47+
instructions = split.to_absolute(split_infos)
48+
else:
49+
instructions = split.to_absolute(self.name2len)
4350
assert len(instructions) == 1
4451
from_ = instructions[0].from_ or 0
4552
to = instructions[0].to or self.name2len[instructions[0].splitname]
@@ -88,12 +95,15 @@ def test_get_read_instruction_for_host_deprecated(self, num_examples: int,
8895
host_id=host_id,
8996
host_count=host_count,
9097
drop_remainder=drop_remainder)
91-
split_infos = {
92-
"test": tfds.core.SplitInfo(
93-
name="test",
94-
shard_lengths=[9],
95-
num_bytes=0,
96-
)}
98+
if _use_split_info:
99+
split_infos = {
100+
"test": tfds.core.SplitInfo(
101+
name="test",
102+
shard_lengths=[9],
103+
num_bytes=0,
104+
)}
105+
else:
106+
split_infos = {"test": 9}
97107
self.assertEqual(
98108
expected.to_absolute(split_infos), actual.to_absolute(split_infos))
99109

setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
setup(
3535
name="clu",
36-
version="0.0.5",
36+
version="0.0.6",
3737
description=("Set of libraries for ML training loops in JAX."),
3838
author="Common Loop Utils Authors",
3939
author_email="[email protected]",
@@ -51,6 +51,7 @@
5151
"jaxlib",
5252
"ml_collections",
5353
"numpy",
54+
"packaging",
5455
"tensorflow",
5556
"tensorflow_datasets",
5657
],

0 commit comments

Comments
 (0)