Skip to content

Commit d4919c3

Browse files
committed
fix typo in wrapper and UT
1 parent 1acc58e commit d4919c3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

deepmd/pt/train/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def forward(
191191
if self.modifier is not None:
192192
modifier_pred = self.modifier(**input_dict)
193193
for k, v in modifier_pred.items():
194-
model_pred[k] = model_pred[k] - v
194+
model_pred[k] = model_pred[k] + v
195195
return model_pred, None, None
196196
else:
197197
natoms = atype.shape[-1]

source/tests/pt/test_data_modifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,11 @@ def test_inference(self):
374374
model_pred = model.eval(coord, cell, atype)
375375
modifier_pred = modifier.eval(coord, cell, atype)
376376
model_pred_ref = model_ref.eval(coord, cell, atype)
377-
# expected: output_model - sfactor * output_modifier
377+
# expected: output_model = output_model_ref + sfactor * output_modifier
378378
for ii in range(3):
379379
np.testing.assert_allclose(
380380
model_pred[ii],
381-
model_pred_ref[ii] - sfactor * modifier_pred[ii],
381+
model_pred_ref[ii] + sfactor * modifier_pred[ii],
382382
rtol=1e-5,
383383
atol=1e-8,
384384
)

0 commit comments

Comments
 (0)