Skip to content

Commit 0b3f860

Browse files
fix(pt): finetuning property/dipole/polar/dos fitting with multi-dimensional data causes error (#4145)
Fix issue #4108 If a pretrained model is labeled with energy and the `out_bias` is one dimension. If we want to finetune a dos/polar/dipole/property model using this pretrained model, the `out_bias` of finetuning model is multi-dimension(example: numb_dos = 250). An error occurs: `RuntimeError: Error(s) in loading state_dict for ModelWrapper:` ` size mismatch for model.Default.atomic_model.out_bias: copying a param with shape torch.Size([1, 118, 1]) from checkpoint, the shape in current model is torch.Size([1, 118, 250]).` ` size mismatch for model.Default.atomic_model.out_std: copying a param with shape torch.Size([1, 118, 1]) from checkpoint, the shape in current model is torch.Size([1, 118, 250]).` When using new fitting, old out_bias is useless because we will recompute the new bias in later code. So we do not need to load old out_bias when using new fitting finetune. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced parameter collection for fine-tuning, refining criteria for parameter retention. - Introduced a model checkpoint file for saving and resuming training states, facilitating iterative development. - **Tests** - Added a new test class to validate training and fine-tuning processes, ensuring model performance consistency across configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 508759c commit 0b3f860

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

deepmd/pt/train/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def collect_single_finetune_params(
484484
if i != "_extra_state" and f".{_model_key}." in i
485485
]
486486
for item_key in target_keys:
487-
if _new_fitting and ".fitting_net." in item_key:
487+
if _new_fitting and (".descriptor." not in item_key):
488488
# print(f'Keep {item_key} in old model!')
489489
_new_state_dict[item_key] = (
490490
_random_state_dict[item_key].clone().detach()

source/tests/pt/test_training.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,5 +448,73 @@ def tearDown(self) -> None:
448448
DPTrainTest.tearDown(self)
449449

450450

451+
class TestPropFintuFromEnerModel(unittest.TestCase):
452+
def setUp(self):
453+
input_json = str(Path(__file__).parent / "water/se_atten.json")
454+
with open(input_json) as f:
455+
self.config = json.load(f)
456+
data_file = [str(Path(__file__).parent / "water/data/data_0")]
457+
self.config["training"]["training_data"]["systems"] = data_file
458+
self.config["training"]["validation_data"]["systems"] = data_file
459+
self.config["model"] = deepcopy(model_dpa1)
460+
self.config["model"]["type_map"] = ["H", "C", "N", "O"]
461+
self.config["training"]["numb_steps"] = 1
462+
self.config["training"]["save_freq"] = 1
463+
464+
property_input = str(Path(__file__).parent / "property/input.json")
465+
with open(property_input) as f:
466+
self.config_property = json.load(f)
467+
prop_data_file = [str(Path(__file__).parent / "property/single")]
468+
self.config_property["training"]["training_data"]["systems"] = prop_data_file
469+
self.config_property["training"]["validation_data"]["systems"] = prop_data_file
470+
self.config_property["model"]["descriptor"] = deepcopy(model_dpa1["descriptor"])
471+
self.config_property["training"]["numb_steps"] = 1
472+
self.config_property["training"]["save_freq"] = 1
473+
474+
def test_dp_train(self):
475+
# test training from scratch
476+
trainer = get_trainer(deepcopy(self.config))
477+
trainer.run()
478+
state_dict_trained = trainer.wrapper.model.state_dict()
479+
480+
# test fine-tuning using diffferent fitting_net, here using property fitting
481+
finetune_model = self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
482+
self.config_property["model"], finetune_links = get_finetune_rules(
483+
finetune_model,
484+
self.config_property["model"],
485+
model_branch="RANDOM",
486+
)
487+
trainer_finetune = get_trainer(
488+
deepcopy(self.config_property),
489+
finetune_model=finetune_model,
490+
finetune_links=finetune_links,
491+
)
492+
493+
# check parameters
494+
state_dict_finetuned = trainer_finetune.wrapper.model.state_dict()
495+
for state_key in state_dict_finetuned:
496+
if (
497+
"out_bias" not in state_key
498+
and "out_std" not in state_key
499+
and "fitting" not in state_key
500+
):
501+
torch.testing.assert_close(
502+
state_dict_trained[state_key],
503+
state_dict_finetuned[state_key],
504+
)
505+
506+
# check running
507+
trainer_finetune.run()
508+
509+
def tearDown(self):
510+
for f in os.listdir("."):
511+
if f.startswith("model") and f.endswith(".pt"):
512+
os.remove(f)
513+
if f in ["lcurve.out"]:
514+
os.remove(f)
515+
if f in ["stat_files"]:
516+
shutil.rmtree(f)
517+
518+
451519
if __name__ == "__main__":
452520
unittest.main()

0 commit comments

Comments
 (0)