Skip to content

Commit 2b3ba52

Browse files
authored
sciarg: fix partitioning (#159)
* test label counts for labeled_partitions in sciarg * fix and simplify tests * fix allowing newlines between matching tags (important for abstract)
1 parent 06e3af5 commit 2b3ba52

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

dataset_builders/pie/sciarg/sciarg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def _generate_document(self, example, **kwargs):
123123
def document_converters(self) -> DocumentConvertersType:
124124
regex_partitioner = RegexPartitioner(
125125
partition_layer_name="labeled_partitions",
126-
pattern="<([^>/]+)>.*</\\1>",
126+
# find matching tags, allow newlines in between (s flag) and capture the tag name
127+
pattern="<([^>/]+)>(?s:.)*?</\\1>",
127128
label_group_id=1,
128129
label_whitelist=["Title", "Abstract", "H1"],
129130
skip_initial_partition=True,

tests/dataset_builders/pie/sciarg/test_sciarg.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,21 @@
5656
"spans": {"background_claim": 2752, "data": 4093, "own_claim": 5450},
5757
},
5858
}
59+
CONVERTED_LAYER_MAPPING = {
60+
"default": {
61+
"spans": "labeled_spans",
62+
"relations": "binary_relations",
63+
},
64+
"resolve_parts_of_same": {
65+
"spans": "labeled_multi_spans",
66+
"relations": "binary_relations",
67+
},
68+
}
69+
FULL_LABEL_COUNTS_CONVERTED = {
70+
variant: {CONVERTED_LAYER_MAPPING[variant][ln]: value for ln, value in counts.items()}
71+
for variant, counts in FULL_LABEL_COUNTS.items()
72+
}
73+
LABELED_PARTITION_COUNTS = {"Abstract": 40, "H1": 340, "Title": 40}
5974

6075

6176
def resolve_annotation(annotation: Annotation) -> Any:
@@ -257,32 +272,25 @@ def converted_dataset(dataset, target_document_type) -> Optional[DatasetDict]:
257272
return dataset.to_document_type(target_document_type)
258273

259274

260-
def test_converted_datasets(converted_dataset, dataset_variant):
275+
def test_converted_datasets(converted_dataset, dataset_variant, target_document_type):
261276
if converted_dataset is not None:
262277
split_sizes = {name: len(ds) for name, ds in converted_dataset.items()}
263278
assert split_sizes == SPLIT_SIZES
264279
if dataset_variant == "default":
265280
expected_document_type = TextDocumentWithLabeledSpansAndBinaryRelations
266-
layer_name_mapping = {
267-
"spans": "labeled_spans",
268-
"relations": "binary_relations",
269-
}
270281
elif dataset_variant == "resolve_parts_of_same":
271282
expected_document_type = TextDocumentWithLabeledMultiSpansAndBinaryRelations
272-
layer_name_mapping = {
273-
"spans": "labeled_multi_spans",
274-
"relations": "binary_relations",
275-
}
276283
else:
277284
raise ValueError(f"Unknown dataset variant: {dataset_variant}")
278285

286+
assert issubclass(converted_dataset.document_type, expected_document_type)
279287
assert isinstance(converted_dataset["train"][0], expected_document_type)
280288

281289
if TEST_FULL_DATASET:
282-
expected_label_counts = {
283-
layer_name_mapping[ln]: value
284-
for ln, value in FULL_LABEL_COUNTS[dataset_variant].items()
285-
}
290+
# copy to avoid modifying the original dict
291+
expected_label_counts = {**FULL_LABEL_COUNTS_CONVERTED[dataset_variant]}
292+
if issubclass(target_document_type, TextDocumentWithLabeledPartitions):
293+
expected_label_counts["labeled_partitions"] = LABELED_PARTITION_COUNTS
286294
assert_dataset_label_counts(converted_dataset, expected_label_counts)
287295

288296

0 commit comments

Comments
 (0)