Skip to content

Commit 199f3f3

Browse files
committed
Support reusing existing model and config for finetuning, fix: #1942
1 parent 81983df commit 199f3f3

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

hanlp/common/torch_component.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,13 @@ def fit(self,
246246
first_device = -1
247247
if _device_placeholder and first_device >= 0:
248248
_dummy_placeholder = self._create_dummy_placeholder_on(first_device)
249-
if finetune:
250-
if isinstance(finetune, str):
251-
self.load(finetune, devices=devices, **self.config)
252-
else:
253-
self.load(save_dir, devices=devices, **self.config)
254-
self.config.finetune = finetune
249+
if finetune or self.model:
250+
if not self.model:
251+
if isinstance(finetune, str):
252+
self.load(finetune, devices=devices, **self.config)
253+
else:
254+
self.load(save_dir, devices=devices, **self.config)
255+
self.config.finetune = finetune or True
255256
self.vocabs.unlock() # For extending vocabs
256257
logger.info(
257258
f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'

plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
cdroot()
1111

12+
# 0. Prepare your dataset for finetuning
1213
your_training_corpus = 'data/ner/finetune/word_to_iobes.tsv'
1314
your_development_corpus = your_training_corpus # Use a different one in reality
1415
save_dir = 'data/ner/finetune/model'
@@ -25,18 +26,22 @@
2526
'''
2627
)
2728

29+
# 1. Load a pretrained model for finetuning
2830
ner = TransformerNamedEntityRecognizer()
31+
ner.load(hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH)
32+
33+
# 2. Override hyper-parameters
34+
ner.config['epochs'] = 50 # Since the corpus is small, overfit it
35+
36+
# 3. Fit on your dataset
2937
ner.fit(
3038
trn_data=your_training_corpus,
3139
dev_data=your_development_corpus,
3240
save_dir=save_dir,
33-
epochs=50, # Since the corpus is small, overfit it
34-
finetune=hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH,
35-
# You MUST set the same parameters with the fine-tuning model:
36-
average_subwords=True,
37-
transformer='hfl/chinese-electra-180g-small-discriminator',
41+
**ner.config
3842
)
3943

44+
# 4. Test it out on your data points
4045
HanLP = hanlp.pipeline()\
4146
.append(hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH), output_key='tok')\
4247
.append(ner, output_key='ner')

0 commit comments

Comments
 (0)