Commit 0b3f860
fix(pt): finetuning property/dipole/polar/dos fitting with multi-dimensional data causes error (#4145)
Fix issue #4108
If a pretrained model is labeled with energy and the `out_bias` is one
dimension. If we want to finetune a dos/polar/dipole/property model
using this pretrained model, the `out_bias` of finetuning model is
multi-dimension(example: numb_dos = 250). An error occurs:
`RuntimeError: Error(s) in loading state_dict for ModelWrapper:`
` size mismatch for model.Default.atomic_model.out_bias: copying a param
with shape torch.Size([1, 118, 1]) from checkpoint, the shape in current
model is torch.Size([1, 118, 250]).`
` size mismatch for model.Default.atomic_model.out_std: copying a param
with shape torch.Size([1, 118, 1]) from checkpoint, the shape in current
model is torch.Size([1, 118, 250]).`
When using new fitting, old out_bias is useless because we will
recompute the new bias in later code. So we do not need to load old
out_bias when using new fitting finetune.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
- **New Features**
- Enhanced parameter collection for fine-tuning, refining criteria for
parameter retention.
- Introduced a model checkpoint file for saving and resuming training
states, facilitating iterative development.
- **Tests**
- Added a new test class to validate training and fine-tuning processes,
ensuring model performance consistency across configurations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent 508759c commit 0b3f860
2 files changed
+69
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
484 | 484 | | |
485 | 485 | | |
486 | 486 | | |
487 | | - | |
| 487 | + | |
488 | 488 | | |
489 | 489 | | |
490 | 490 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
448 | 448 | | |
449 | 449 | | |
450 | 450 | | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
| 478 | + | |
| 479 | + | |
| 480 | + | |
| 481 | + | |
| 482 | + | |
| 483 | + | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
| 489 | + | |
| 490 | + | |
| 491 | + | |
| 492 | + | |
| 493 | + | |
| 494 | + | |
| 495 | + | |
| 496 | + | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
| 504 | + | |
| 505 | + | |
| 506 | + | |
| 507 | + | |
| 508 | + | |
| 509 | + | |
| 510 | + | |
| 511 | + | |
| 512 | + | |
| 513 | + | |
| 514 | + | |
| 515 | + | |
| 516 | + | |
| 517 | + | |
| 518 | + | |
451 | 519 | | |
452 | 520 | | |
0 commit comments