Skip to content

Commit e939959

Browse files
Add try_original_type to DatasetDict.map (#7544)
* Add try_original_type to DatasetDict.map * Add test cases * Apply make style
1 parent 0cf7f2b commit e939959

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/datasets/dataset_dict.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ def map(
830830
fn_kwargs: Optional[dict] = None,
831831
num_proc: Optional[int] = None,
832832
desc: Optional[str] = None,
833+
try_original_type: Optional[bool] = True,
833834
) -> "DatasetDict":
834835
"""
835836
Apply a function to all the examples in the table (individually or in batches) and update the table.
@@ -908,6 +909,9 @@ def map(
908909
use multiprocessing.
909910
desc (`str`, *optional*, defaults to `None`):
910911
Meaningful description to be displayed alongside with the progress bar while mapping examples.
912+
try_original_type (`Optional[bool]`, defaults to `True`):
913+
Try to keep the types of the original columns (e.g. int32 -> int32).
914+
Set to False if you want to always infer new types.
911915
912916
Example:
913917
@@ -956,6 +960,7 @@ def map(
956960
fn_kwargs=fn_kwargs,
957961
num_proc=num_proc,
958962
desc=desc,
963+
try_original_type=try_original_type,
959964
)
960965

961966
if with_split:

tests/test_dataset_dict.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,27 @@
2424

2525

2626
class DatasetDictTest(TestCase):
27-
def _create_dummy_dataset(self, multiple_columns=False):
27+
def _create_dummy_dataset(self, multiple_columns=False, int_to_float=False):
2828
if multiple_columns:
2929
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
3030
dset = Dataset.from_dict(data)
31+
elif int_to_float:
32+
data = {
33+
"text": ["text1", "text2", "text3", "text4"],
34+
"labels": [[1, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0]],
35+
}
36+
dset = Dataset.from_dict(data)
3137
else:
3238
dset = Dataset.from_dict(
3339
{"filename": ["my_name-train" + "_" + f"{x:03d}" for x in np.arange(30).tolist()]}
3440
)
3541
return dset
3642

37-
def _create_dummy_dataset_dict(self, multiple_columns=False) -> DatasetDict:
43+
def _create_dummy_dataset_dict(self, multiple_columns=False, int_to_float=False) -> DatasetDict:
3844
return DatasetDict(
3945
{
40-
"train": self._create_dummy_dataset(multiple_columns=multiple_columns),
41-
"test": self._create_dummy_dataset(multiple_columns=multiple_columns),
46+
"train": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float),
47+
"test": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float),
4248
}
4349
)
4450

@@ -325,6 +331,28 @@ def test_map(self):
325331
self.assertListEqual(sorted(mapped_dsets_2["train"].column_names), sorted(["filename", "foo", "bar"]))
326332
del dsets, mapped_dsets_1, mapped_dsets_2
327333

334+
# casting int labels to float labels
335+
with tempfile.TemporaryDirectory() as tmp_dir:
336+
dset_dict = self._create_dummy_dataset_dict(int_to_float=True)
337+
338+
def _preprocess(examples):
339+
result = {"labels": [list(map(float, labels)) for labels in examples["labels"]]}
340+
return result
341+
342+
with dset_dict.map(
343+
_preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True
344+
) as dset_test:
345+
for labels in dset_test["test"]["labels"]:
346+
for label in labels:
347+
self.assertIsInstance(label, int)
348+
349+
with dset_dict.map(
350+
_preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False
351+
) as dset_test:
352+
for labels in dset_test["test"]["labels"]:
353+
for label in labels:
354+
self.assertIsInstance(label, float)
355+
328356
def test_iterable_map(self):
329357
dsets = self._create_dummy_iterable_dataset_dict()
330358
fn_kwargs = {"n": 3}

0 commit comments

Comments
 (0)