@@ -283,13 +283,6 @@ def __init__(
283
283
self .register_buffer ("batch" , torch .zeros (nodeAttrs .shape [0 ], dtype = torch .long , requires_grad = False ))
284
284
self .register_buffer ("pbc" , torch .tensor ([periodic , periodic , periodic ], dtype = torch .bool , requires_grad = False ))
285
285
286
- self .inputDict = {
287
- "ptr" : self .ptr ,
288
- "node_attrs" : self .node_attrs ,
289
- "batch" : self .batch ,
290
- "pbc" : self .pbc ,
291
- }
292
-
293
286
def _getNeighborPairs (
294
287
self , positions : torch .Tensor , cell : Optional [torch .Tensor ]
295
288
) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -369,12 +362,19 @@ def forward(
369
362
edgeIndex , shifts = self ._getNeighborPairs (positions , cell )
370
363
371
364
# Update input dictionary.
372
- self .inputDict ["positions" ] = positions
373
- self .inputDict ["edge_index" ] = edgeIndex
374
- self .inputDict ["shifts" ] = shifts
365
+ inputDict = {
366
+ "ptr" : self .ptr ,
367
+ "node_attrs" : self .node_attrs ,
368
+ "batch" : self .batch ,
369
+ "pbc" : self .pbc ,
370
+ "positions" : positions ,
371
+ "edge_index" : edgeIndex ,
372
+ "shifts" : shifts ,
373
+ "cell" : cell if cell is not None else torch .zeros (3 , 3 , dtype = self .dtype ),
374
+ }
375
375
376
376
# Predict the energy.
377
- energy = self .model (self . inputDict , compute_force = False )[
377
+ energy = self .model (inputDict , compute_force = False )[
378
378
self .returnEnergyType
379
379
]
380
380
0 commit comments