Skip to content

Commit c68fcbd

Browse files
authoredAug 6, 2024··
Update inputDict of MACEForce (#86)
1 parent 2d0fcef commit c68fcbd

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed
 

‎openmmml/models/macepotential.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,6 @@ def __init__(
283283
self.register_buffer("batch", torch.zeros(nodeAttrs.shape[0], dtype=torch.long, requires_grad=False))
284284
self.register_buffer("pbc", torch.tensor([periodic, periodic, periodic], dtype=torch.bool, requires_grad=False))
285285

286-
self.inputDict = {
287-
"ptr": self.ptr,
288-
"node_attrs": self.node_attrs,
289-
"batch": self.batch,
290-
"pbc": self.pbc,
291-
}
292-
293286
def _getNeighborPairs(
294287
self, positions: torch.Tensor, cell: Optional[torch.Tensor]
295288
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -369,12 +362,19 @@ def forward(
369362
edgeIndex, shifts = self._getNeighborPairs(positions, cell)
370363

371364
# Update input dictionary.
372-
self.inputDict["positions"] = positions
373-
self.inputDict["edge_index"] = edgeIndex
374-
self.inputDict["shifts"] = shifts
365+
inputDict = {
366+
"ptr": self.ptr,
367+
"node_attrs": self.node_attrs,
368+
"batch": self.batch,
369+
"pbc": self.pbc,
370+
"positions": positions,
371+
"edge_index": edgeIndex,
372+
"shifts": shifts,
373+
"cell": cell if cell is not None else torch.zeros(3, 3, dtype=self.dtype),
374+
}
375375

376376
# Predict the energy.
377-
energy = self.model(self.inputDict, compute_force=False)[
377+
energy = self.model(inputDict, compute_force=False)[
378378
self.returnEnergyType
379379
]
380380

0 commit comments

Comments
 (0)
Please sign in to comment.