|
24 | 24 |
|
25 | 25 |
|
26 | 26 | class DatasetDictTest(TestCase):
|
27 |
| - def _create_dummy_dataset(self, multiple_columns=False): |
| 27 | + def _create_dummy_dataset(self, multiple_columns=False, int_to_float=False): |
28 | 28 | if multiple_columns:
|
29 | 29 | data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
|
30 | 30 | dset = Dataset.from_dict(data)
|
| 31 | + elif int_to_float: |
| 32 | + data = { |
| 33 | + "text": ["text1", "text2", "text3", "text4"], |
| 34 | + "labels": [[1, 1, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 1, 1], [0, 0, 0, 1, 0]], |
| 35 | + } |
| 36 | + dset = Dataset.from_dict(data) |
31 | 37 | else:
|
32 | 38 | dset = Dataset.from_dict(
|
33 | 39 | {"filename": ["my_name-train" + "_" + f"{x:03d}" for x in np.arange(30).tolist()]}
|
34 | 40 | )
|
35 | 41 | return dset
|
36 | 42 |
|
37 |
| - def _create_dummy_dataset_dict(self, multiple_columns=False) -> DatasetDict: |
| 43 | + def _create_dummy_dataset_dict(self, multiple_columns=False, int_to_float=False) -> DatasetDict: |
38 | 44 | return DatasetDict(
|
39 | 45 | {
|
40 |
| - "train": self._create_dummy_dataset(multiple_columns=multiple_columns), |
41 |
| - "test": self._create_dummy_dataset(multiple_columns=multiple_columns), |
| 46 | + "train": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float), |
| 47 | + "test": self._create_dummy_dataset(multiple_columns=multiple_columns, int_to_float=int_to_float), |
42 | 48 | }
|
43 | 49 | )
|
44 | 50 |
|
@@ -325,6 +331,28 @@ def test_map(self):
|
325 | 331 | self.assertListEqual(sorted(mapped_dsets_2["train"].column_names), sorted(["filename", "foo", "bar"]))
|
326 | 332 | del dsets, mapped_dsets_1, mapped_dsets_2
|
327 | 333 |
|
| 334 | + # casting int labels to float labels |
| 335 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 336 | + dset_dict = self._create_dummy_dataset_dict(int_to_float=True) |
| 337 | + |
| 338 | + def _preprocess(examples): |
| 339 | + result = {"labels": [list(map(float, labels)) for labels in examples["labels"]]} |
| 340 | + return result |
| 341 | + |
| 342 | + with dset_dict.map( |
| 343 | + _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=True |
| 344 | + ) as dset_test: |
| 345 | + for labels in dset_test["test"]["labels"]: |
| 346 | + for label in labels: |
| 347 | + self.assertIsInstance(label, int) |
| 348 | + |
| 349 | + with dset_dict.map( |
| 350 | + _preprocess, remove_columns=["labels", "text"], batched=True, try_original_type=False |
| 351 | + ) as dset_test: |
| 352 | + for labels in dset_test["test"]["labels"]: |
| 353 | + for label in labels: |
| 354 | + self.assertIsInstance(label, float) |
| 355 | + |
328 | 356 | def test_iterable_map(self):
|
329 | 357 | dsets = self._create_dummy_iterable_dataset_dict()
|
330 | 358 | fn_kwargs = {"n": 3}
|
|
0 commit comments