@@ -44,12 +44,12 @@ def _ncoord(
4444 """
4545 if cutoff is not None :
4646 # Calculate _ncoord only for r < cutoff
47- within_cutoff = r <= cutoff
48- r = r [within_cutoff ]
49- # Zi = Zi[within_cutoff ]
50- # Zj = Zj[within_cutoff ]
51- idx_i = idx_i [within_cutoff ]
52- idx_j = idx_j [within_cutoff ]
47+ indices = torch . nonzero ( r <= cutoff ). reshape ( - 1 )
48+ r = r [indices ]
49+ # Zi = Zi[indices ]
50+ # Zj = Zj[indices ]
51+ idx_i = idx_i [indices ]
52+ idx_j = idx_j [indices ]
5353 Zi = Z [idx_i ]
5454 Zj = Z [idx_j ]
5555 rco = rcov [Zi ] + rcov [Zj ] # (n_edges,)
@@ -67,7 +67,13 @@ def _ncoord(
6767
6868
6969def _getc6 (
70- Zi : Tensor , Zj : Tensor , nci : Tensor , ncj : Tensor , c6ab : Tensor , k3 : float = d3_k3
70+ Zi : Tensor ,
71+ Zj : Tensor ,
72+ nci : Tensor ,
73+ ncj : Tensor ,
74+ c6ab : Tensor ,
75+ k3 : float = d3_k3 ,
76+ n_chunks : Optional [int ] = None ,
7177) -> Tensor :
7278 """interpolate c6
7379
@@ -78,10 +84,30 @@ def _getc6(
7884 ncj: (n_edges,)
7985 c6ab:
8086 k3:
87+ n_chunks:
8188
8289 Returns:
8390 c6 (Tensor): (n_edges,)
8491 """
92+ if n_chunks is None :
93+ return _getc6_impl (Zi , Zj , nci , ncj , c6ab , k3 = k3 )
94+
95+ # TODO(takagi) More balanced split like torch.tensor_split as, for example,
96+ # trying to split 13 elements into 6 chunks currently gives 5 chunks:
97+ # ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12])
98+ n_chunks_t = torch .tensor (n_chunks )
99+ chunk_size = torch .ceil (Zi .shape [0 ] / n_chunks_t ).to (torch .int64 )
100+ c6s = []
101+ for i in range (0 , n_chunks ):
102+ chunk_start = i * chunk_size
103+ slc = slice (chunk_start , chunk_start + chunk_size )
104+ c6s .append (_getc6_impl (Zi [slc ], Zj [slc ], nci [slc ], ncj [slc ], c6ab , k3 = k3 ))
105+ return torch .cat (c6s , 0 )
106+
107+
108+ def _getc6_impl (
109+ Zi : Tensor , Zj : Tensor , nci : Tensor , ncj : Tensor , c6ab : Tensor , k3 : float = d3_k3
110+ ) -> Tensor :
85111 # gather the relevant entries from the table
86112 # c6ab (95, 95, 5, 5, 3) --> c6ab_ (n_edges, 5, 5, 3)
87113 c6ab_ = c6ab [Zi , Zj ].type (nci .dtype )
@@ -98,7 +124,7 @@ def _getc6(
98124 if cn0 .size (0 ) == 0 :
99125 k3_rnc = (k3 * r ).view (n_edges , n_c6ab )
100126 else :
101- k3_rnc = torch .where (cn0 > 0.0 , k3 * r , - 1.0e20 * torch . ones_like ( r ) ).view (n_edges , n_c6ab )
127+ k3_rnc = torch .where (cn0 > 0.0 , k3 * r , - 1.0e20 ).view (n_edges , n_c6ab )
102128 r_ratio = torch .softmax (k3_rnc , dim = 1 )
103129 c6 = (r_ratio * cn0 .view (n_edges , n_c6ab )).sum (dim = 1 )
104130 return c6
@@ -130,6 +156,7 @@ def edisp(
130156 damping : str = "zero" ,
131157 bidirectional : bool = False ,
132158 abc : bool = False ,
159+ n_chunks : Optional [int ] = None ,
133160):
134161 """compute d3 dispersion energy in Hartree
135162
@@ -159,6 +186,7 @@ def edisp(
159186 damping (str): damping method, only "zero" is supported.
160187 bidirectional (bool): calculated `edge_index` is bidirectional or not.
161188 abc (bool): ATM 3-body interaction
189+ n_chunks (int or None): number of times to split c6 computation to reduce peak memory
162190
163191 Returns:
164192 energy: (n_graphs,) Energy in Hartree unit.
@@ -190,7 +218,7 @@ def edisp(
190218
191219 nci = nc [idx_i ]
192220 ncj = nc [idx_j ]
193- c6 = _getc6 (Zi , Zj , nci , ncj , c6ab = c6ab , k3 = k3 ) # c6 coefficients
221+ c6 = _getc6 (Zi , Zj , nci , ncj , c6ab = c6ab , k3 = k3 , n_chunks = n_chunks ) # c6 coefficients
194222
195223 c8 = 3 * c6 * r2r4 [Zi ].type (c6 .dtype ) * r2r4 [Zj ].type (c6 .dtype ) # c8 coefficient
196224
0 commit comments