Skip to content

Commit 4befbd1

Browse files
committed
Enable sdp as the first task
1 parent ee1effb commit 4befbd1

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

hanlp/components/mtl/tasks/sdp.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,20 @@ def build_metric(self, **kwargs):
112112

113113
def build_dataloader(self, data, transform: TransformList = None, training=False, device=None,
114114
logger: logging.Logger = None, gradient_accumulation=1, **kwargs) -> DataLoader:
115-
if isinstance(data, list):
116-
data = BiaffineSemanticDependencyParser.build_samples(self, data, self.config.use_pos)
117115
dataset = BiaffineSemanticDependencyParser.build_dataset(self, data, transform)
118116
if isinstance(data, str):
119117
dataset.purge_cache()
118+
length_field = 'token'
119+
else:
120+
length_field = 'FORM'
120121
if self.vocabs.mutable:
121122
BiaffineSemanticDependencyParser.build_vocabs(self, dataset, logger, transformer=True)
122123
if dataset.cache:
123124
timer = CountdownTimer(len(dataset))
124125
BiaffineSemanticDependencyParser.cache_dataset(self, dataset, timer, training, logger)
125126
return PadSequenceDataLoader(
126-
batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset), shuffle=training,
127-
gradient_accumulation=gradient_accumulation),
127+
batch_sampler=self.sampler_builder.build(self.compute_lens(data, dataset, length_field=length_field),
128+
shuffle=training, gradient_accumulation=gradient_accumulation),
128129
device=device,
129130
dataset=dataset,
130131
pad=self.get_pad_dict())
@@ -167,3 +168,6 @@ def prediction_to_result(self, prediction: Dict[str, Any], batch: Dict[str, Any]
167168
deprels = [vocab[r[i]] for i in range(sent_len + 1) if a[i]]
168169
result.append(list(zip(heads, deprels)))
169170
yield result
171+
172+
def build_samples(self, inputs, cls_is_bos=False, sep_is_eos=False):
173+
return BiaffineSemanticDependencyParser.build_samples(self, inputs, self.config.use_pos)

hanlp/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
# Author: hankcs
33
# Date: 2019-12-28 19:26
44

5-
__version__ = '2.1.0-alpha.55'
5+
__version__ = '2.1.0-alpha.56'
66
"""HanLP version"""

tests/test_mtl.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ def test_skip_tok(self):
3030
doc: Document = mtl(pre_tokenized_sents, skip_tasks='tok*')
3131
self.assertSequenceEqual(doc['tok'], pre_tokenized_sents)
3232

33+
def test_sdp_as_the_first_task(self):
34+
doc: Document = mtl(['人', '吃', '鱼'], tasks='sdp', skip_tasks='tok*')
35+
self.assertDictEqual(
36+
doc.to_dict(),
37+
{
38+
"sdp": [
39+
[(2, "Agt")],
40+
[(0, "Root")],
41+
[(2, "Pat")]
42+
],
43+
"tok": [
44+
"人",
45+
"吃",
46+
"鱼"
47+
]
48+
}
49+
)
50+
3351
def test_threading(self):
3452
num_proc = 8
3553
with Pool(num_proc) as pool:

0 commit comments

Comments
 (0)