|
56 | 56 | "spans": {"background_claim": 2752, "data": 4093, "own_claim": 5450}, |
57 | 57 | }, |
58 | 58 | } |
| 59 | +CONVERTED_LAYER_MAPPING = { |
| 60 | + "default": { |
| 61 | + "spans": "labeled_spans", |
| 62 | + "relations": "binary_relations", |
| 63 | + }, |
| 64 | + "resolve_parts_of_same": { |
| 65 | + "spans": "labeled_multi_spans", |
| 66 | + "relations": "binary_relations", |
| 67 | + }, |
| 68 | +} |
| 69 | +FULL_LABEL_COUNTS_CONVERTED = { |
| 70 | + variant: {CONVERTED_LAYER_MAPPING[variant][ln]: value for ln, value in counts.items()} |
| 71 | + for variant, counts in FULL_LABEL_COUNTS.items() |
| 72 | +} |
| 73 | +LABELED_PARTITION_COUNTS = {"Abstract": 40, "H1": 340, "Title": 40} |
59 | 74 |
|
60 | 75 |
|
61 | 76 | def resolve_annotation(annotation: Annotation) -> Any: |
@@ -257,32 +272,25 @@ def converted_dataset(dataset, target_document_type) -> Optional[DatasetDict]: |
257 | 272 | return dataset.to_document_type(target_document_type) |
258 | 273 |
|
259 | 274 |
|
260 | | -def test_converted_datasets(converted_dataset, dataset_variant): |
| 275 | +def test_converted_datasets(converted_dataset, dataset_variant, target_document_type): |
261 | 276 | if converted_dataset is not None: |
262 | 277 | split_sizes = {name: len(ds) for name, ds in converted_dataset.items()} |
263 | 278 | assert split_sizes == SPLIT_SIZES |
264 | 279 | if dataset_variant == "default": |
265 | 280 | expected_document_type = TextDocumentWithLabeledSpansAndBinaryRelations |
266 | | - layer_name_mapping = { |
267 | | - "spans": "labeled_spans", |
268 | | - "relations": "binary_relations", |
269 | | - } |
270 | 281 | elif dataset_variant == "resolve_parts_of_same": |
271 | 282 | expected_document_type = TextDocumentWithLabeledMultiSpansAndBinaryRelations |
272 | | - layer_name_mapping = { |
273 | | - "spans": "labeled_multi_spans", |
274 | | - "relations": "binary_relations", |
275 | | - } |
276 | 283 | else: |
277 | 284 | raise ValueError(f"Unknown dataset variant: {dataset_variant}") |
278 | 285 |
|
| 286 | + assert issubclass(converted_dataset.document_type, expected_document_type) |
279 | 287 | assert isinstance(converted_dataset["train"][0], expected_document_type) |
280 | 288 |
|
281 | 289 | if TEST_FULL_DATASET: |
282 | | - expected_label_counts = { |
283 | | - layer_name_mapping[ln]: value |
284 | | - for ln, value in FULL_LABEL_COUNTS[dataset_variant].items() |
285 | | - } |
| 290 | + # copy to avoid modifying the original dict |
| 291 | + expected_label_counts = {**FULL_LABEL_COUNTS_CONVERTED[dataset_variant]} |
| 292 | + if issubclass(target_document_type, TextDocumentWithLabeledPartitions): |
| 293 | + expected_label_counts["labeled_partitions"] = LABELED_PARTITION_COUNTS |
286 | 294 | assert_dataset_label_counts(converted_dataset, expected_label_counts) |
287 | 295 |
|
288 | 296 |
|
|
0 commit comments