Skip to content

Commit 859292f

Browse files
committed
add UT for data modifier in pt model training
1 parent 03e4edc commit 859292f

File tree

1 file changed

+232
-0
lines changed

1 file changed

+232
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import os
4+
import unittest
5+
from pathlib import (
6+
Path,
7+
)
8+
9+
import numpy as np
10+
11+
from deepmd.pt.entrypoints.main import (
12+
get_trainer,
13+
)
14+
from deepmd.pt.modifier.base_modifier import (
15+
BaseModifier,
16+
)
17+
from deepmd.pt.utils.utils import (
18+
to_numpy_array,
19+
)
20+
from deepmd.utils.argcheck import (
21+
modifier_args_plugin,
22+
)
23+
from deepmd.utils.data import (
24+
DeepmdData,
25+
)
26+
27+
28+
@modifier_args_plugin.register("random_tester")
29+
def modifier_random_tester() -> list:
30+
return []
31+
32+
33+
@modifier_args_plugin.register("zero_tester")
34+
def modifier_zero_tester() -> list:
35+
return []
36+
37+
38+
@BaseModifier.register("random_tester")
39+
class ModifierRandomTester(BaseModifier):
40+
def __new__(cls) -> BaseModifier:
41+
return super().__new__(cls)
42+
43+
def __init__(self) -> None:
44+
"""Construct a basic model for different tasks."""
45+
super().__init__()
46+
self.modifier_type = "tester"
47+
48+
def serialize(self) -> dict:
49+
"""Serialize the modifier.
50+
51+
Returns
52+
-------
53+
dict
54+
The serialized data
55+
"""
56+
data = {
57+
"@class": "Modifier",
58+
"type": self.modifier_type,
59+
"@version": 3,
60+
}
61+
return data
62+
63+
@classmethod
64+
def deserialize(cls, data: dict) -> "BaseModifier":
65+
"""Deserialize the modifier.
66+
67+
Parameters
68+
----------
69+
data : dict
70+
The serialized data
71+
72+
Returns
73+
-------
74+
BaseModifier
75+
The deserialized modifier
76+
"""
77+
data = data.copy()
78+
modifier = cls(**data)
79+
return modifier
80+
81+
def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
82+
"""Multiply by a random factor."""
83+
if (
84+
"find_energy" not in data
85+
and "find_force" not in data
86+
and "find_virial" not in data
87+
):
88+
return
89+
90+
if "find_energy" in data and data["find_energy"] == 1.0:
91+
data["energy"] = data["energy"] * np.random.Generator()
92+
if "find_force" in data and data["find_force"] == 1.0:
93+
data["force"] = data["force"] * np.random.Generator()
94+
if "find_virial" in data and data["find_virial"] == 1.0:
95+
data["virial"] = data["virial"] * np.random.Generator()
96+
97+
98+
@BaseModifier.register("zero_tester")
99+
class ModifierZeroTester(BaseModifier):
100+
def __new__(cls) -> BaseModifier:
101+
return super().__new__(cls)
102+
103+
def __init__(self) -> None:
104+
"""Construct a basic model for different tasks."""
105+
super().__init__()
106+
self.modifier_type = "tester"
107+
108+
def serialize(self) -> dict:
109+
"""Serialize the modifier.
110+
111+
Returns
112+
-------
113+
dict
114+
The serialized data
115+
"""
116+
data = {
117+
"@class": "Modifier",
118+
"type": self.modifier_type,
119+
"@version": 3,
120+
}
121+
return data
122+
123+
@classmethod
124+
def deserialize(cls, data: dict) -> "BaseModifier":
125+
"""Deserialize the modifier.
126+
127+
Parameters
128+
----------
129+
data : dict
130+
The serialized data
131+
132+
Returns
133+
-------
134+
BaseModifier
135+
The deserialized modifier
136+
"""
137+
data = data.copy()
138+
modifier = cls(**data)
139+
return modifier
140+
141+
def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
142+
"""Multiply by a random factor."""
143+
if (
144+
"find_energy" not in data
145+
and "find_force" not in data
146+
and "find_virial" not in data
147+
):
148+
return
149+
150+
if "find_energy" in data and data["find_energy"] == 1.0:
151+
data["energy"] -= data["energy"]
152+
if "find_force" in data and data["find_force"] == 1.0:
153+
data["force"] -= data["force"]
154+
if "find_virial" in data and data["find_virial"] == 1.0:
155+
data["virial"] -= data["virial"]
156+
157+
158+
class TestDataModifier(unittest.TestCase):
159+
def setUp(self) -> None:
160+
"""Set up test fixtures."""
161+
input_json = str(Path(__file__).parent / "water/se_e2_a.json")
162+
with open(input_json, encoding="utf-8") as f:
163+
config = json.load(f)
164+
config["training"]["numb_steps"] = 10
165+
config["training"]["save_freq"] = 1
166+
config["learning_rate"]["start_lr"] = 1.0
167+
config["training"]["training_data"]["systems"] = [
168+
str(Path(__file__).parent / "water/data/single")
169+
]
170+
config["training"]["validation_data"]["systems"] = [
171+
str(Path(__file__).parent / "water/data/single")
172+
]
173+
self.config = config
174+
175+
def test_init_modify_data(self):
176+
"""Ensure modify_data applied."""
177+
tmp_config = self.config.copy()
178+
# add tester data modifier
179+
tmp_config["model"]["modifier"] = {"type": "zero_tester"}
180+
181+
# data modification is finished in __init__
182+
trainer = get_trainer(tmp_config)
183+
184+
# training data
185+
training_data = trainer.get_data(is_train=True)
186+
# validation data
187+
validation_data = trainer.get_data(is_train=False)
188+
189+
for dataset in [training_data, validation_data]:
190+
for kw in ["energy", "force"]:
191+
data = to_numpy_array(dataset[1][kw])
192+
np.testing.assert_allclose(data, np.zeros_like(data))
193+
194+
def test_full_modify_data(self):
195+
"""Ensure modify_data only applied once."""
196+
tmp_config = self.config.copy()
197+
# add tester data modifier
198+
tmp_config["model"]["modifier"] = {"type": "random_tester"}
199+
200+
# data modification is finished in __init__
201+
trainer = get_trainer(tmp_config)
202+
203+
# training data
204+
training_data_before = trainer.get_data(is_train=True)
205+
# validation data
206+
validation_data_before = trainer.get_data(is_train=False)
207+
208+
trainer.run()
209+
210+
# training data
211+
training_data_after = trainer.get_data(is_train=True)
212+
# validation data
213+
validation_data_after = trainer.get_data(is_train=False)
214+
215+
for kw in ["energy", "force"]:
216+
np.testing.assert_allclose(
217+
to_numpy_array(training_data_before[1][kw]),
218+
to_numpy_array(training_data_after[1][kw]),
219+
)
220+
np.testing.assert_allclose(
221+
to_numpy_array(validation_data_before[1][kw]),
222+
to_numpy_array(validation_data_after[1][kw]),
223+
)
224+
225+
def tearDown(self) -> None:
226+
for f in os.listdir("."):
227+
if f.startswith("frozen_model") and f.endswith(".pth"):
228+
os.remove(f)
229+
if f.startswith("model") and f.endswith(".pt"):
230+
os.remove(f)
231+
if f in ["lcurve.out", "checkpoint"]:
232+
os.remove(f)

0 commit comments

Comments
 (0)