Skip to content

Commit 352c149

Browse files
committed
fix typo in wrapper and UT
1 parent 1acc58e commit 352c149

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

deepmd/pt/train/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def forward(
191191
if self.modifier is not None:
192192
modifier_pred = self.modifier(**input_dict)
193193
for k, v in modifier_pred.items():
194-
model_pred[k] = model_pred[k] - v
194+
model_pred[k] = model_pred[k] + v
195195
return model_pred, None, None
196196
else:
197197
natoms = atype.shape[-1]

deepmd/utils/data.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,6 @@ def get_single_frame(self, index: int, num_worker: int) -> dict:
507507
frame_data,
508508
self,
509509
)
510-
# Wait for completion and propagate any exceptions
511-
future.result()
512510
if self.use_modifier_cache:
513511
# Cache the modified frame to avoid recomputation
514512
self._modified_frame_cache[index] = copy.deepcopy(frame_data)

source/tests/pt/test_data_modifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,11 +374,11 @@ def test_inference(self):
374374
model_pred = model.eval(coord, cell, atype)
375375
modifier_pred = modifier.eval(coord, cell, atype)
376376
model_pred_ref = model_ref.eval(coord, cell, atype)
377-
# expected: output_model - sfactor * output_modifier
377+
# expected: output_model = output_model_ref + sfactor * output_modifier
378378
for ii in range(3):
379379
np.testing.assert_allclose(
380380
model_pred[ii],
381-
model_pred_ref[ii] - sfactor * modifier_pred[ii],
381+
model_pred_ref[ii] + sfactor * modifier_pred[ii],
382382
rtol=1e-5,
383383
atol=1e-8,
384384
)

0 commit comments

Comments
 (0)