Skip to content

Commit dfeba54

Browse files
authored
feat(pt): add plugin for data modifier (#4661)
## Overview This PR adds a data modifier plugin functionality to the PyTorch implementation of DeepMD. This feature allows for on-the-fly data modification during training and inference, enabling advanced data manipulation capabilities. ## Key Changes ### 1. Added Data Modifier to Training Pipeline - **File**: `deepmd/pt/entrypoints/main.py` - Added imports for data modifier functionality (`get_data_modifier`) - Added modifier initialization in `get_trainer()` function - Added modifier parameter to data loader initialization for both training and validation datasets - Enhanced model freezing process to include modifier handling with temporary file management ### 2. Added Data Modifier to Inference - **File**: `deepmd/pt/infer/deep_eval.py` - Added modifier loading and handling in `DeepEval` class - Enhanced model loading process to handle extra files containing modifier data - Added modifier application in inference methods to modify model predictions ### 3. Implemented Data Modifier Framework - **File**: `deepmd/pt/modifier/__init__.py` (entirely new) - Created base class `BaseModifier` with registration system - Implemented three specific modifier types: - `ModifierRandomTester`: Applies random scaling to energy/force/virial data for testing - `ModifierZeroTester`: Zeroes out energy/force/virial data for testing - `ModifierScalingTester`: Applies scaled model predictions as data modifications - Added comprehensive argument parsing for modifier configuration ### 4. Added Data Modifier Tests - **File**: `deepmd/pt/test/test_modifier.py` (entirely new) - Created comprehensive test suite for data modifier functionality - Tests include: - Modifier initialization and data modification verification - Ensuring data modification is applied only once - Testing inference with data modification by verifying scaled model predictions - Added helper methods for test data management and comparison <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Pluggable data modifier API: create, attach, serialize and embed modifiers with models; modifiers propagate through loaders, model wrapper and inference. * **Behavior** * Modifiers can be preloaded, applied and optionally cached during data loading; their outputs adjust energies/forces/virials during training and inference. * Frozen model export/import preserves embedded modifiers. * **Tests** * End-to-end tests covering zeroing, scaling, deterministic/random modifiers and frozen-model inference. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent be3b876 commit dfeba54

File tree

12 files changed

+888
-24
lines changed

12 files changed

+888
-24
lines changed

deepmd/pd/utils/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
Dataset,
66
)
77

