@@ -112,19 +112,20 @@ def build_metric(self, **kwargs):
112112
113113 def build_dataloader (self , data , transform : TransformList = None , training = False , device = None ,
114114 logger : logging .Logger = None , gradient_accumulation = 1 , ** kwargs ) -> DataLoader :
115- if isinstance (data , list ):
116- data = BiaffineSemanticDependencyParser .build_samples (self , data , self .config .use_pos )
117115 dataset = BiaffineSemanticDependencyParser .build_dataset (self , data , transform )
118116 if isinstance (data , str ):
119117 dataset .purge_cache ()
118+ length_field = 'token'
119+ else :
120+ length_field = 'FORM'
120121 if self .vocabs .mutable :
121122 BiaffineSemanticDependencyParser .build_vocabs (self , dataset , logger , transformer = True )
122123 if dataset .cache :
123124 timer = CountdownTimer (len (dataset ))
124125 BiaffineSemanticDependencyParser .cache_dataset (self , dataset , timer , training , logger )
125126 return PadSequenceDataLoader (
126- batch_sampler = self .sampler_builder .build (self .compute_lens (data , dataset ), shuffle = training ,
127- gradient_accumulation = gradient_accumulation ),
127+ batch_sampler = self .sampler_builder .build (self .compute_lens (data , dataset , length_field = length_field ) ,
128+ shuffle = training , gradient_accumulation = gradient_accumulation ),
128129 device = device ,
129130 dataset = dataset ,
130131 pad = self .get_pad_dict ())
@@ -167,3 +168,6 @@ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]
167168 deprels = [vocab [r [i ]] for i in range (sent_len + 1 ) if a [i ]]
168169 result .append (list (zip (heads , deprels )))
169170 yield result
171+
172+ def build_samples (self , inputs , cls_is_bos = False , sep_is_eos = False ):
173+ return BiaffineSemanticDependencyParser .build_samples (self , inputs , self .config .use_pos )
0 commit comments