Skip to content

Commit 508759c

Browse files
authored
fix(pt ut): make separated uts deterministic (#4162)
Fix failed uts in #4145 . <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added a `"seed"` property to multiple JSON configuration files, enhancing control over randomness in model training and evaluation. - Introduced a global seed parameter in various test functions to improve reproducibility across test runs. - **Bug Fixes** - Ensured consistent random number generation in tests by integrating a global seed parameter. - **Documentation** - Updated configuration files and test methods to reflect the addition of the seed parameter for clarity and consistency. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0b72dae commit 508759c

15 files changed

+62
-3
lines changed

source/tests/pt/model/models/dpa1.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
"activation_function": "tanh",
2222
"scaling_factor": 1.0,
2323
"normalize": true,
24-
"temperature": 1.0
24+
"temperature": 1.0,
25+
"seed": 1
2526
},
2627
"fitting_net": {
2728
"neuron": [

source/tests/pt/model/models/dpa2.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
"g1_out_conv": false,
4343
"g1_out_mlp": false
4444
},
45+
"seed": 1,
4546
"add_tebd_to_repinit_out": false
4647
},
4748
"fitting_net": {

source/tests/pt/model/test_descriptor_se_r.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_consistency(
6363
resnet_dt=idt,
6464
old_impl=False,
6565
exclude_mask=em,
66+
seed=GLOBAL_SEED,
6667
).to(env.DEVICE)
6768
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
6869
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
@@ -130,6 +131,7 @@ def test_load_stat(self):
130131
precision=prec,
131132
resnet_dt=idt,
132133
old_impl=False,
134+
seed=GLOBAL_SEED,
133135
)
134136
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
135137
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
@@ -180,6 +182,7 @@ def test_jit(
180182
precision=prec,
181183
resnet_dt=idt,
182184
old_impl=False,
185+
seed=GLOBAL_SEED,
183186
)
184187
dd0.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
185188
dd0.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)

source/tests/pt/model/test_dipole_fitting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_consistency(
8787
numb_fparam=nfp,
8888
numb_aparam=nap,
8989
mixed_types=self.dd0.mixed_types(),
90+
seed=GLOBAL_SEED,
9091
).to(env.DEVICE)
9192
ft1 = DPDipoleFitting.deserialize(ft0.serialize())
9293
ft2 = DipoleFittingNet.deserialize(ft1.serialize())
@@ -139,6 +140,7 @@ def test_jit(
139140
numb_fparam=nfp,
140141
numb_aparam=nap,
141142
mixed_types=mixed_types,
143+
seed=GLOBAL_SEED,
142144
).to(env.DEVICE)
143145
torch.jit.script(ft0)
144146

@@ -180,6 +182,7 @@ def test_rot(self):
180182
numb_fparam=nfp,
181183
numb_aparam=nap,
182184
mixed_types=self.dd0.mixed_types(),
185+
seed=GLOBAL_SEED,
183186
).to(env.DEVICE)
184187
if nfp > 0:
185188
ifp = torch.tensor(
@@ -234,6 +237,7 @@ def test_permu(self):
234237
numb_fparam=0,
235238
numb_aparam=0,
236239
mixed_types=self.dd0.mixed_types(),
240+
seed=GLOBAL_SEED,
237241
).to(env.DEVICE)
238242
res = []
239243
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
@@ -280,6 +284,7 @@ def test_trans(self):
280284
numb_fparam=0,
281285
numb_aparam=0,
282286
mixed_types=self.dd0.mixed_types(),
287+
seed=GLOBAL_SEED,
283288
).to(env.DEVICE)
284289
res = []
285290
for xyz in [self.coord, coord_s]:
@@ -327,6 +332,7 @@ def setUp(self):
327332
numb_fparam=0,
328333
numb_aparam=0,
329334
mixed_types=self.dd0.mixed_types(),
335+
seed=GLOBAL_SEED,
330336
).to(env.DEVICE)
331337
self.type_mapping = ["O", "H", "B"]
332338
self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping)

source/tests/pt/model/test_dpa1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_consistency(
7171
use_econf_tebd=ect,
7272
type_map=["O", "H"] if ect else None,
7373
old_impl=False,
74+
seed=GLOBAL_SEED,
7475
).to(env.DEVICE)
7576
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
7677
dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)
@@ -125,6 +126,7 @@ def test_consistency(
125126
resnet_dt=idt,
126127
smooth_type_embedding=sm,
127128
old_impl=True,
129+
seed=GLOBAL_SEED,
128130
).to(env.DEVICE)
129131
dd0_state_dict = dd0.se_atten.state_dict()
130132
dd3_state_dict = dd3.se_atten.state_dict()
@@ -210,6 +212,7 @@ def test_jit(
210212
use_econf_tebd=ect,
211213
type_map=["O", "H"] if ect else None,
212214
old_impl=False,
215+
seed=GLOBAL_SEED,
213216
)
214217
dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
215218
dd0.se_atten.dstd = torch.tensor(dstd, dtype=dtype, device=env.DEVICE)

source/tests/pt/model/test_dpa2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
PRECISION_DICT,
2121
)
2222

23+
from ...seed import (
24+
GLOBAL_SEED,
25+
)
2326
from .test_env_mat import (
2427
TestCaseSingleFrameWithNlist,
2528
)
@@ -152,6 +155,7 @@ def test_consistency(
152155
use_econf_tebd=ect,
153156
type_map=["O", "H"] if ect else None,
154157
old_impl=False,
158+
seed=GLOBAL_SEED,
155159
).to(env.DEVICE)
156160

