Skip to content

Commit 8d02057

Browse files
committed
add UT for data modifier in pt model training
1 parent 8e4d3f3 commit 8d02057

File tree

5 files changed

+250
-8
lines changed

5 files changed

+250
-8
lines changed

deepmd/pt/train/training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,9 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
391391
if self.finetune_links is not None
392392
else False,
393393
)
394+
training_data[model_key].preload_and_modify_all_data()
395+
if validation_data[model_key] is not None:
396+
validation_data[model_key].preload_and_modify_all_data()
394397
(
395398
self.training_dataloader[model_key],
396399
self.training_data[model_key],

deepmd/pt/utils/dataloader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,10 @@ def preload_and_modify_all_data(self) -> None:
242242
for system in self.systems:
243243
system.preload_and_modify_all_data()
244244

245+
# def clear_modified_frame_cache(self) -> None:
246+
# for system in self.systems:
247+
# system.clear_modified_frame_cache()
248+
245249

246250
def collate_batch(batch: list[dict[str, Any]]) -> dict[str, Any]:
247251
example = batch[0]

deepmd/pt/utils/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,6 @@ def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> N
7070

7171
def preload_and_modify_all_data(self) -> None:
7272
self._data_system.preload_and_modify_all_data()
73+
74+
# def clear_modified_frame_cache(self) -> None:
75+
# self._data_system.clear_modified_frame_cache()

deepmd/utils/data.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,14 @@ def preload_and_modify_all_data(self) -> None:
512512
log.info(f"Processed {i + 1}/{self.nframes} frames")
513513
log.info("All frames preloaded and modified.")
514514

515-
def clear_modified_frame_cache(self) -> None:
516-
"""Clear the modified frame cache.
517-
518-
This method is useful when you want to free up memory or force
519-
recomputation of modified frames.
520-
"""
521-
self._modified_frame_cache.clear()
522-
log.info("Modified frame cache cleared.")
515+
# def clear_modified_frame_cache(self) -> None:
516+
# """Clear the modified frame cache.
517+
518+
# This method is useful when you want to free up memory or force
519+
# recomputation of modified frames.
520+
# """
521+
# self._modified_frame_cache.clear()
522+
# log.info("Modified frame cache cleared.")
523523

524524
def avg(self, key: str) -> float:
525525
"""Return the average value of an item."""
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import os
4+
import unittest
5+
from pathlib import (
6+
Path,
7+
)
8+
9+
import numpy as np
10+
11+
from deepmd.pt.entrypoints.main import (
12+
get_trainer,
13+
)
14+
from deepmd.pt.modifier.base_modifier import (
15+
BaseModifier,
16+
)
17+
from deepmd.pt.utils.utils import (
18+
to_numpy_array,
19+
)
20+
from deepmd.utils.argcheck import (
21+
modifier_args_plugin,
22+
)
23+
from deepmd.utils.data import (
24+
DeepmdData,
25+
)
26+
27+
28+
@modifier_args_plugin.register("random_tester")
29+
def modifier_random_tester() -> list:
30+
return []
31+
32+
33+
@modifier_args_plugin.register("zero_tester")
34+
def modifier_zero_tester() -> list:
35+
return []
36+
37+
38+
@BaseModifier.register("random_tester")
39+
class ModifierRandomTester(BaseModifier):
40+
def __new__(cls) -> BaseModifier:
41+
return super().__new__(cls)
42+
43+
def __init__(self) -> None:
44+
"""Construct a basic model for different tasks."""
45+
super().__init__()
46+
self.modifier_type = "tester"
47+
48+
def serialize(self) -> dict:
49+
"""Serialize the modifier.
50+
51+
Returns
52+
-------
53+
dict
54+
The serialized data
55+
"""
56+
data = {
57+
"@class": "Modifier",
58+
"type": self.modifier_type,
59+
"@version": 3,
60+
}
61+
return data
62+
63+
@classmethod
64+
def deserialize(cls, data: dict) -> "BaseModifier":
65+
"""Deserialize the modifier.
66+
67+
Parameters
68+
----------
69+
data : dict
70+
The serialized data
71+
72+
Returns
73+
-------
74+
BaseModifier
75+
The deserialized modifier
76+
"""
77+
data = data.copy()
78+
modifier = cls(**data)
79+
return modifier
80+
81+
def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
82+
"""Multiply by a random factor."""
83+
if (
84+
"find_energy" not in data
85+
and "find_force" not in data
86+
and "find_virial" not in data
87+
):
88+
return
89+
90+
if "find_energy" in data and data["find_energy"] == 1.0:
91+
data["energy"] = data["energy"] * np.random.Generator()
92+
if "find_force" in data and data["find_force"] == 1.0:
93+
data["force"] = data["force"] * np.random.Generator()
94+
if "find_virial" in data and data["find_virial"] == 1.0:
95+
data["virial"] = data["virial"] * np.random.Generator()
96+
97+
98+
@BaseModifier.register("zero_tester")
99+
class ModifierZeroTester(BaseModifier):
100+
def __new__(cls) -> BaseModifier:
101+
return super().__new__(cls)
102+
103+
def __init__(self) -> None:
104+
"""Construct a basic model for different tasks."""
105+
super().__init__()
106+
self.modifier_type = "tester"
107+
108+
def serialize(self) -> dict:
109+
"""Serialize the modifier.
110+
111+
Returns
112+
-------
113+
dict
114+
The serialized data
115+
"""
116+
data = {
117+
"@class": "Modifier",
118+
"type": self.modifier_type,
119+
"@version": 3,
120+
}
121+
return data
122+
123+
@classmethod
124+
def deserialize(cls, data: dict) -> "BaseModifier":
125+
"""Deserialize the modifier.
126+
127+
Parameters
128+
----------
129+
data : dict
130+
The serialized data
131+
132+
Returns
133+
-------
134+
BaseModifier
135+
The deserialized modifier
136+
"""
137+
data = data.copy()
138+
modifier = cls(**data)
139+
return modifier
140+
141+
def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
142+
"""Multiply by a random factor."""
143+
if (
144+
"find_energy" not in data
145+
and "find_force" not in data
146+
and "find_virial" not in data
147+
):
148+
return
149+
150+
if "find_energy" in data and data["find_energy"] == 1.0:
151+
data["energy"] -= data["energy"]
152+
if "find_force" in data and data["find_force"] == 1.0:
153+
data["force"] -= data["force"]
154+
if "find_virial" in data and data["find_virial"] == 1.0:
155+
data["virial"] -= data["virial"]
156+
157+
158+
class TestDataModifier(unittest.TestCase):
159+
def setUp(self) -> None:
160+
"""Set up test fixtures."""
161+
input_json = str(Path(__file__).parent / "water/se_e2_a.json")
162+
with open(input_json, encoding="utf-8") as f:
163+
config = json.load(f)
164+
config["training"]["numb_steps"] = 10
165+
config["training"]["save_freq"] = 1
166+
config["learning_rate"]["start_lr"] = 1.0
167+
config["training"]["training_data"]["systems"] = [
168+
str(Path(__file__).parent / "water/data/single")
169+
]
170+
config["training"]["validation_data"]["systems"] = [
171+
str(Path(__file__).parent / "water/data/single")
172+
]
173+
self.config = config
174+
175+
def test_init_modify_data(self):
176+
"""Ensure modify_data applied."""
177+
tmp_config = self.config.copy()
178+
# add tester data modifier
179+
tmp_config["model"]["modifier"] = {"type": "zero_tester"}
180+
181+
# data modification is finished in __init__
182+
trainer = get_trainer(tmp_config)
183+
184+
# training data
185+
training_data = trainer.get_data(is_train=True)
186+
# validation data
187+
validation_data = trainer.get_data(is_train=False)
188+
189+
for dataset in [training_data, validation_data]:
190+
for kw in ["energy", "force"]:
191+
data = to_numpy_array(dataset[1][kw])
192+
np.testing.assert_allclose(data, np.zeros_like(data))
193+
194+
def test_full_modify_data(self):
195+
"""Ensure modify_data only applied once."""
196+
tmp_config = self.config.copy()
197+
# add tester data modifier
198+
tmp_config["model"]["modifier"] = {"type": "random_tester"}
199+
200+
# data modification is finished in __init__
201+
trainer = get_trainer(tmp_config)
202+
203+
# training data
204+
training_data_before = trainer.get_data(is_train=True)
205+
# validation data
206+
validation_data_before = trainer.get_data(is_train=False)
207+
208+
trainer.run()
209+
210+
# training data
211+
training_data_after = trainer.get_data(is_train=True)
212+
# validation data
213+
validation_data_after = trainer.get_data(is_train=False)
214+
215+
for kw in ["energy", "force"]:
216+
np.testing.assert_allclose(
217+
to_numpy_array(training_data_before[1][kw]),
218+
to_numpy_array(training_data_after[1][kw]),
219+
)
220+
np.testing.assert_allclose(
221+
to_numpy_array(validation_data_before[1][kw]),
222+
to_numpy_array(validation_data_after[1][kw]),
223+
)
224+
225+
def tearDown(self) -> None:
226+
for f in os.listdir("."):
227+
if f.startswith("frozen_model") and f.endswith(".pth"):
228+
os.remove(f)
229+
if f.startswith("model") and f.endswith(".pt"):
230+
os.remove(f)
231+
if f in ["lcurve.out", "checkpoint"]:
232+
os.remove(f)

0 commit comments

Comments
 (0)