Skip to content

Commit cdf9701

Browse files
authored
aae2 dataset: implement paragraphs variant (#169)
* implement "paragraphs" dataset variant * add test_remove_cross_partition_relations()
1 parent 5a01869 commit cdf9701

File tree

2 files changed

+341
-84
lines changed

2 files changed

+341
-84
lines changed

Diff for: dataset_builders/pie/aae2/aae2.py

+75-18
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,51 @@ def get_common_pipeline_steps(conversion_method: str) -> dict:
123123
)
124124

125125

126+
def remove_cross_partition_relations(
127+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
128+
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
129+
# for each labeled_spans entry, get the labeled_partitions entry it belongs to
130+
labeled_span2partition = {}
131+
for labeled_span in document.labeled_spans:
132+
for partition in document.labeled_partitions:
133+
if partition.start <= labeled_span.start and labeled_span.end <= partition.end:
134+
labeled_span2partition[labeled_span] = partition
135+
break
136+
else:
137+
raise ValueError(f"Could not find partition for labeled_span: {labeled_span}")
138+
139+
result = document.copy(with_annotations=True)
140+
idx = 0
141+
for relation in document.binary_relations:
142+
head_partition = labeled_span2partition[relation.head]
143+
tail_partition = labeled_span2partition[relation.tail]
144+
if head_partition != tail_partition:
145+
result.binary_relations.pop(idx)
146+
else:
147+
idx += 1
148+
return result
149+
150+
151+
# def split_documents_into_partitions(
152+
# document: TextDocumentWithLabeledSpansAndBinaryRelations,
153+
# ) -> TextDocumentWithLabeledSpansAndBinaryRelations:
154+
# raise NotImplementedError("split_documents_into_partitions is not implemented yet.")
155+
156+
157+
def get_common_pipeline_steps_paragraphs(conversion_method: str) -> dict:
158+
return dict(
159+
**get_common_pipeline_steps(conversion_method=conversion_method),
160+
cast=Caster(document_type=TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions),
161+
add_partitions=RegexPartitioner(
162+
partition_layer_name="labeled_partitions",
163+
default_partition_label="paragraph",
164+
pattern="\n",
165+
strip_whitespace=True,
166+
verbose=False,
167+
),
168+
)
169+
170+
126171
class ArgumentAnnotatedEssaysV2Config(BratConfig):
127172
def __init__(self, conversion_method: str, **kwargs):
128173
"""BuilderConfig for ArgumentAnnotatedEssaysV2.
@@ -140,43 +185,55 @@ class ArgumentAnnotatedEssaysV2(BratBuilder):
140185
BASE_DATASET_PATH = "DFKI-SLT/brat"
141186
BASE_DATASET_REVISION = "bb8c37d84ddf2da1e691d226c55fef48fd8149b5"
142187

143-
# we need to add None to the list of dataset variants to support the default dataset variant
144-
BASE_BUILDER_KWARGS_DICT = {
145-
dataset_variant: {"url": URL, "split_paths": SPLIT_PATHS}
146-
for dataset_variant in [BratBuilder.DEFAULT_CONFIG_NAME, None]
147-
}
148-
149188
BUILDER_CONFIGS = [
150189
ArgumentAnnotatedEssaysV2Config(
151190
name=BratBuilder.DEFAULT_CONFIG_NAME,
152191
conversion_method="connect_first",
153192
),
193+
ArgumentAnnotatedEssaysV2Config(
194+
name="paragraphs",
195+
conversion_method="connect_all",
196+
),
154197
]
155198

156-
DOCUMENT_TYPES = {
157-
BratBuilder.DEFAULT_CONFIG_NAME: BratDocumentWithMergedSpans,
199+
# we need to add None to the list of dataset variants to support the default dataset variant
200+
BASE_BUILDER_KWARGS_DICT = {
201+
dataset_variant: {"url": URL, "split_paths": SPLIT_PATHS}
202+
for dataset_variant in [None] + [config.name for config in BUILDER_CONFIGS]
158203
}
159204

205+
DOCUMENT_TYPES = {config.name: BratDocumentWithMergedSpans for config in BUILDER_CONFIGS}
206+
160207
@property
161208
def document_converters(self) -> DocumentConvertersType:
162-
if self.config.name == "default" or None:
209+
if self.config.name in [None, "main_claim_connect_all", BratBuilder.DEFAULT_CONFIG_NAME]:
163210
return {
164211
TextDocumentWithLabeledSpansAndBinaryRelations: Pipeline(
165212
**get_common_pipeline_steps(conversion_method=self.config.conversion_method)
166213
),
167214
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: Pipeline(
168-
**get_common_pipeline_steps(conversion_method=self.config.conversion_method),
169-
cast=Caster(
170-
document_type=TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
215+
**get_common_pipeline_steps_paragraphs(
216+
conversion_method=self.config.conversion_method
217+
)
218+
),
219+
}
220+
elif self.config.name == "paragraphs":
221+
return {
222+
# return one document per paragraph
223+
# TextDocumentWithLabeledSpansAndBinaryRelations: Pipeline(
224+
# **get_common_pipeline_steps_paragraphs(conversion_method=self.config.conversion_method),
225+
# split_documents=Converter(function=split_documents_into_partitions),
226+
# ),
227+
# just remove the cross-paragraph relations
228+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions: Pipeline(
229+
**get_common_pipeline_steps_paragraphs(
230+
conversion_method=self.config.conversion_method
171231
),
172-
add_partitions=RegexPartitioner(
173-
partition_layer_name="labeled_partitions",
174-
default_partition_label="paragraph",
175-
pattern="\n",
176-
strip_whitespace=True,
177-
verbose=False,
232+
remove_cross_partition_relations=Converter(
233+
function=remove_cross_partition_relations
178234
),
179235
),
180236
}
237+
181238
else:
182239
raise ValueError(f"Unknown dataset variant: {self.config.name}")

0 commit comments

Comments
 (0)