157161
dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
@@ -201,6 +205,7 @@ def test_consistency(
201205
add_tebd_to_repinit_out=False,
202206
precision=prec,
203207
old_impl=True,
208+
seed=GLOBAL_SEED,
204209
).to(env.DEVICE)
205210
dd0_state_dict = dd0.state_dict()
206211
dd3_state_dict = dd3.state_dict()
@@ -346,6 +351,7 @@ def test_jit(
346351
use_econf_tebd=ect,
347352
type_map=["O", "H"] if ect else None,
348353
old_impl=False,
354+
seed=GLOBAL_SEED,
349355
).to(env.DEVICE)
350356

351357
dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)

source/tests/pt/model/test_embedding_net.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
)
4040
from deepmd.tf.descriptor import DescrptSeA as DescrptSeA_tf
4141

42+
from ...seed import (
43+
GLOBAL_SEED,
44+
)
4245
from ..test_finetune import (
4346
energy_data_requirement,
4447
)
@@ -153,7 +156,7 @@ def test_consistency(self):
153156
sel=self.sel,
154157
neuron=self.filter_neuron,
155158
axis_neuron=self.axis_neuron,
156-
seed=1,
159+
seed=GLOBAL_SEED,
157160
)
158161
dp_embedding, dp_force, dp_vars = base_se_a(
159162
descriptor=dp_d,

source/tests/pt/model/test_ener_fitting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def test_consistency(
6565
mixed_types=mixed_types,
6666
exclude_types=et,
6767
neuron=nn,
68+
seed=GLOBAL_SEED,
6869
).to(env.DEVICE)
6970
ft1 = DPInvarFitting.deserialize(ft0.serialize())
7071
ft2 = InvarFitting.deserialize(ft0.serialize())
@@ -168,6 +169,7 @@ def test_jit(
168169
numb_aparam=nap,
169170
mixed_types=mixed_types,
170171
exclude_types=et,
172+
seed=GLOBAL_SEED,
171173
).to(env.DEVICE)
172174
torch.jit.script(ft0)
173175

@@ -177,6 +179,7 @@ def test_get_set(self):
177179
self.nt,
178180
3,
179181
1,
182+
seed=GLOBAL_SEED,
180183
)
181184
rng = np.random.default_rng(GLOBAL_SEED)
182185
foo = rng.normal([3, 4])

source/tests/pt/model/test_permutation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
"temperature": 1.0,
8989
"set_davg_zero": True,
9090
"type_one_side": True,
91+
"seed": 1,
9192
},
9293
"fitting_net": {
9394
"neuron": [24, 24, 24],
@@ -155,6 +156,7 @@
155156
"update_g2_has_attn": True,
156157
"attn2_has_gate": True,
157158
},
159+
"seed": 1,
158160
"add_tebd_to_repinit_out": False,
159161
},
160162
"fitting_net": {
@@ -207,6 +209,7 @@
207209
"g1_out_conv": True,
208210
"g1_out_mlp": True,
209211
},
212+
"seed": 1,
210213
"add_tebd_to_repinit_out": False,
211214
},
212215
"fitting_net": {
@@ -235,6 +238,7 @@
235238
"temperature": 1.0,
236239
"set_davg_zero": True,
237240
"type_one_side": True,
241+
"seed": 1,
238242
},
239243
"fitting_net": {
240244
"neuron": [24, 24, 24],
@@ -264,6 +268,7 @@
264268
"scaling_factor": 1.0,
265269
"normalize": True,
266270
"temperature": 1.0,
271+
"seed": 1,
267272
},
268273
{
269274
"type": "dpa2",
@@ -296,6 +301,7 @@
296301
"update_g2_has_attn": True,
297302
"attn2_has_gate": True,
298303
},
304+
"seed": 1,
299305
"add_tebd_to_repinit_out": False,
300306
},
301307
],

source/tests/pt/model/test_polarizability_fitting.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_consistency(
7777
mixed_types=self.dd0.mixed_types(),
7878
fit_diag=fit_diag,
7979
scale=scale,
80+
seed=GLOBAL_SEED,
8081
).to(env.DEVICE)
8182
ft1 = DPPolarFitting.deserialize(ft0.serialize())
8283
ft2 = PolarFittingNet.deserialize(ft0.serialize())
@@ -143,6 +144,7 @@ def test_jit(
143144
numb_aparam=nap,
144145
mixed_types=mixed_types,
145146
fit_diag=fit_diag,
147+
seed=GLOBAL_SEED,
146148
).to(env.DEVICE)
147149
torch.jit.script(ft0)
148150

@@ -186,6 +188,7 @@ def test_rot(self):
186188
mixed_types=self.dd0.mixed_types(),
187189
fit_diag=fit_diag,
188190
scale=scale,
191+
seed=GLOBAL_SEED,
189192
).to(env.DEVICE)
190193
if nfp > 0:
191194
ifp = torch.tensor(
@@ -248,6 +251,7 @@ def test_permu(self):
248251
mixed_types=self.dd0.mixed_types(),
249252
fit_diag=fit_diag,
250253
scale=scale,
254+
seed=GLOBAL_SEED,
251255
).to(env.DEVICE)
252256
res = []
253257
for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]:
@@ -298,6 +302,7 @@ def test_trans(self):
298302
mixed_types=self.dd0.mixed_types(),
299303
fit_diag=fit_diag,
300304
scale=scale,
305+
seed=GLOBAL_SEED,
301306
).to(env.DEVICE)
302307
res = []
303308
for xyz in [self.coord, coord_s]:
@@ -347,6 +352,7 @@ def setUp(self):
347352
numb_fparam=0,
348353
numb_aparam=0,
349354
mixed_types=self.dd0.mixed_types(),
355+
seed=GLOBAL_SEED,
350356
).to(env.DEVICE)
351357
self.type_mapping = ["O", "H", "B"]
352358
self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)

0 commit comments

Comments
 (0)