Skip to content

Commit

Permalink
fix type inference (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Dec 9, 2023
1 parent b4900a1 commit ced53f6
Showing 1 changed file with 62 additions and 59 deletions.
121 changes: 62 additions & 59 deletions src/pytorch_ie/core/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,24 @@ def _get_reference_fields_and_container_types(
for field in dataclasses.fields(annotation_class):
if field.name == "_targets":
continue
if not _contains_annotation_type(field.type):
if isinstance(field.type, type):
field_type = field.type
else:
field_type = typing.get_type_hints(annotation_class)[field.name]
if not _contains_annotation_type(field_type):
continue
if _is_optional_annotation_type(field.type):
if _is_optional_annotation_type(field_type):
containers[field.name] = typing.Optional
continue
if _is_annotation_type(field.type):
if _is_annotation_type(field_type):
containers[field.name] = None
continue
if _is_tuple_of_annotation_types(field.type):
if _is_tuple_of_annotation_types(field_type):
containers[field.name] = tuple
continue
annot_name = annotation_class.__name__
raise TypeError(
f"The type '{field.type}' of the field '{field.name}' from Annotation subclass '{annot_name}' can not "
f"The type '{field_type}' of the field '{field.name}' from Annotation subclass '{annot_name}' can not "
f"be handled automatically. For automatic handling, type constructs that contain any Annotation subclasses "
f"need to be either (1) pure subclasses of Annotation, (2) tuples of Annotation subclasses, or their "
f"optional variants (examples: 1) Span, 2) Tuple[Span, ...], 3) Optional[Span]). Is the defined type "
Expand Down Expand Up @@ -527,62 +531,59 @@ def __len__(self):
def __post_init__(self):
targeted = set()
field_names = {field.name for field in dataclasses.fields(self)}
for field in dataclasses.fields(self):
if field.name == "_annotation_graph":
continue

field_origin = typing.get_origin(field.type)

if field_origin is AnnotationLayer:
self._annotation_fields.add(field.name)
field_types = self.field_types()
for field in self.annotation_fields():

self._annotation_fields.add(field.name)

targets = field.metadata.get("targets")
for target in targets:
targeted.add(target)
if field.name not in self._annotation_graph:
self._annotation_graph[field.name] = []
self._annotation_graph[field.name].append(target)
if target not in field_names:
raise TypeError(
f'annotation target "{target}" is not in field names of the document: {field_names}'
)

targets = field.metadata.get("targets")
for target in targets:
targeted.add(target)
if field.name not in self._annotation_graph:
self._annotation_graph[field.name] = []
self._annotation_graph[field.name].append(target)
if target not in field_names:
# check annotation target names and use them together with target names from the AnnotationLayer
# to reorder targets, if available
target_names = field.metadata.get("target_names")
field_type = field_types[field.name]
annotation_type = typing.get_args(field_type)[0]
annotation_target_names = annotation_type.TARGET_NAMES
if annotation_target_names is not None:
if target_names is not None:
if set(target_names) != set(annotation_target_names):
raise TypeError(
f'annotation target "{target}" is not in field names of the document: {field_names}'
f"keys of targets {sorted(target_names)} do not match "
f"{annotation_type.__name__}.TARGET_NAMES {sorted(annotation_target_names)}"
)
# reorder targets according to annotation_target_names
target_name_mapping = dict(zip(target_names, targets))
target_position_mapping = {
i: target_name_mapping[name]
for i, name in enumerate(annotation_target_names)
}
targets = [target_position_mapping[i] for i in range(len(targets))]
else:
if len(annotation_target_names) != len(targets):
raise TypeError(
f"number of targets {sorted(targets)} does not match number of entries in "
f"{annotation_type.__name__}.TARGET_NAMES: {sorted(annotation_target_names)}"
)
# disallow multiple targets when target names are specified in the definition of the Annotation
if len(annotation_target_names) > 1:
raise TypeError(
f"A target name mapping is required for AnnotationLayers containing Annotations with "
f'TARGET_NAMES, but AnnotationLayer "{field.name}" has no target_names. You should '
f"pass the named_targets dict containing the following keys (see Annotation "
f'"{annotation_type.__name__}") to annotation_field: {annotation_target_names}'
)

# check annotation target names and use them together with target names from the AnnotationLayer
# to reorder targets, if available
target_names = field.metadata.get("target_names")
annotation_type = typing.get_args(field.type)[0]
annotation_target_names = annotation_type.TARGET_NAMES
if annotation_target_names is not None:
if target_names is not None:
if set(target_names) != set(annotation_target_names):
raise TypeError(
f"keys of targets {sorted(target_names)} do not match "
f"{annotation_type.__name__}.TARGET_NAMES {sorted(annotation_target_names)}"
)
# reorder targets according to annotation_target_names
target_name_mapping = dict(zip(target_names, targets))
target_position_mapping = {
i: target_name_mapping[name]
for i, name in enumerate(annotation_target_names)
}
targets = [target_position_mapping[i] for i in range(len(targets))]
else:
if len(annotation_target_names) != len(targets):
raise TypeError(
f"number of targets {sorted(targets)} does not match number of entries in "
f"{annotation_type.__name__}.TARGET_NAMES: {sorted(annotation_target_names)}"
)
# disallow multiple targets when target names are specified in the definition of the Annotation
if len(annotation_target_names) > 1:
raise TypeError(
f"A target name mapping is required for AnnotationLayers containing Annotations with "
f'TARGET_NAMES, but AnnotationLayer "{field.name}" has no target_names. You should '
f"pass the named_targets dict containing the following keys (see Annotation "
f'"{annotation_type.__name__}") to annotation_field: {annotation_target_names}'
)

field_value = field.type(document=self, targets=targets)
setattr(self, field.name, field_value)
field_value = field_type(document=self, targets=targets)
setattr(self, field.name, field_value)

if "_artificial_root" in self._annotation_graph:
raise ValueError(
Expand Down Expand Up @@ -612,6 +613,7 @@ def asdict(self):
def fromdict(cls, dct):
fields = dataclasses.fields(cls)
annotation_fields = cls.annotation_fields()
field_types = cls.field_types()

cls_kwargs = {}
for field in fields:
Expand Down Expand Up @@ -648,9 +650,10 @@ def fromdict(cls, dct):
if value is None or not value:
continue

field_type = field_types[field_name]
# TODO: handle single annotations, e.g. a document-level label
if typing.get_origin(field.type) is AnnotationLayer:
annotation_class = typing.get_args(field.type)[0]
if typing.get_origin(field_type) is AnnotationLayer:
annotation_class = typing.get_args(field_type)[0]
# build annotations
for annotation_data in value["annotations"]:
annotation_dict = dict(annotation_data)
Expand Down

0 comments on commit ced53f6

Please sign in to comment.