@@ -57,6 +57,7 @@ def tokens_per_second(self):
5757
5858class GraphParser (AbstractParser ):
5959 """ Parser for a single graph, has a state and optionally an oracle """
60+
6061 def __init__ (self , graph , * args , target = None , ** kwargs ):
6162 """
6263 :param graph: gold Graph to get the correct nodes and edges from (in training), or just to get id from (in test)
@@ -74,20 +75,54 @@ def __init__(self, graph, *args, target=None, **kwargs):
7475 assert self .lang , "Attribute 'lang' is required per passage when using multilingual BERT"
7576 self .state_hash_history = set ()
7677 self .state = self .oracle = None
77- if self .framework == "amr" and self .alignment : # Copy alignments to anchors, updating graph
78+ if self .framework in ( "amr" , "drg" , "ptg" ) and self .alignment : # Copy alignments to anchors, updating graph
7879 for alignment_node in self .alignment .nodes :
7980 node = self .graph .find_node (alignment_node .id )
8081 if node is None :
8182 self .config .log ("graph %s: invalid alignment node %s" % (self .graph .id , alignment_node .id ))
8283 continue
8384 if node .anchors is None :
8485 node .anchors = []
85- for conllu_node_id in (alignment_node .label or []) + list (chain (* alignment_node .values or [])):
86- conllu_node = self .conllu .find_node (conllu_node_id )
87- if conllu_node is None :
88- raise ValueError ("Alignments incompatible with tokenization: token %s "
89- "not found in graph %s" % (conllu_node_id , self .graph .id ))
90- node .anchors += conllu_node .anchors
86+
87+ conllu_node_id_list = None
88+ alignment_node_anchor_char_range_list = None
89+ if self .alignment .framework == "alignment" :
90+ conllu_node_id_list = (alignment_node .label or []) + list (chain (* alignment_node .values or []))
91+ elif self .alignment .framework == "anchoring" and self .framework in ("amr" , "ptg" ):
92+ conllu_node_id_list = set ([alignment_dict ["#" ] for alignment_dict in
93+ (alignment_node .anchors or [])
94+ + ([anchor for anchor_list in (alignment_node .anchorings or []) for anchor in anchor_list ])])
95+ elif self .alignment .framework == "anchoring" and self .framework == "drg" :
96+ alignment_node_anchor_char_range_list = [(int (alignment_dict ["from" ]),(int (alignment_dict ["to" ]))) for alignment_dict in
97+ (alignment_node .anchors or [])
98+ + ([anchor for anchor_list in (alignment_node .anchorings or []) for anchor in anchor_list ])]
99+ assert all ([len (conllu_node .anchors ) == 1 for conllu_node in self .conllu .nodes ])
100+ anchors_to_conllu_node = {(int (conllu_node .anchors [0 ]["from" ]), int (conllu_node .anchors [0 ]["to" ])):
101+ conllu_node
102+ for conllu_node in self .conllu .nodes }
103+ else :
104+ raise ValueError (f'Unknown alignments framework: { alignment_node .framework } ' )
105+
106+ if conllu_node_id_list is not None :
107+ assert self .framework in ("amr" , "ptg" )
108+ for conllu_node_id in conllu_node_id_list :
109+ conllu_node = self .conllu .find_node (conllu_node_id + 1 )
110+
111+ if conllu_node is None :
112+ raise ValueError ("Alignments incompatible with tokenization: token %s "
113+ "not found in graph %s" % (conllu_node_id , self .graph .id ))
114+ node .anchors += conllu_node .anchors
115+
116+ elif alignment_node_anchor_char_range_list is not None :
117+ for alignment_node_char_range in alignment_node_anchor_char_range_list :
118+ for conllu_anchor_range in anchors_to_conllu_node :
119+ if alignment_node_char_range [0 ] <= conllu_anchor_range [0 ] \
120+ and alignment_node_char_range [1 ] >= conllu_anchor_range [1 ]:
121+ conllu_node = anchors_to_conllu_node [conllu_anchor_range ]
122+ if conllu_node is None :
123+ raise ValueError ("Alignments incompatible with tokenization: token %s "
124+ "not found in graph %s" % (conllu_anchor_range , self .graph .id ))
125+ node .anchors += conllu_node .anchors
91126
92127 def init (self ):
93128 self .config .set_framework (self .framework )
@@ -320,6 +355,7 @@ def num_tokens(self, _):
320355
321356class BatchParser (AbstractParser ):
322357 """ Parser for a single training iteration or single pass over dev/test graphs """
358+
323359 def __init__ (self , * args , ** kwargs ):
324360 super ().__init__ (* args , ** kwargs )
325361 self .seen_per_framework = defaultdict (int )
@@ -335,14 +371,11 @@ def parse(self, graphs, display=True, write=False, accuracies=None):
335371 if conllu is None :
336372 self .config .print ("skipped '%s', no companion conllu data found" % graph .id )
337373 continue
338- alignment = self .alignment .get (graph .id )
374+ alignment = self .alignment .get (graph .id ) if self . alignment else None
339375 for target in graph .targets () or [graph .framework ]:
340376 if not self .training and target not in self .model .classifier .labels :
341377 self .config .print ("skipped target '%s' for '%s': did not train on it" % (target , graph .id ), level = 1 )
342378 continue
343- if target == "amr" and alignment is None :
344- self .config .print ("skipped target 'amr' for '%s': no companion alignment found" % graph .id , level = 1 )
345- continue
346379 parser = GraphParser (
347380 graph , self .config , self .model , self .training , conllu = conllu , alignment = alignment , target = target )
348381 if self .config .args .verbose and display :
@@ -403,6 +436,7 @@ def time_per_graph(self):
403436
404437class Parser (AbstractParser ):
405438 """ Main class to implement transition-based meaning representation parser """
439+
406440 def __init__ (self , model_file = None , config = None , training = None , conllu = None , alignment = None ):
407441 super ().__init__ (config = config or Config (), model = Model (model_file or config .args .model ),
408442 training = config .args .train if training is None else training ,
@@ -646,7 +680,7 @@ def read_graphs_with_progress_bar(file_handle_or_graphs):
646680 if isinstance (file_handle_or_graphs , IOBase ):
647681 graphs , _ = read_graphs (
648682 tqdm (file_handle_or_graphs , desc = "Reading " + getattr (file_handle_or_graphs , "name" , "input" ),
649- unit = " graphs" ), format = "mrp" )
683+ unit = " graphs" ), format = "mrp" , robust = True )
650684 return graphs
651685 return file_handle_or_graphs
652686
0 commit comments