-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pt): DPLR in PyTorch backend #5138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds a new DipoleChargeModifier (with serialization, WFCC model inference, and Ewald reciprocal corrections), re-exports it, changes PyTorch dependency handling/pinning and model/modifier serialization (pickle-based), updates related tests/fixtures, and adjusts backend PyTorch requirement API and a few argument docs. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant DipoleMod as DipoleChargeModifier
participant WFCC as WFCC Model (TorchScript)
participant Ewald as CoulombForceModule
participant BaseMod as BaseModifier (serialize/deserialize)
Caller->>DipoleMod: forward(coords, box, atype)
alt box is None
DipoleMod-->>Caller: raise RuntimeError
else
DipoleMod->>WFCC: request per-frame WFCC dipoles (batched)
WFCC-->>DipoleMod: dipole vectors (WFCC coords)
DipoleMod->>DipoleMod: build extended_coord & extended_charge
loop per batch/frame
DipoleMod->>Ewald: compute reciprocal-space correction
Ewald-->>DipoleMod: energy/force/virial contributions
end
DipoleMod-->>Caller: aggregated energy, force, virial
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @deepmd/pt/modifier/dipole_charge.py:
- Around line 141-144: Replace the incorrect use of RuntimeWarning in the "if
box is None:" check with a real exception (e.g., raise
RuntimeError("dipole_charge data modifier can only be applied for periodic
systems.")) inside the dipole_charge modifier
(deepmd.pt.modifier.dipole_charge), or alternatively call warnings.warn(...) if
you want non-fatal behavior; also update the test(s) that expect a
RuntimeWarning to assertRaises(RuntimeError) (or the equivalent check if you
choose warnings.warn) so tests reflect the change.
In @source/tests/pt/modifier/test_dipole_charge.py:
- Around line 173-190: The code currently raises RuntimeWarning directly when
box is None (exercised by test_box_none_warning calling self.dm_pt), which is
semantically incorrect; change the modifier implementation that currently "raise
RuntimeWarning(...)" to either issue a warnings.warn("...", RuntimeWarning) (and
update the test to use assertWarns or catch the warning) or raise a proper
exception type such as ValueError/RuntimeError (and update test_box_none_warning
to assertRaises(ValueError/RuntimeError) accordingly); locate the check in the
dipole_charge modifier method (the function/method that accepts coord, atype,
box and is invoked via self.dm_pt) and replace the direct raise with the chosen
approach and adjust the test assertion to match.
🧹 Nitpick comments (4)
deepmd/pt/modifier/dipole_charge.py (4)
38-41: Unused parameters in__new__method.The
*argsand**kwargsparameters are not used and not passed tosuper().__new__(). If these are present for interface compatibility, consider documenting why. Otherwise, they can be removed.♻️ Potential simplification
- def __new__( - cls, *args: tuple, model_name: str | None = None, **kwargs: dict - ) -> "DipoleChargeModifier": + def __new__(cls, model_name: str | None = None) -> "DipoleChargeModifier": return super().__new__(cls, model_name)
101-109: Unused parameterdo_atomic_virial.The
do_atomic_virialparameter is declared but never used in the method. If this is kept for interface compatibility withBaseModifier, add a comment explaining this. Otherwise, consider removing it.📝 Add clarifying comment
def forward( self, coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + # Kept for interface compatibility with BaseModifier do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]:
255-317: Frame-by-frame processing is clear but potentially inefficient.The method correctly computes WFCC coordinates by adding dipole vectors to atomic positions for selected types. However, the frame-by-frame loop (lines 294-304) prevents batching, which could impact performance for large batch sizes.
If the model supports batched inference, consider refactoring to process all frames at once:
♻️ Potential optimization
# Instead of looping over frames: dipole_batch = self.model( coord=coord, # All frames atype=atype, box=box, do_atomic_virial=False, fparam=fparam, aparam=aparam, ) dipole = dipole_batch["dipole"]Only implement if the model interface supports batched inference and profiling shows this is a bottleneck.
320-347: Clear and correct mask creation.The function properly creates a boolean mask for selected atom types. The loop-based approach is readable and correct.
♻️ Optional vectorization
For a more vectorized approach (though the current implementation is fine):
- # Create mask using broadcasting - mask = torch.zeros_like(atype, dtype=torch.bool) - for t in sel_type: - mask = mask | (atype == t) + # Create mask using broadcasting + mask = (atype.unsqueeze(-1) == sel_type.unsqueeze(0)).any(dim=-1) return maskOnly consider this if profiling shows the loop is a bottleneck. The current implementation is clearer.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
deepmd/pt/modifier/__init__.pydeepmd/pt/modifier/dipole_charge.pypyproject.tomlsource/tests/pt/modifier/__init__.pysource/tests/pt/modifier/test_data_modifier.pysource/tests/pt/modifier/test_dipole_charge.pysource/tests/pt/modifier/watersource/tests/pt/modifier/water_tensor
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-14T07:11:51.357Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4884
File: .github/workflows/test_cuda.yml:46-46
Timestamp: 2025-08-14T07:11:51.357Z
Learning: As of PyTorch 2.8 (August 2025), the default wheel on PyPI installed by `pip install torch` is CPU-only. CUDA-enabled wheels are available on PyPI for Linux x86 and Windows x86 platforms, but require explicit specification via index URLs or variant-aware installers. For CUDA support, use `--index-url https://download.pytorch.org/whl/cu126` (or appropriate CUDA version).
Applied to files:
pyproject.toml
🧬 Code graph analysis (2)
deepmd/pt/modifier/dipole_charge.py (1)
deepmd/tf/modifier/dipole_charge.py (1)
DipoleChargeModifier(40-546)
deepmd/pt/modifier/__init__.py (1)
deepmd/pt/modifier/dipole_charge.py (1)
DipoleChargeModifier(23-317)
🪛 Ruff (0.14.10)
deepmd/pt/modifier/dipole_charge.py
39-39: Unused static method argument: args
(ARG004)
39-39: Unused static method argument: kwargs
(ARG004)
108-108: Unused method argument: do_atomic_virial
(ARG002)
142-144: Avoid specifying long messages outside the exception class
(TRY003)
source/tests/pt/modifier/test_dipole_charge.py
175-175: Unpacked variable box is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (28)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (17)
deepmd/pt/modifier/__init__.py (1)
10-12: LGTM!The import and export of
DipoleChargeModifierfollows the established pattern in this module and properly exposes the new modifier class for external use.Also applies to: 16-16
source/tests/pt/modifier/test_data_modifier.py (1)
55-57: This import path change is incorrect.The file
source/consistent/does not exist in the repository. Theparameterizedsymbol is defined insource/tests/consistent/common.py. The original import path..consistent.commoncorrectly resolves tosource/tests/consistent/common, while the changed path...consistent.commonattempts to resolve to the non-existentsource/consistent/common. Revert to the correct relative import depth of..consistent.common.Likely an incorrect or invalid review comment.
pyproject.toml (1)
167-167: Verify the torch_admp dependency availability and consider version stability.The new dependency uses a Git URL with an alpha version tag (
v1.1.0a). Please ensure:
- The repository and tag are publicly accessible and stable
- Using an alpha version is appropriate for production use
- Consider pinning to a specific commit hash for better reproducibility
This applies to both lines 167 and 171 where
torch_admpis added.source/tests/pt/modifier/__init__.py (1)
1-1: LGTM!Standard SPDX license header correctly identifies the LGPL-3.0-or-later license.
source/tests/pt/modifier/test_dipole_charge.py (8)
1-30: LGTM!Import organization is clean and all dependencies are appropriate for the test suite. The use of a fixed seed ensures reproducible test results.
33-48: LGTM!The helper function correctly loads test data with deterministic random frame selection using a seeded RNG. Proper reshaping ensures batch dimension compatibility.
52-108: Comprehensive integration test setup.The setUp method performs end-to-end model training, freezing, and backend conversion to create test fixtures. While heavyweight, this approach ensures realistic cross-framework testing. The minimal training (1 step) keeps test duration reasonable.
110-111: LGTM!Simple and effective test for JIT scripting compatibility.
113-139: LGTM!Cross-framework consistency test properly validates that PyTorch and TensorFlow modifiers produce equivalent results for energy, force, and virial predictions. The tolerance of
rtol=1e-4is appropriate for numerical comparisons.
141-171: LGTM!Serialization round-trip test correctly validates that the modifier can be serialized and deserialized without loss of functionality.
192-209: LGTM!Training workflow test ensures the modifier integrates correctly with the training pipeline. Minimal training steps keep the test efficient while validating functionality.
211-220: LGTM!Thorough cleanup of generated artifacts. The straightforward
os.remove()approach is appropriate for test teardown.deepmd/pt/modifier/dipole_charge.py (5)
1-36: LGTM!Clean import organization and proper modifier registration. The docstring clearly documents the class parameters including Ewald summation settings and charge mappings.
81-99: LGTM!Serialization method correctly captures all parameters needed to reconstruct the modifier. Using the list versions of charge maps ensures JSON compatibility.
199-253: LGTM!The
extend_systemmethod correctly constructs extended charge arrays by combining ion charges and WFCC charges for selected atom types. The masking logic properly filters selected atoms.
43-79: LGTM with note on placeholder tensors.Initialization logic correctly loads the model, extracts necessary parameters, and sets up the Ewald summation module. The placeholder tensors (lines 75-79) are intentionally passed to the
CoulombForceModulewhenrspace=False—the shapes(1, 2)for pairs,(1,)for distances, and buffer scales are used in the reciprocal space calculation.Note:
torch_admpis not publicly documented, so verification against the full API specification requires access to internal project documentation.
145-197: Verify coordinate transformation and virial computation against TensorFlow implementation.The
sfactorcomputation usingdetached_boxis mathematically sound for maintaining gradient flow: it uses a fixed transformation matrix while allowing gradients throughinput_box.However, the PT implementation (lines 145-197) shows only Ewald energy gradient contributions to the virial (
tot_v = grad(tot_e, input_box)followed by-tot_v.T @ input_box). The TensorFlow reference implementation explicitly includes three virial components:
all_v: Ewald reciprocal contributions (analogous to PT)corr_v: DNN model force corrections from_eval_fvfd_corr_v: Force-dipole correction (-ext_f3 @ dipole3)The visible PT code does not show the explicit force corrections and force-dipole corrections present in TF. Confirm whether the automatic differentiation approach through
torch_admpimplicitly captures these additional virial terms or if the implementation is missing corrections relative to the TensorFlow version.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5138 +/- ##
==========================================
+ Coverage 81.92% 82.00% +0.07%
==========================================
Files 714 714
Lines 73301 73227 -74
Branches 3616 3616
==========================================
- Hits 60055 60052 -3
+ Misses 12083 12011 -72
- Partials 1163 1164 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
deepmd/pt/modifier/dipole_charge.py (3)
38-41: Consider prefixing unused arguments with underscores.The
argsandkwargsparameters in__new__are unused. While they may be needed for base class compatibility, prefixing with underscores clarifies intent:♻️ Suggested change
def __new__( - cls, *args: tuple, model_name: str | None = None, **kwargs: dict + cls, *_args: tuple, model_name: str | None = None, **_kwargs: dict ) -> "DipoleChargeModifier": return super().__new__(cls, model_name)
108-108: Unuseddo_atomic_virialparameter should be documented or implemented.The
do_atomic_virialparameter is accepted but not used in the implementation. Either implement atomic virial support or document that it's not supported:♻️ Option 1: Document the limitation
def forward( self, coord: torch.Tensor, atype: torch.Tensor, box: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, - do_atomic_virial: bool = False, + do_atomic_virial: bool = False, # noqa: ARG002 - atomic virial not implemented ) -> dict[str, torch.Tensor]:♻️ Option 2: Raise if atomic virial is requested
if box is None: raise RuntimeError( "dipole_charge data modifier can only be applied for periodic systems." ) + if do_atomic_virial: + raise NotImplementedError( + "Atomic virial is not supported by dipole_charge modifier." + ) else:
293-307: Per-frame model inference may be a performance bottleneck.Similar to the Ewald loop, the per-frame model inference in
extend_system_coordcould impact performance. If the underlying model supports batched inference, this could be optimized.pyproject.toml (1)
167-167: Consider publishingtorch_admpto PyPI for broader compatibility and improved reliability.The Git URL dependency with tag v1.1.1 is currently valid and pinned correctly. However, relying on Git-based dependencies introduces ongoing risks:
- Package managers cannot cache these packages as effectively as PyPI releases
- If the repository becomes unavailable in the future, builds will fail despite the pinned tag
- Installation requires git client and is less resilient to repository hosting changes
For a production dependency, publishing
torch_admpto PyPI would be the more maintainable approach. If staying with Git-based installation, document the rationale for this decision.Also applies to: 171-171
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/modifier/dipole_charge.pypyproject.tomlsource/tests/pt/modifier/test_dipole_charge.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
📚 Learning: 2025-08-14T07:11:51.357Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4884
File: .github/workflows/test_cuda.yml:46-46
Timestamp: 2025-08-14T07:11:51.357Z
Learning: As of PyTorch 2.8 (August 2025), the default wheel on PyPI installed by `pip install torch` is CPU-only. CUDA-enabled wheels are available on PyPI for Linux x86 and Windows x86 platforms, but require explicit specification via index URLs or variant-aware installers. For CUDA support, use `--index-url https://download.pytorch.org/whl/cu126` (or appropriate CUDA version).
Applied to files:
pyproject.toml
🧬 Code graph analysis (2)
deepmd/pt/modifier/dipole_charge.py (2)
source/api_cc/src/common.cc (1)
sel_type(80-80)deepmd/tf/modifier/dipole_charge.py (1)
DipoleChargeModifier(40-546)
source/tests/pt/modifier/test_dipole_charge.py (1)
deepmd/pt/modifier/dipole_charge.py (2)
DipoleChargeModifier(23-317)serialize(81-99)
🪛 Ruff (0.14.10)
deepmd/pt/modifier/dipole_charge.py
39-39: Unused static method argument: args
(ARG004)
39-39: Unused static method argument: kwargs
(ARG004)
108-108: Unused method argument: do_atomic_virial
(ARG002)
142-144: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Analyze (python)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (12)
source/tests/pt/modifier/test_dipole_charge.py (7)
33-48: LGTM!The
ref_data()helper function provides reproducible test data loading with proper seeding and path handling.
51-108: LGTM!The setUp method correctly orchestrates the training, freezing, and backend conversion workflow to create both PyTorch and TensorFlow modifiers for cross-framework consistency testing.
110-111: LGTM!JIT compilation compatibility test ensures the modifier can be scripted for production deployment.
113-139: LGTM!Cross-framework consistency tests ensure PyTorch and TensorFlow implementations produce numerically equivalent results within appropriate tolerance.
141-171: LGTM!Serialization round-trip test validates that the modifier can be saved and restored with identical behavior.
173-190: LGTM!The test correctly validates that
RuntimeErroris raised for non-periodic systems, which is the appropriate exception type for this error condition.
192-220: LGTM!Training workflow test and tearDown cleanup ensure the modifier integrates properly with the training pipeline and leaves no artifacts.
deepmd/pt/modifier/dipole_charge.py (5)
1-19: LGTM!Imports are appropriately organized with the new
torch_admpdependency for Coulomb force calculations.
81-99: LGTM!Serialization method follows the established pattern and correctly uses raw Python lists for JSON compatibility.
199-253: LGTM!The WFCC extension logic correctly assigns charges to selected atom types and concatenates with ion charges.
320-346: LGTM!The
make_maskfunction is correctly implemented with JIT export support, enabling use in scripted modules.
168-181: Consider verifying batch operation support for CoulombForceModule.The per-frame loop for Ewald reciprocal energy computation (lines 168-181) processes each frame individually and concatenates results. If the underlying
CoulombForceModulefrom torch-admp supports batched operations, this loop could potentially be refactored to improve batch performance. However, this requires checking the torch-admp library's API to confirm batch support feasibility.
2c83995 to
ba65a54
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (5)
deepmd/pt/modifier/dipole_charge.py (3)
38-41: Consider prefixing unused parameters with underscore.The
argsandkwargsparameters are not used in__new__. While they may be present for API compatibility, prefixing them with underscores (*_args,**_kwargs) would silence the static analysis warning and clarify intent.♻️ Suggested fix
def __new__( - cls, *args: tuple, model_name: str | None = None, **kwargs: dict + cls, *_args: tuple, model_name: str | None = None, **_kwargs: dict ) -> "DipoleChargeModifier": return super().__new__(cls, model_name)
168-181: Consider documenting the per-frame loop performance implications.The frame-by-frame loop for Ewald reciprocal energy computation may become a performance bottleneck for large batch sizes. If the
CoulombForceModuleAPI supports batched operations, consider refactoring for better throughput. Otherwise, a brief comment explaining this limitation would be helpful for future maintainers.
293-308: Consider replacing assert with explicit validation.The
assertstatement on line 308 will be skipped if Python runs with optimization flags (-O). For critical shape validation, consider using an explicitifcheck with a descriptive error.Also, similar to
forward, the per-frame model inference loop could be a performance consideration for large batches.♻️ Suggested fix
# nframe x natoms x 3 dipole = torch.cat(all_dipole, dim=0) - assert dipole.shape[0] == nframes + if dipole.shape[0] != nframes: + raise RuntimeError( + f"Dipole shape mismatch: expected {nframes} frames, got {dipole.shape[0]}" + )source/tests/pt/modifier/test_dipole_charge.py (2)
173-190: Consider renaming the test method.The test correctly validates the
RuntimeErrorfor non-periodic systems. However, the method nametest_box_none_warningis slightly misleading since it now tests for an error, not a warning. Consider renaming totest_box_none_errorortest_box_none_raisesfor clarity.
211-220: Consider using a temporary directory for test artifacts.The
tearDownmethod manually cleans up files from the current working directory. Usingtempfile.TemporaryDirectoryorpytest'stmp_pathfixture would provide automatic cleanup and avoid potential conflicts with other tests or leftover files if the test fails mid-execution.♻️ Alternative approach using tempfile
import tempfile class TestDipoleChargeModifier(unittest.TestCase): def setUp(self) -> None: self.test_dir = tempfile.TemporaryDirectory() self.orig_dir = os.getcwd() os.chdir(self.test_dir.name) # ... rest of setUp ... def tearDown(self) -> None: os.chdir(self.orig_dir) self.test_dir.cleanup()
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/modifier/dipole_charge.pypyproject.tomlsource/tests/pt/modifier/test_dipole_charge.py
🚧 Files skipped from review as they are similar to previous changes (1)
- pyproject.toml
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
🧬 Code graph analysis (2)
deepmd/pt/modifier/dipole_charge.py (1)
deepmd/tf/modifier/dipole_charge.py (1)
DipoleChargeModifier(40-546)
source/tests/pt/modifier/test_dipole_charge.py (2)
deepmd/entrypoints/convert_backend.py (1)
convert_backend(11-31)deepmd/pt/modifier/dipole_charge.py (2)
DipoleChargeModifier(23-317)serialize(81-99)
🪛 Ruff (0.14.10)
deepmd/pt/modifier/dipole_charge.py
39-39: Unused static method argument: args
(ARG004)
39-39: Unused static method argument: kwargs
(ARG004)
108-108: Unused method argument: do_atomic_virial
(ARG002)
142-144: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (36)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (12)
deepmd/pt/modifier/dipole_charge.py (5)
1-19: LGTM!The imports are well-organized and appropriately include the necessary dependencies from
torch_admpfor Ewald summation and local utilities.
43-79: LGTM!The initialization correctly loads the model, extracts necessary parameters, and sets up the Ewald reciprocal space module. The dual storage of charge maps (as tensors and lists) properly supports both computation and serialization.
81-99: LGTM!The serialization method properly captures all constructor parameters needed for deserialization. Using the Python list versions (
_model_charge_map,_sys_charge_map) ensures JSON compatibility.
199-253: LGTM!The
extend_systemmethod correctly constructs extended coordinates and charges by combining ion charges with WFCC charges for selected atom types.
320-347: LGTM!The
make_maskutility is correctly implemented with JIT compatibility. The loop-based approach ensures compatibility with TorchScript.source/tests/pt/modifier/test_dipole_charge.py (7)
1-30: LGTM!The imports are comprehensive, covering all necessary modules for testing both PyTorch and TensorFlow backends, along with conversion utilities.
33-48: LGTM!The
ref_datahelper provides deterministic test data using a fixed seed, ensuring reproducible test results.
51-108: LGTM!The
setUpmethod properly configures and trains a minimal model for integration testing, enabling cross-backend consistency validation. The conversion from PyTorch to TensorFlow format ensures both modifiers use equivalent underlying models.
110-111: LGTM!The JIT compatibility test correctly validates that the modifier can be scripted without errors.
113-139: LGTM!The consistency test comprehensively validates that PyTorch and TensorFlow backends produce equivalent results for energy, force, and virial within acceptable numerical tolerance.
141-171: LGTM!The serialization test correctly validates the round-trip serialize/deserialize workflow, ensuring the modifier produces identical results before and after serialization.
192-209: LGTM!The training workflow smoke test validates that the training pipeline runs without errors. This is appropriate for integration testing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In @deepmd/pt/modifier/dipole_charge.py:
- Around line 248-249: The loop in dipole_charge.py that assigns wc_charge uses
parallel indexing of model_charge_map and sel_type (used in the for ii, charge
in enumerate(self.model_charge_map) loop), which can raise IndexError if lengths
differ; add a validation in the class __init__ to check len(model_charge_map) ==
len(sel_type) and raise a ValueError with a clear message if not, so the
constructor enforces correspondence before code later (including the wc_charge
assignment) runs.
- Around line 38-41: In DipoleChargeModifier's __new__ method remove the extra
model_name positional argument passed to super().__new__; change the call to
invoke super().__new__(cls) only (do not forward model_name), since
object.__new__ rejects extra args and model_name is handled in __init__.
- Around line 4-9: The imports from torch_admp (CoulombForceModule, calc_grads)
reference a non-published package; either declare torch_admp as an internal
dependency or replace/remove the imports: if torch_admp is an internal/private
repo, add it to project dependencies (requirements/dev docs or pyproject) and
reference the correct package name; if there is a public alternative, change the
imports to that package and update usages of CoulombForceModule and calc_grads
accordingly; if these symbols are no longer needed, remove the imports and
delete or refactor any code that calls CoulombForceModule or calc_grads (search
for CoulombForceModule and calc_grads to update all usages).
- Around line 348-350: Replace the manual loop that builds mask from atype and
sel_type with a JIT-safe broadcasting expression: instead of iterating over
sel_type and OR-ing comparisons, compute mask = (atype.unsqueeze(-1) ==
sel_type).any(dim=-1) so the boolean mask is built via broadcasting (use the
existing variables mask, atype, sel_type) which avoids torch.isin and ensures
compatibility with torch.jit/export and accelerators.
🧹 Nitpick comments (1)
deepmd/pt/modifier/dipole_charge.py (1)
101-109: Document the interface compatibility reason for unused parameter.The
do_atomic_virialparameter is documented as "not implemented" but kept in the signature. While this likely maintains interface compatibility withBaseModifier, the docstring could be clearer about this being reserved for future implementation or required by the base class contract.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/modifier/dipole_charge.pysource/tests/pt/modifier/test_dipole_charge.py
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/modifier/test_dipole_charge.py
🧰 Additional context used
🪛 Ruff (0.14.10)
deepmd/pt/modifier/dipole_charge.py
39-39: Unused static method argument: args
(ARG004)
39-39: Unused static method argument: kwargs
(ARG004)
108-108: Unused method argument: do_atomic_virial
(ARG002)
143-145: Avoid specifying long messages outside the exception class
(TRY003)
310-312: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (9)
deepmd/pt/modifier/dipole_charge.py (9)
43-79: LGTM!The initialization properly loads the pre-trained model, converts charge maps to tensors, and configures the Ewald summation module. The dual storage of charge maps (as both lists and tensors) correctly supports serialization requirements.
81-99: LGTM!The serialization method correctly captures all necessary state including model path, charge maps, and Ewald parameters, enabling proper deserialization.
142-145: LGTM!The periodic boundary check is correct and necessary. The modifier appropriately requires periodic systems for Ewald summation to work properly.
151-159: LGTM! Clever gradient tracking setup.The coordinate transformation using
sfactorcorrectly enables gradient tracking through the box tensor for proper virial computation. The detached box inverse prevents unwanted gradient accumulation while the gradient-enabled box multiplication ensures forces and virials are computed correctly.
184-198: LGTM! Correct force and virial computation.The force computation using negative gradients (line 184) and the virial calculation via chain rule through the box tensor (lines 188-193) are both mathematically correct for molecular dynamics applications.
232-254: Approve charge assignment logic with reservation.The charge mapping logic correctly distinguishes between ion charges (from
sys_charge_map) and WFCC charges (frommodel_charge_map), then properly filters and concatenates them. The tensor indexing operations are efficient.
314-321: LGTM! Correct WFCC coordinate computation.The masking operation correctly filters WFCC coordinates for selected atom types. The use of
unsqueeze(-1).expand_as()properly broadcasts the boolean mask across the coordinate dimension, and the subsequent reshape handles the flattened result appropriately.
169-182: Verify batching capability of CoulombForceModule before optimization.The suggestion to batch process frames depends on whether
torch_admp.CoulombForceModulesupports batched input. The externaltorch-admppackage (https://github.com/chiahsinchu/[email protected]) is not publicly documented, making it difficult to confirm whether the module's API accepts multi-frame tensors simultaneously. Before implementing this optimization, verify with the package maintainers or source code whether batching is supported and would measurably improve performance for your typical frame counts.
294-306: Verify batched model inference capability. The frame-wise loop reshapes each frame to batch size 1 before callingself.model(). If the model supports processing multiple frames in a single call, this could be optimized to pass all frames together. However, this depends on whether the model handles batchedfparam/aparamcorrectly and whether there are constraints from the Ewald reciprocal energy calculation (which also processes frames individually at lines 171–180).
| def __new__( | ||
| cls, *args: tuple, model_name: str | None = None, **kwargs: dict | ||
| ) -> "DipoleChargeModifier": | ||
| return super().__new__(cls, model_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find BaseModifier definition and examine __new__ method
find . -type f -name "*.py" | xargs grep -l "class BaseModifier"Repository: deepmodeling/deepmd-kit
Length of output: 185
🏁 Script executed:
#!/bin/bash
# Search for BaseModifier __new__ implementation
rg "class BaseModifier|def __new__" -A 10 --type pyRepository: deepmodeling/deepmd-kit
Length of output: 18120
🏁 Script executed:
#!/bin/bash
# Check the inheritance chain and __init__ patterns in the modifier module
find . -type f -name "*.py" -path "*/modifier/*" | head -20Repository: deepmodeling/deepmd-kit
Length of output: 502
🏁 Script executed:
#!/bin/bash
# Find make_base_modifier implementation in pt/modifier
rg "def make_base_modifier|class.*make_base_modifier" -A 20 deepmd/pt/modifier/Repository: deepmodeling/deepmd-kit
Length of output: 1422
🏁 Script executed:
#!/bin/bash
# Look for make_base_modifier in the entire pt directory
find deepmd/pt -name "*.py" -type f | xargs grep -l "make_base_modifier"Repository: deepmodeling/deepmd-kit
Length of output: 102
🏁 Script executed:
#!/bin/bash
# Also check if there's a base modifier factory or mixin that provides __new__
rg "def __new__" deepmd/pt/modifier/base_modifier.py -B 5 -A 15Repository: deepmodeling/deepmd-kit
Length of output: 49
🏁 Script executed:
#!/bin/bash
# Search for make_base_modifier definition - likely in dpmodel
rg "def make_base_modifier" -A 30Repository: deepmodeling/deepmd-kit
Length of output: 2208
🏁 Script executed:
#!/bin/bash
# Check imports in pt/modifier/base_modifier.py to find make_base_modifier source
head -50 deepmd/pt/modifier/base_modifier.pyRepository: deepmodeling/deepmd-kit
Length of output: 1224
🏁 Script executed:
#!/bin/bash
# Try to instantiate DipoleChargeModifier to see if the __new__ pattern causes errors
# First check the test file for how it's instantiated
cat deepmd/pt/modifier/dipole_charge.py | head -50Repository: deepmodeling/deepmd-kit
Length of output: 1301
🏁 Script executed:
#!/bin/bash
# Verify the exact parent __new__ behavior by looking at the full make_base_modifier code
sed -n '1,100p' deepmd/dpmodel/modifier/base_modifier.py | grep -A 10 "def __new__"Repository: deepmodeling/deepmd-kit
Length of output: 415
🏁 Script executed:
#!/bin/bash
# Check if there are any test cases that actually call DipoleChargeModifier instantiation
grep -n "DipoleChargeModifier(" source/tests/pt/modifier/test_dipole_charge.py | head -5Repository: deepmodeling/deepmd-kit
Length of output: 217
🏁 Script executed:
#!/bin/bash
# Run the actual test to confirm if it fails
cat source/tests/pt/modifier/test_dipole_charge.py | sed -n '95,120p'Repository: deepmodeling/deepmd-kit
Length of output: 871
Remove model_name argument from super().__new__() call — it will cause a TypeError at runtime.
The __new__ method passes model_name to the parent's __new__, but the parent implementation calls super().__new__(cls) with only the cls argument. This cascades to object.__new__(cls), which rejects extra positional arguments and will raise a TypeError during instantiation.
Change line 41 from:
return super().__new__(cls, model_name)
to:
return super().__new__(cls)
The model_name parameter is properly handled in __init__, not __new__.
🧰 Tools
🪛 Ruff (0.14.10)
39-39: Unused static method argument: args
(ARG004)
39-39: Unused static method argument: kwargs
(ARG004)
🤖 Prompt for AI Agents
In @deepmd/pt/modifier/dipole_charge.py around lines 38 - 41, In
DipoleChargeModifier's __new__ method remove the extra model_name positional
argument passed to super().__new__; change the call to invoke
super().__new__(cls) only (do not forward model_name), since object.__new__
rejects extra args and model_name is handled in __init__.
7984900 to
b0b0047
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
source/tests/pt/modifier/test_dipole_charge.py (2)
34-49: Consider usingsetUpClassfor expensive setup operations.The
setUpmethod trains a model, freezes it, and converts backends for every test method. This is computationally expensive. Since the trained model state is not modified by the tests, consider usingsetUpClassto run this setup once per test class.♻️ Suggested refactor
-class TestDipoleChargeModifier(unittest.TestCase): - def setUp(self) -> None: - self.test_dir = tempfile.TemporaryDirectory() - self.orig_dir = os.getcwd() - os.chdir(self.test_dir.name) +class TestDipoleChargeModifier(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + cls.test_dir = tempfile.TemporaryDirectory() + cls.orig_dir = os.getcwd() + os.chdir(cls.test_dir.name) # ... rest of setup ... + + @classmethod + def tearDownClass(cls) -> None: + os.chdir(cls.orig_dir) + cls.test_dir.cleanup()
196-213: Consider adding an assertion to validate the training outcome.The test verifies the training workflow runs without errors, but adding at least one assertion (e.g., checking that model checkpoint was created) would strengthen the test.
💡 Optional enhancement
trainer = get_trainer(config) trainer.run() + # Verify model checkpoint was created + self.assertTrue(os.path.exists("model.ckpt.pt"))deepmd/pt/modifier/dipole_charge.py (1)
301-312: Per-frame model inference loop may be inefficient.Similar to the Ewald loop, the dipole model is called per-frame. If the underlying model supports batched inference (which is typical for PyTorch models), this could be vectorized for better performance.
💡 Potential optimization
# Instead of per-frame loop, try batched inference: dipole_batch = self.model( coord=coord, atype=atype, box=box, do_atomic_virial=False, fparam=fparam, aparam=aparam, ) dipole = dipole_batch["dipole"]If the model doesn't support this shape, the loop is necessary.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
deepmd/pt/modifier/__init__.pydeepmd/pt/modifier/dipole_charge.pypyproject.tomlsource/tests/pt/modifier/__init__.pysource/tests/pt/modifier/test_data_modifier.pysource/tests/pt/modifier/test_dipole_charge.pysource/tests/pt/modifier/watersource/tests/pt/modifier/water_tensor
✅ Files skipped from review due to trivial changes (2)
- source/tests/pt/modifier/water
- source/tests/pt/modifier/water_tensor
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/pt/modifier/init.py
- source/tests/pt/modifier/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-10-08T15:32:11.479Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4144
File: source/api_cc/tests/test_deeppot_dpa_pt.cc:166-246
Timestamp: 2024-09-19T04:25:12.408Z
Learning: Refactoring between test classes `TestInferDeepPotDpaPt` and `TestInferDeepPotDpaPtNopbc` is addressed in PR #3905.
📚 Learning: 2025-08-14T07:11:51.357Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4884
File: .github/workflows/test_cuda.yml:46-46
Timestamp: 2025-08-14T07:11:51.357Z
Learning: As of PyTorch 2.8 (August 2025), the default wheel on PyPI installed by `pip install torch` is CPU-only. CUDA-enabled wheels are available on PyPI for Linux x86 and Windows x86 platforms, but require explicit specification via index URLs or variant-aware installers. For CUDA support, use `--index-url https://download.pytorch.org/whl/cu126` (or appropriate CUDA version).
Applied to files:
pyproject.toml
🧬 Code graph analysis (2)
deepmd/pt/modifier/dipole_charge.py (1)
deepmd/tf/modifier/dipole_charge.py (1)
DipoleChargeModifier(40-546)
source/tests/pt/modifier/test_dipole_charge.py (1)
deepmd/pt/modifier/dipole_charge.py (2)
DipoleChargeModifier(23-328)serialize(88-106)
🪛 Ruff (0.14.10)
deepmd/pt/modifier/dipole_charge.py
39-39: Unused static method argument: args
(ARG004)
39-39: Unused static method argument: kwargs
(ARG004)
68-71: Avoid specifying long messages outside the exception class
(TRY003)
115-115: Unused method argument: do_atomic_virial
(ARG002)
150-152: Avoid specifying long messages outside the exception class
(TRY003)
317-319: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
🔇 Additional comments (13)
source/tests/pt/modifier/test_dipole_charge.py (4)
114-115: LGTM!Simple and effective JIT compatibility check. If scripting fails,
torch.jit.scriptwill raise an exception.
117-143: LGTM!The consistency test properly validates cross-framework behavior by comparing energy, force, and virial outputs between PyTorch and TensorFlow implementations with appropriate tolerances.
145-175: LGTM!The serialization round-trip test validates that the modifier's state is correctly preserved through serialize/deserialize operations.
177-194: LGTM!Good error handling test that validates the
RuntimeErroris raised with the expected message when a non-periodic system (None box) is passed.deepmd/pt/modifier/dipole_charge.py (7)
38-41: Unused arguments in__new__are intentional for inheritance compatibility.The static analysis correctly identifies
*argsand**kwargsas unused. This pattern is common in__new__overrides to maintain compatibility with parent class signatures. No change needed.
43-86: LGTM!The initialization is well-structured with proper validation of charge map lengths and correct device placement for tensors.
88-106: LGTM!Serialization correctly preserves all necessary parameters using Python-native types for JSON compatibility.
115-115: Unuseddo_atomic_virialis documented as not implemented.The docstring at line 138-139 explicitly notes this parameter is not implemented. This is acceptable.
249-260: LGTM!The charge extension logic correctly assigns ion charges based on atom types and WFCC charges for selected atom types.
331-356: LGTM!The
make_maskfunction is correctly implemented with JIT export compatibility and efficient broadcasting for atom type selection.
176-189: Verify batching support in torch_admp CoulombForceModule before suggesting vectorization.The per-frame loop pattern here (lines 178–187) may be a performance bottleneck if batched operations are available. However,
CoulombForceModuleis from the externaltorch_admppackage, which is not publicly documented. Without access to its API, it cannot be confirmed whether it supports batched processing of multiple frames in a single call. If it does, the loop could be replaced with a vectorized call; if not, this optimization is not feasible. Coordinate directly with the torch_admp maintainers or review their source code to determine batching capability before implementing any refactor.pyproject.toml (1)
167-171: The exact version pintorch-admp==1.1.1is valid and compatible.Verification confirms that torch-admp 1.1.1 exists on PyPI and requires
torch>=2.1.0, making it fully compatible with the pinned PyTorch versions (>=2.7, <2.10). The exact version pin is appropriate for reproducibility.source/tests/pt/modifier/test_data_modifier.py (1)
55-57: The updated relative import path is correct.The import
from ...consistent.common import parameterizedcorrectly resolves fromsource/tests/pt/modifier/test_data_modifier.pytosource/tests/consistent/common.py, where theparameterizedfunction is defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In @backend/dynamic_metadata.py:
- Around line 51-57: The code extracts torch_static_requirement as a list then
passes it to get_pt_requirement which is @lru_cache-decorated and therefore
requires hashable args; convert torch_static_requirement to an immutable tuple
before calling get_pt_requirement (replace usage of torch_static_requirement
with tuple(torch_static_requirement) or None when absent) and update
get_pt_requirement's signature/type hint in backend/find_pytorch.py to accept
tuple[str, ...] | None instead of list to match and avoid unhashable arguments;
also adjust any internal normalization in find_pytorch.py to expect/handle a
tuple.
In @backend/find_pytorch.py:
- Around line 133-134: The normalization currently sets static_requirement to an
empty list; if you change static_requirement to be a tuple for @lru_cache
compatibility, update the normalization to use an empty tuple instead (replace
static_requirement = [] with static_requirement = () where static_requirement is
initialized/normalized in backend/find_pytorch.py), and ensure any subsequent
code that relies on list methods is adapted to treat static_requirement as an
immutable tuple or explicitly convert to a list only where mutation is required.
- Around line 92-96: The get_pt_requirement function is decorated with
@lru_cache but accepts a mutable list (static_requirement: list[str] | None)
which is unhashable and breaks caching; change the parameter type to a hashable
tuple (static_requirement: tuple[str, ...] | None), update any call sites to
pass tuples (or convert lists to tuples inside callers), and update the function
docstring to reflect the new tuple type for static_requirement so lru_cache
works correctly.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
backend/dynamic_metadata.pybackend/find_pytorch.pypyproject.toml
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2024-10-30T20:08:12.531Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4284
File: deepmd/jax/__init__.py:8-8
Timestamp: 2024-10-30T20:08:12.531Z
Learning: In the DeepMD project, entry points like `deepmd.jax` may be registered in external projects, so their absence in the local configuration files is acceptable.
Applied to files:
pyproject.toml
📚 Learning: 2025-08-14T07:11:51.357Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4884
File: .github/workflows/test_cuda.yml:46-46
Timestamp: 2025-08-14T07:11:51.357Z
Learning: As of PyTorch 2.8 (August 2025), the default wheel on PyPI installed by `pip install torch` is CPU-only. CUDA-enabled wheels are available on PyPI for Linux x86 and Windows x86 platforms, but require explicit specification via index URLs or variant-aware installers. For CUDA support, use `--index-url https://download.pytorch.org/whl/cu126` (or appropriate CUDA version).
Applied to files:
pyproject.toml
🧬 Code graph analysis (1)
backend/dynamic_metadata.py (1)
backend/find_pytorch.py (1)
get_pt_requirement(93-150)
🪛 GitHub Actions: Test C++
backend/dynamic_metadata.py
[error] 55-55: TypeError: unhashable type: 'list' in dynamic_metadata (likely a build-environment issue) during editable wheel build for deepmd-kit. This caused the build to fail while running 'pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich'.
🪛 GitHub Actions: Test Python
backend/dynamic_metadata.py
[error] 55-55: TypeError: unhashable type: 'list' during dynamic_metadata() in backend/dynamic_metadata.py. This usually indicates a problem with the package or the build environment.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp311-win_amd64
🔇 Additional comments (2)
pyproject.toml (2)
142-144: LGTM! Follows existing patterns.The new
torchoptional dependency group is structured correctly and follows the same pattern as thejaxgroup above it. The unpinned version here paired with pinned versions in the dependency groups below is consistent with the project's dependency management approach.
170-170: torch-admp==1.1.1 is correctly pinned to the latest version with no known security vulnerabilities. No action required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @backend/dynamic_metadata.py:
- Around line 51-57: The code currently does tuple(torch_static_requirement)
which raises TypeError when optional_dependencies.pop("torch") returns None;
instead pass torch_static_requirement directly to get_pt_requirement (or only
wrap in tuple when it's not None) because get_pt_requirement already accepts
None and normalizes it; update the call in the return to use
static_requirement=torch_static_requirement (or conditional tuple(...) only if
torch_static_requirement is not None) and remove the unconditional tuple(...)
conversion.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
backend/dynamic_metadata.pybackend/find_pytorch.py
🚧 Files skipped from review as they are similar to previous changes (1)
- backend/find_pytorch.py
🧰 Additional context used
🧬 Code graph analysis (1)
backend/dynamic_metadata.py (1)
backend/find_pytorch.py (1)
get_pt_requirement(93-150)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (33)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
njzjz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
doc/model/dplr.md (1)
4-4: Update title icon for consistency with supported backends.The title on Line 1 includes only the TensorFlow icon, but Line 4 now lists both TensorFlow and PyTorch as supported backends. To avoid visual inconsistency, consider updating the title to include the PyTorch icon as well.
Proposed change
-# Deep potential long-range (DPLR) {{ tensorflow_icon }} +# Deep potential long-range (DPLR) {{ tensorflow_icon }} {{ pytorch_icon }}
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/utils/argcheck.pydoc/model/dplr.md
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (34)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (2)
doc/model/dplr.md (1)
7-7: Clear communication of PyTorch LAMMPS limitation.The updated notice correctly informs users that PyTorch DPLR models are not yet supported in the DP/LAMMPS interface and provides the necessary workaround. However, consider whether the documentation should include guidance on PyTorch-specific training workflows (e.g., how the PyTorch backend differs from TensorFlow in the training process, if at all).
Does the PyTorch implementation have any training-time differences or special considerations that should be documented separately, or is the training workflow identical to TensorFlow?
deepmd/utils/argcheck.py (1)
2307-2314: LGTM! Documentation correctly updated for multi-backend support.The removal of
doc_only_tf_supportedprefix from the modifier documentation is appropriate, as this PR adds PyTorch backend support for the DipoleChargeModifier. The documentation now accurately reflects that the modifier is no longer TensorFlow-exclusive.
| raise RuntimeError( | ||
| "dipole_charge data modifier can only be applied for periodic systems." | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simplify this logic by removing the else keyword, since the early raise makes it redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current implementation (pt backend), periodic is not required in the level of BaseModifier. I therefore need to add the raise here.
Signed-off-by: Jinzhe Zeng <[email protected]>
…put and improve serialization. Added support for both model file path and direct model object initialization, along with proper serialize/deserialize methods and caching functionality.
…ed memory efficiency and performance. This commit introduces configurable batch sizes for both Ewald calculations and dipole model inference, refactors the system extension logic to return atomic dipoles, and optimizes memory usage during large-scale simulations.
452a932 to
983e7d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@source/tests/pt/modifier/test_data_modifier.py`:
- Around line 192-199: The __new__ implementation on ModifierScalingTester is
incorrectly forwarding model/model_name to super().__new__, which can raise a
TypeError; remove the extra argument and have super().__new__(cls) only, and
move any handling of model/model_name into the class __init__ (or process them
after object creation) so initialization uses the model/model_name values
without passing them to the base __new__ method.
♻️ Duplicate comments (2)
backend/find_pytorch.py (1)
92-96: Incorrect tuple type annotation for variable-length sequences.The type
tuple[str]specifies a tuple with exactly onestrelement. For a variable-length tuple of strings, usetuple[str, ...].🐛 Proposed fix
`@lru_cache` def get_pt_requirement( pt_version: str = "", - static_requirement: tuple[str] | None = None, + static_requirement: tuple[str, ...] | None = None, ) -> dict:deepmd/pt/modifier/dipole_charge.py (1)
47-50: Removemodel_nameargument fromsuper().__new__()call.This appears to be unresolved from a previous review. Passing
model_nametosuper().__new__()will cause aTypeErrorat runtime since the parent's__new__only acceptscls. Themodel_nameparameter should be handled exclusively in__init__.🐛 Proposed fix
def __new__( cls, *args: tuple, model_name: str | None = None, **kwargs: dict ) -> "DipoleChargeModifier": - return super().__new__(cls, model_name) + return super().__new__(cls)
🧹 Nitpick comments (8)
backend/find_pytorch.py (1)
105-106: Update docstring to match corrected type.The docstring should reflect the variable-length tuple type
tuple[str, ...]to be consistent with the corrected type annotation.♻️ Proposed fix
- static_requirement : tuple[str] or None, optional + static_requirement : tuple[str, ...] or None, optional Static requirementsdeepmd/pt/modifier/dipole_charge.py (4)
194-198: Remove redundantelseafterraise.The
elsekeyword is unnecessary since theraisestatement exits the function. This was noted in a previous review.🔧 Suggested simplification
if box is None: raise RuntimeError( "dipole_charge data modifier can only be applied for periodic systems." ) - else: - modifier_pred = {} + modifier_pred = {}Then dedent the rest of the function body.
232-243: Addstrict=Truetozip()for safer iteration.Without
strict=True, mismatched lengths betweenchunk_coord,chunk_box, andchunk_chargewould silently truncate, potentially causing incorrect results.🔧 Proposed fix
- for _coord, _box, _charge in zip(chunk_coord, chunk_box, chunk_charge): + for _coord, _box, _charge in zip(chunk_coord, chunk_box, chunk_charge, strict=True):
370-372: Addstrict=Truetozip()for safer iteration.Same as above - mismatched chunk lengths would silently truncate.
🔧 Proposed fix
- for _coord, _atype, _box, _fparam, _aparam in zip( - chunk_coord, chunk_atype, chunk_box, chunk_fparam, chunk_aparam - ): + for _coord, _atype, _box, _fparam, _aparam in zip( + chunk_coord, chunk_atype, chunk_box, chunk_fparam, chunk_aparam, strict=True + ):
160-160: Consider removing or implementingdo_atomic_virial.The parameter is documented as "not implemented and is ignored" (line 184). If atomic virial computation is not planned, consider removing it from the signature to avoid confusion. Otherwise, add a
# noqa: ARG002comment to suppress the linter warning.deepmd/pt/train/wrapper.py (1)
191-195: LGTM with a minor suggestion.The reshape ensures shape compatibility between modifier and model outputs. Consider adding validation to provide clearer error messages if shapes are incompatible (e.g., different total element counts).
💡 Optional: Add shape validation for clearer errors
if self.modifier is not None: modifier_pred = self.modifier(**input_dict) for k, v in modifier_pred.items(): + if v.numel() != model_pred[k].numel(): + raise RuntimeError( + f"Modifier output '{k}' has {v.numel()} elements, " + f"but model output has {model_pred[k].numel()} elements" + ) model_pred[k] = model_pred[k] + v.reshape(model_pred[k].shape)source/tests/pt/modifier/test_dipole_charge.py (2)
128-134: Consider explicit output labeling for clarity.The loop over indices works correctly, but explicitly naming the outputs would improve readability and make test failures easier to diagnose.
♻️ Suggested improvement
- for ii in range(3): - np.testing.assert_allclose( - pt_data[ii].reshape(-1), - tf_data[ii].reshape(-1), - atol=1e-6, - rtol=1e-6, - ) + output_names = ["energy", "force", "virial"] + for ii, name in enumerate(output_names): + np.testing.assert_allclose( + pt_data[ii].reshape(-1), + tf_data[ii].reshape(-1), + atol=1e-6, + rtol=1e-6, + err_msg=f"Mismatch in {name}", + )
191-208: Consider adding a basic assertion for completeness.The smoke test validates that training runs without exceptions, which is valuable. Adding a simple assertion (e.g., checking that model checkpoint was created) would make the test's success criteria more explicit.
♻️ Optional enhancement
config["training"]["numb_steps"] = 1 trainer = get_trainer(config) trainer.run() + # Verify model checkpoint was created + self.assertTrue( + Path("model.ckpt.pt").exists(), + "Training should produce a model checkpoint", + )
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
backend/dynamic_metadata.pybackend/find_pytorch.pydeepmd/dpmodel/modifier/base_modifier.pydeepmd/pt/entrypoints/main.pydeepmd/pt/infer/deep_eval.pydeepmd/pt/modifier/__init__.pydeepmd/pt/modifier/base_modifier.pydeepmd/pt/modifier/dipole_charge.pydeepmd/pt/train/wrapper.pydeepmd/utils/argcheck.pydoc/model/dplr.mdpyproject.tomlsource/tests/pt/modifier/__init__.pysource/tests/pt/modifier/test_data_modifier.pysource/tests/pt/modifier/test_dipole_charge.pysource/tests/pt/modifier/watersource/tests/pt/modifier/water_tensor
💤 Files with no reviewable changes (1)
- deepmd/dpmodel/modifier/base_modifier.py
✅ Files skipped from review due to trivial changes (2)
- source/tests/pt/modifier/water_tensor
- deepmd/pt/modifier/base_modifier.py
🚧 Files skipped from review as they are similar to previous changes (3)
- backend/dynamic_metadata.py
- source/tests/pt/modifier/init.py
- deepmd/utils/argcheck.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2024-11-25T07:42:55.735Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4302
File: deepmd/pd/infer/inference.py:35-38
Timestamp: 2024-11-25T07:42:55.735Z
Learning: In the file `deepmd/pd/infer/inference.py`, when loading the model checkpoint in the `Tester` class, it's acceptable to not include additional error handling for loading the model state dictionary.
Applied to files:
deepmd/pt/infer/deep_eval.py
📚 Learning: 2024-10-30T20:08:12.531Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4284
File: deepmd/jax/__init__.py:8-8
Timestamp: 2024-10-30T20:08:12.531Z
Learning: In the DeepMD project, entry points like `deepmd.jax` may be registered in external projects, so their absence in the local configuration files is acceptable.
Applied to files:
pyproject.toml
📚 Learning: 2025-08-14T07:11:51.357Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4884
File: .github/workflows/test_cuda.yml:46-46
Timestamp: 2025-08-14T07:11:51.357Z
Learning: As of PyTorch 2.8 (August 2025), the default wheel on PyPI installed by `pip install torch` is CPU-only. CUDA-enabled wheels are available on PyPI for Linux x86 and Windows x86 platforms, but require explicit specification via index URLs or variant-aware installers. For CUDA support, use `--index-url https://download.pytorch.org/whl/cu126` (or appropriate CUDA version).
Applied to files:
pyproject.toml
🧬 Code graph analysis (1)
deepmd/pt/modifier/dipole_charge.py (1)
deepmd/tf/modifier/dipole_charge.py (1)
DipoleChargeModifier(40-546)
🪛 Ruff (0.14.11)
deepmd/pt/modifier/dipole_charge.py
48-48: Unused static method argument: args
(ARG004)
48-48: Unused static method argument: kwargs
(ARG004)
69-69: Avoid specifying long messages outside the exception class
(TRY003)
71-73: Avoid specifying long messages outside the exception class
(TRY003)
95-98: Avoid specifying long messages outside the exception class
(TRY003)
160-160: Unused method argument: do_atomic_virial
(ARG002)
195-197: Avoid specifying long messages outside the exception class
(TRY003)
232-232: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
370-372: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
387-389: Avoid specifying long messages outside the exception class
(TRY003)
deepmd/pt/infer/deep_eval.py
187-187: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
source/tests/pt/modifier/test_data_modifier.py
194-194: Unused static method argument: args
(ARG004)
197-197: Unused static method argument: kwargs
(ARG004)
214-214: Avoid specifying long messages outside the exception class
(TRY003)
216-218: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (25)
doc/model/dplr.md (1)
1-7: LGTM!Documentation updates clearly communicate PyTorch backend support and appropriately warn users about the current LAMMPS interface limitation.
pyproject.toml (2)
166-176: LGTM!The
torch-admp==1.1.2pin is consistently applied to both CPU and GPU dependency groups, ensuring reproducible builds across environments.
143-145: Thetorch-admppackage is available on PyPI with version 1.1.2 as the latest release. The dependency specification in pyproject.toml is correct and will not cause installation issues.deepmd/pt/modifier/__init__.py (1)
10-17: LGTM!The
DipoleChargeModifieris properly imported and exported via__all__, following the established pattern forBaseModifier.backend/find_pytorch.py (1)
133-148: LGTM!The normalization of
Noneto an empty tuple and the unpacking into the requirements list is correctly implemented. This properly addresses the@lru_cachecompatibility requirement from the previous review.source/tests/pt/modifier/water (1)
1-1: No issues found. The symlink correctly targetssource/tests/pt/model/water/which exists and contains all necessary test fixtures.source/tests/pt/modifier/test_data_modifier.py (5)
39-50: LGTM! New imports for model deserialization support.The imports for
EnergyModelandserialize_from_fileare correctly added to support the new serialization/deserialization functionality.
61-63: Import path adjustment looks correct.The path change from
..consistent.commonto...consistent.commonappears to reflect the directory structure with the test file located undersource/tests/pt/modifier/.
94-96: LGTM! Alias improves API flexibility.Adding
alias=["model"]allows users to specify eithermodel_nameormodelin configuration, improving usability while maintaining backward compatibility.
201-227: LGTM! Model initialization logic is well-structured.The mutual exclusivity check between
modelandmodel_nameis clear, and the model loading/scripting flow is consistent with the pattern used inDipoleChargeModifier.
229-254: LGTM! Serialization methods follow the established pattern.The
serializeanddeserializemethods correctly handle the model serialization and metadata stripping, consistent with the pattern shown in the relevant code snippets.deepmd/pt/modifier/dipole_charge.py (2)
261-313: LGTM! System extension logic is well-implemented.The
extend_systemand charge assignment logic correctly handles WFCC coordinates and charge mapping. The validation at lines 93-98 ensuresmodel_charge_mapandsel_typealignment.
397-435: LGTM! NumPy evaluation wrapper is correctly implemented.The
eval_npmethod properly handles tensor conversions and forwards to theforwardmethod.deepmd/pt/infer/deep_eval.py (3)
4-4: Pickle import added for modifier deserialization.The pickle import is required for deserializing modifier data stored in the frozen model's extra files.
52-54: LGTM! BaseModifier import enables polymorphic deserialization.The import allows using
BaseModifier.get_class_by_type()for proper modifier class resolution during deserialization.
178-192: Modifier deserialization logic looks correct.The flow properly:
- Checks if modifier data exists in extra_files
- Unpickles the bytes data
- Uses the type registry to instantiate the correct modifier class
The pickle security warning (S301) is acceptable here since users only load model files they trust—the same trust model applies to loading the model itself.
deepmd/pt/entrypoints/main.py (2)
7-7: Pickle import added for modifier serialization.Required for the new serialization approach in the freeze function.
404-413: LGTM! Simplified modifier serialization.The change from BytesIO-wrapped JIT saving to direct pickle serialization is cleaner and aligns with the deserialization logic in
deep_eval.py. The serialized dict includes all necessary data for reconstruction viaBaseModifier.get_class_by_type().deserialize().source/tests/pt/modifier/test_dipole_charge.py (7)
1-33: LGTM!The imports are well-organized, covering the necessary utilities for cross-backend testing between PyTorch and TensorFlow implementations of DipoleChargeModifier.
35-50: LGTM!The helper function provides reproducible test data loading with proper seeding and consistent output shapes.
53-110: LGTM!The setUp method properly isolates tests using a temporary directory, trains a minimal model for speed (
numb_steps=1), and initializes both PT and TF modifiers for cross-backend comparison. Theselandneuronvalues are already reasonably sized to keep tests fast while maintaining coverage.
112-113: LGTM!Simple and effective JIT scripting compatibility test.
136-168: LGTM!The serialization round-trip test is well-structured, properly moving models to the correct device and validating that deserialized models produce identical outputs. Using default tolerances is appropriate here since serialization should preserve exact values.
170-189: LGTM!The error handling test correctly validates that
RuntimeErroris raised (notRuntimeWarning) whenbox=None, and properly verifies the error message content. The past review feedback has been addressed.
210-212: LGTM!Proper cleanup that restores the working directory and removes the temporary test directory.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
This PR introduces the DPLR (Dipole Charge) modifier implementation for the PyTorch backend of DeepMD-kit, enabling quantum mechanical charge distribution modeling in molecular dynamics simulations.
Key Changes
New DipoleChargeModifier class (
deepmd/pt/modifier/dipole_charge.py):Dependency addition:
torch_admppackage to project dependencies inpyproject.tomlComprehensive testing (
source/tests/pt/modifier/test_dipole_charge.py):Benefits
Files Modified
Summary by CodeRabbit
New Features
Tests
Chores
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.