@@ -17,7 +17,7 @@ def calc_energy_batch(
1717 edge_index : Tensor ,
1818 cell : Optional [Tensor ] = None ,
1919 pbc : Optional [Tensor ] = None ,
20- shift : Optional [Tensor ] = None ,
20+ shift_pos : Optional [Tensor ] = None ,
2121 batch : Optional [Tensor ] = None ,
2222 batch_edge : Optional [Tensor ] = None ,
2323 damping : str = "zero" ,
@@ -32,7 +32,7 @@ def calc_energy_batch(
3232 edge_index (Tensor): (2, n_edges) edge index within cutoff
3333 cell (Tensor): (n_atoms, 3) cell size in angstrom, None for non periodic system.
3434 pbc (Tensor): (bs, 3) pbc condition, None for non periodic system.
35- shift (Tensor): (n_atoms, 3) shift vector
35+ shift_pos (Tensor): (n_atoms, 3) shift vector (length unit).
3636 batch (Tensor): (n_atoms,) Specify which graph this atom belongs to
3737 batch_edge (Tensor): (n_edges, 3) Specify which graph this edge belongs to
3838 damping (str):
@@ -49,7 +49,7 @@ def calc_energy(
4949 edge_index : Tensor ,
5050 cell : Optional [Tensor ] = None ,
5151 pbc : Optional [Tensor ] = None ,
52- shift : Optional [Tensor ] = None ,
52+ shift_pos : Optional [Tensor ] = None ,
5353 batch : Optional [Tensor ] = None ,
5454 batch_edge : Optional [Tensor ] = None ,
5555 damping : str = "zero" ,
@@ -64,6 +64,7 @@ def calc_energy(
6464 edge_index (Tensor):
6565 cell (Tensor): cell size in angstrom, None for non periodic system.
6666 pbc (Tensor): pbc condition, None for non periodic system.
67+ shift_pos (Tensor): (n_atoms, 3) shift vector (length unit).
6768 batch (Tensor):
6869 batch_edge (Tensor):
6970 damping (str): damping method. "zero", "bj", "zerom", "bjm"
@@ -73,7 +74,7 @@ def calc_energy(
7374 """
7475 with torch .no_grad ():
7576 E_disp = self .calc_energy_batch (
76- Z , pos , edge_index , cell , pbc , shift , batch , batch_edge , damping = damping
77+ Z , pos , edge_index , cell , pbc , shift_pos , batch , batch_edge , damping = damping
7778 )
7879 if batch is None :
7980 return [{"energy" : E_disp .item ()}]
@@ -91,7 +92,7 @@ def calc_energy_and_forces(
9192 edge_index : Tensor ,
9293 cell : Optional [Tensor ] = None ,
9394 pbc : Optional [Tensor ] = None ,
94- shift : Optional [Tensor ] = None ,
95+ shift_pos : Optional [Tensor ] = None ,
9596 batch : Optional [Tensor ] = None ,
9697 batch_edge : Optional [Tensor ] = None ,
9798 damping : str = "zero" ,
@@ -103,6 +104,7 @@ def calc_energy_and_forces(
103104 pos (Tensor): atom positions in angstrom
104105 cell (Tensor): cell size in angstrom, None for non periodic system.
105106 pbc (Tensor): pbc condition, None for non periodic system.
107+ shift_pos (Tensor): (n_atoms, 3) shift vector (length unit).
106108 damping (str): damping method. "zero", "bj", "zerom", "bjm"
107109
108110 Returns:
@@ -117,11 +119,11 @@ def calc_energy_and_forces(
117119 # We need to explicitly include this dependency to calculate cell gradient
118120 # for stress computation.
119121 # pos is assumed to be inside "cell", so relative position `rel_pos` lies between 0~1.
120- assert isinstance (shift , Tensor )
121- shift .requires_grad_ (True )
122+ assert isinstance (shift_pos , Tensor )
123+ shift_pos .requires_grad_ (True )
122124
123125 E_disp = self .calc_energy_batch (
124- Z , pos , edge_index , cell , pbc , shift , batch , batch_edge , damping = damping
126+ Z , pos , edge_index , cell , pbc , shift_pos , batch , batch_edge , damping = damping
125127 )
126128
127129 E_disp .sum ().backward ()
@@ -140,7 +142,7 @@ def calc_energy_and_forces(
140142 if cell is not None :
141143 # stress = torch.mm(cell_grad, cell.T) / cell_volume
142144 # Get stress in Voigt notation (xx, yy, zz, yz, xz, xy)
143- assert isinstance (shift , Tensor )
145+ assert isinstance (shift_pos , Tensor )
144146 voigt_left = [0 , 1 , 2 , 1 , 2 , 0 ]
145147 voigt_right = [0 , 1 , 2 , 2 , 0 , 1 ]
146148 if batch is None :
@@ -149,7 +151,8 @@ def calc_energy_and_forces(
149151 (pos [:, voigt_left ] * pos .grad [:, voigt_right ]).to (torch .float64 ), dim = 0
150152 )
151153 cell_grad += torch .sum (
152- (shift [:, voigt_left ] * shift .grad [:, voigt_right ]).to (torch .float64 ), dim = 0
154+ (shift_pos [:, voigt_left ] * shift_pos .grad [:, voigt_right ]).to (torch .float64 ),
155+ dim = 0 ,
153156 )
154157 stress = cell_grad .to (cell .dtype ) / cell_volume
155158 results_list [0 ]["stress" ] = stress .detach ().cpu ().numpy ()
@@ -166,7 +169,7 @@ def calc_energy_and_forces(
166169 cell_grad .scatter_add_ (
167170 0 ,
168171 batch_edge .view (batch_edge .size ()[0 ], 1 ).expand (batch_edge .size ()[0 ], 6 ),
169- (shift [:, voigt_left ] * shift .grad [:, voigt_right ]).to (torch .float64 ),
172+ (shift_pos [:, voigt_left ] * shift_pos .grad [:, voigt_right ]).to (torch .float64 ),
170173 )
171174 stress = cell_grad .to (cell .dtype ) / cell_volume [:, None ]
172175 stress = stress .detach ().cpu ().numpy ()
0 commit comments