8+
from deepmd.pd.utils.env import (
9+
NUM_WORKERS,
10+
)
811
from deepmd.utils.data import (
912
DataRequirementItem,
1013
DeepmdData,
@@ -32,7 +35,7 @@ def __len__(self):
3235

3336
def __getitem__(self, index):
3437
"""Get a frame from the selected system."""
35-
b_data = self._data_system.get_item_paddle(index)
38+
b_data = self._data_system.get_item_paddle(index, max(1, NUM_WORKERS))
3639
b_data["natoms"] = self._natoms_vec
3740
return b_data
3841

deepmd/pt/entrypoints/main.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import argparse
33
import copy
4+
import io
45
import json
56
import logging
67
import os
@@ -47,6 +48,9 @@
4748
from deepmd.pt.model.model import (
4849
BaseModel,
4950
)
51+
from deepmd.pt.modifier import (
52+
get_data_modifier,
53+
)
5054
from deepmd.pt.train import (
5155
training,
5256
)
@@ -111,6 +115,12 @@ def prepare_trainer_input_single(
111115
rank: int = 0,
112116
seed: int | None = None,
113117
) -> tuple[DpLoaderSet, DpLoaderSet | None, DPPath | None]:
118+
# get data modifier
119+
modifier = None
120+
modifier_params = model_params_single.get("modifier", None)
121+
if modifier_params is not None:
122+
modifier = get_data_modifier(modifier_params).to(DEVICE)
123+
114124
training_dataset_params = data_dict_single["training_data"]
115125
validation_dataset_params = data_dict_single.get("validation_data", None)
116126
validation_systems = (
@@ -145,6 +155,7 @@ def prepare_trainer_input_single(
145155
validation_dataset_params["batch_size"],
146156
model_params_single["type_map"],
147157
seed=rank_seed,
158+
modifier=modifier,
148159
)
149160
if validation_systems
150161
else None
@@ -154,6 +165,7 @@ def prepare_trainer_input_single(
154165
training_dataset_params["batch_size"],
155166
model_params_single["type_map"],
156167
seed=rank_seed,
168+
modifier=modifier,
157169
)
158170
return (
159171
train_data_single,
@@ -372,10 +384,22 @@ def freeze(
372384
output: str = "frozen_model.pth",
373385
head: str | None = None,
374386
) -> None:
375-
model = inference.Tester(model, head=head).model
387+
tester = inference.Tester(model, head=head)
388+
model = tester.model
376389
model.eval()
377390
model = torch.jit.script(model)
378-
extra_files = {}
391+
392+
dm_output = "data_modifier.pth"
393+
extra_files = {dm_output: ""}
394+
if tester.modifier is not None:
395+
dm = tester.modifier
396+
dm.eval()
397+
buffer = io.BytesIO()
398+
torch.jit.save(
399+
torch.jit.script(dm),
400+
buffer,
401+
)
402+
extra_files = {dm_output: buffer.getvalue()}
379403
torch.jit.save(
380404
model,
381405
output,

deepmd/pt/infer/deep_eval.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import io
23
import json
34
import logging
45
from collections.abc import (
@@ -171,8 +172,21 @@ def __init__(
171172
self.dp = ModelWrapper(model)
172173
self.dp.load_state_dict(state_dict)
173174
elif str(self.model_path).endswith(".pth"):
174-
model = torch.jit.load(model_file, map_location=env.DEVICE)
175-
self.dp = ModelWrapper(model)
175+
extra_files = {"data_modifier.pth": ""}
176+
model = torch.jit.load(
177+
model_file, map_location=env.DEVICE, _extra_files=extra_files
178+
)
179+
modifier = None
180+
# Load modifier if it exists in extra_files
181+
if len(extra_files["data_modifier.pth"]) > 0:
182+
# Create a file-like object from the in-memory data
183+
modifier_data = extra_files["data_modifier.pth"]
184+
if isinstance(modifier_data, bytes):
185+
modifier_data = io.BytesIO(modifier_data)
186+
# Load the modifier directly from the file-like object
187+
modifier = torch.jit.load(modifier_data, map_location=env.DEVICE)
188+
self.dp = ModelWrapper(model, modifier=modifier)
189+
self.modifier = modifier
176190
model_def_script = self.dp.model["Default"].get_model_def_script()
177191
if model_def_script:
178192
self.model_def_script = json.loads(model_def_script)

deepmd/pt/infer/inference.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from deepmd.pt.model.model import (
1010
get_model,
1111
)
12+
from deepmd.pt.modifier import (
13+
get_data_modifier,
14+
)
1215
from deepmd.pt.train.wrapper import (
1316
ModelWrapper,
1417
)
@@ -60,6 +63,11 @@ def __init__(
6063
) # wrapper Hessian to Energy model due to JIT limit
6164
self.model_params = deepcopy(model_params)
6265
self.model = get_model(model_params).to(DEVICE)
66+
self.modifier = None
67+
if "modifier" in model_params:
68+
modifier = get_data_modifier(model_params["modifier"]).to(DEVICE)
69+
if modifier.jitable:
70+
self.modifier = modifier
6371

6472
# Model Wrapper
6573
self.wrapper = ModelWrapper(self.model) # inference only

deepmd/pt/modifier/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import copy
3+
from typing import (
4+
Any,
5+
)
6+
7+
from .base_modifier import (
8+
BaseModifier,
9+
)
10+
11+
__all__ = [
12+
"BaseModifier",
13+
"get_data_modifier",
14+
]
15+
16+
17+
def get_data_modifier(_modifier_params: dict[str, Any]) -> BaseModifier:
18+
modifier_params = copy.deepcopy(_modifier_params)
19+
try:
20+
modifier_type = modifier_params.pop("type")
21+
except KeyError:
22+
raise ValueError("Data modifier type not specified!") from None
23+
return BaseModifier.get_class_by_type(modifier_type).get_modifier(modifier_params)
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from abc import (
3+
abstractmethod,
4+
)
5+
6+
import numpy as np
7+
import torch
8+
9+
from deepmd.dpmodel.array_api import (
10+
Array,
11+
)
12+
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
13+
from deepmd.dpmodel.modifier.base_modifier import (
14+
make_base_modifier,
15+
)
16+
from deepmd.pt.utils.env import (
17+
DEVICE,
18+
GLOBAL_PT_FLOAT_PRECISION,
19+
RESERVED_PRECISION_DICT,
20+
)
21+
from deepmd.pt.utils.utils import (
22+
to_numpy_array,
23+
to_torch_tensor,
24+
)
25+
from deepmd.utils.data import (
26+
DeepmdData,
27+
)
28+
29+
30+
class BaseModifier(torch.nn.Module, make_base_modifier()):
31+
def __init__(self, use_cache: bool = True) -> None:
32+
"""Construct a base modifier for data modification tasks."""
33+
torch.nn.Module.__init__(self)
34+
self.modifier_type = "base"
35+
self.jitable = True
36+
37+
self.use_cache = use_cache
38+
39+
def serialize(self) -> dict:
40+
"""Serialize the modifier.
41+
42+
Returns
43+
-------
44+
dict
45+
The serialized data
46+
"""
47+
data = {
48+
"@class": "Modifier",
49+
"type": self.modifier_type,
50+
"@version": 3,
51+
}
52+
return data
53+
54+
@classmethod
55+
def deserialize(cls, data: dict) -> "BaseModifier":
56+
"""Deserialize the modifier.
57+
58+
Parameters
59+
----------
60+
data : dict
61+
The serialized data
62+
63+
Returns
64+
-------
65+
BaseModifier
66+
The deserialized modifier
67+
"""
68+
data = data.copy()
69+
# Remove serialization metadata before passing to constructor
70+
data.pop("@class", None)
71+
data.pop("type", None)
72+
data.pop("@version", None)
73+
modifier = cls(**data)
74+
return modifier
75+
76+
@abstractmethod
77+
@torch.jit.export
78+
def forward(
79+
self,
80+
coord: torch.Tensor,
81+
atype: torch.Tensor,
82+
box: torch.Tensor | None = None,
83+
fparam: torch.Tensor | None = None,
84+
aparam: torch.Tensor | None = None,
85+
do_atomic_virial: bool = False,
86+
) -> dict[str, torch.Tensor]:
87+
"""Compute energy, force, and virial corrections."""
88+
89+
@torch.jit.unused
90+
def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> None:
91+
"""Modify data of single frame.
92+
93+
Parameters
94+
----------
95+
data
96+
Internal data of DeepmdData.
97+
Be a dict, has the following keys
98+
- coord coordinates (nat, 3)
99+
- box simulation box (9,)
100+
- atype atom types (nat,)
101+
- fparam frame parameter (nfp,)
102+
- aparam atom parameter (nat, nap)
103+
- find_energy tells if data has energy
104+
- find_force tells if data has force
105+
- find_virial tells if data has virial
106+
- energy energy (1,)
107+
- force force (nat, 3)
108+
- virial virial (9,)
109+
"""
110+
if (
111+
"find_energy" not in data
112+
and "find_force" not in data
113+
and "find_virial" not in data
114+
):
115+
return
116+
117+
prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]]
118+
119+
nframes = 1
120+
natoms = len(data["atype"])
121+
atom_types = np.tile(data["atype"], nframes).reshape(nframes, -1)
122+
123+
coord_input = torch.tensor(
124+
data["coord"].reshape([nframes, natoms, 3]).astype(prec),
125+
dtype=GLOBAL_PT_FLOAT_PRECISION,
126+
device=DEVICE,
127+
)
128+
type_input = torch.tensor(
129+
atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISION_DICT[torch.long]]),
130+
dtype=torch.long,
131+
device=DEVICE,
132+
)
133+
if data["box"] is not None:
134+
box_input = torch.tensor(
135+
data["box"].reshape([nframes, 3, 3]).astype(prec),
136+
dtype=GLOBAL_PT_FLOAT_PRECISION,
137+
device=DEVICE,
138+
)
139+
else:
140+
box_input = None
141+
if "fparam" in data:
142+
fparam_input = to_torch_tensor(data["fparam"].reshape(nframes, -1))
143+
else:
144+
fparam_input = None
145+
if "aparam" in data:
146+
aparam_input = to_torch_tensor(data["aparam"].reshape(nframes, natoms, -1))
147+
else:
148+
aparam_input = None
149+
do_atomic_virial = False
150+
151+
# implement data modification method in forward
152+
modifier_data = self.forward(
153+
coord_input,
154+
type_input,
155+
box_input,
156+
fparam_input,
157+
aparam_input,
158+
do_atomic_virial,
159+
)
160+
161+
if data.get("find_energy") == 1.0:
162+
if "energy" not in modifier_data:
163+
raise KeyError(
164+
f"Modifier {self.__class__.__name__} did not provide 'energy' "
165+
"in its output while 'find_energy' is set."
166+
)
167+
data["energy"] -= to_numpy_array(modifier_data["energy"]).reshape(
168+
data["energy"].shape
169+
)
170+
if data.get("find_force") == 1.0:
171+
if "force" not in modifier_data:
172+
raise KeyError(
173+
f"Modifier {self.__class__.__name__} did not provide 'force' "
174+
"in its output while 'find_force' is set."
175+
)
176+
data["force"] -= to_numpy_array(modifier_data["force"]).reshape(
177+
data["force"].shape
178+
)
179+
if data.get("find_virial") == 1.0:
180+
if "virial" not in modifier_data:
181+
raise KeyError(
182+
f"Modifier {self.__class__.__name__} did not provide 'virial' "
183+
"in its output while 'find_virial' is set."
184+
)
185+
data["virial"] -= to_numpy_array(modifier_data["virial"]).reshape(
186+
data["virial"].shape
187+
)

0 commit comments

Comments
 (0)