@@ -37,7 +37,9 @@ def minimize_stochastic_reconfiguration(
3737 cond : float = 1e-4 ,
3838 epsilon : float = 1e-8 ,
3939 gtol : float = 1e-5 ,
40+ regularization : float = 0.0 ,
4041 variation : float = 1.0 ,
42+ optimize_regularization : bool = True ,
4143 optimize_variation : bool = True ,
4244 optimize_kwargs : dict | None = None ,
4345 callback : Callable [[OptimizeResult ], Any ] | None = None ,
@@ -60,9 +62,16 @@ def minimize_stochastic_reconfiguration(
6062 epsilon: Increment to use for approximating the gradient using
6163 finite difference.
6264 gtol: Convergence threshold for the norm of the projected gradient.
65+ regularization: Hyperparameter controlling regularization of the
66+ overlap matrix. Its value must be positive. A larger value results in
67+ greater regularization.
6368 variation: TODO Hyperparameter controlling the size of parameter variations
6469 used in the linear expansion of the wavefunction. Its value must be
6570 positive.
71+ optimize_regularization: Whether to optimize the `regularization` hyperparameter
72+ in each iteration. Optimizing hyperparameters incurs more function and
73+ energy evaluations in each iteration, but may improve convergence.
74+ The optimization is performed using `scipy.optimize.minimize`_.
6675 optimize_variation: Whether to optimize the `variation` hyperparameter
6776 in each iteration. Optimizing hyperparameters incurs more function and
6877 energy evaluations in each iteration, but may improve convergence.
@@ -117,7 +126,6 @@ def minimize_stochastic_reconfiguration(
117126 if optimize_kwargs is None :
118127 optimize_kwargs = dict (method = "L-BFGS-B" )
119128
120- variation_param = math .sqrt (variation )
121129 params = x0 .copy ()
122130 converged = False
123131 intermediate_result = OptimizeResult (
@@ -140,27 +148,81 @@ def minimize_stochastic_reconfiguration(
140148 intermediate_result .fun = energy
141149 intermediate_result .jac = grad
142150 intermediate_result .overlap_mat = overlap_mat
151+ intermediate_result .regularization = regularization
143152 intermediate_result .variation = variation
144153 callback (intermediate_result )
145154
146155 if np .linalg .norm (grad ) < gtol :
147156 converged = True
148157 break
149158
150- if optimize_variation :
159+ if optimize_regularization and optimize_variation :
160+
161+ def f (x : np .ndarray ) -> float :
162+ (regularization_param , variation_param ) = x
163+ regularization = regularization_param ** 2
164+ variation = variation_param ** 2
165+ param_update = _get_param_update (
166+ grad ,
167+ overlap_mat ,
168+ regularization = regularization ,
169+ variation = variation ,
170+ cond = cond ,
171+ )
172+ vec = params_to_vec (params + param_update )
173+ return np .vdot (vec , hamiltonian @ vec ).real
174+
175+ regularization_param = math .sqrt (regularization )
176+ variation_param = math .sqrt (variation )
177+ result = minimize (
178+ f ,
179+ x0 = [regularization_param , variation_param ],
180+ ** optimize_kwargs ,
181+ )
182+ (regularization_param , variation_param ) = result .x
183+ regularization = regularization_param ** 2
184+ variation = variation_param ** 2
185+
186+ elif optimize_regularization :
187+
188+ def f (x : np .ndarray ) -> float :
189+ (regularization_param ,) = x
190+ regularization = regularization_param ** 2
191+ param_update = _get_param_update (
192+ grad ,
193+ overlap_mat ,
194+ regularization = regularization ,
195+ variation = variation ,
196+ cond = cond ,
197+ )
198+ vec = params_to_vec (params + param_update )
199+ return np .vdot (vec , hamiltonian @ vec ).real
200+
201+ regularization_param = math .sqrt (regularization )
202+ result = minimize (
203+ f ,
204+ x0 = [regularization_param ],
205+ ** optimize_kwargs ,
206+ )
207+ (regularization_param ,) = result .x
208+ regularization = regularization_param ** 2
209+
210+ elif optimize_variation :
151211
152212 def f (x : np .ndarray ) -> float :
153213 (variation_param ,) = x
154214 variation = variation_param ** 2
155215 param_update = _get_param_update (
156216 grad ,
157217 overlap_mat ,
158- variation ,
218+ regularization = regularization ,
219+ variation = variation ,
159220 cond = cond ,
160221 )
161222 vec = params_to_vec (params + param_update )
162223 return np .vdot (vec , hamiltonian @ vec ).real
163224
225+ variation_param = math .sqrt (variation )
164226 result = minimize (
165227 f ,
166228 x0 = [variation_param ],
@@ -172,7 +234,8 @@ def f(x: np.ndarray) -> float:
172234 param_update = _get_param_update (
173235 grad ,
174236 overlap_mat ,
175- variation ,
237+ regularization = regularization ,
238+ variation = variation ,
176239 cond = cond ,
177240 )
178241 params = params + param_update
@@ -217,7 +280,15 @@ def _sr_matrices(
217280
218281
219282def _get_param_update (
220- grad : np .ndarray , overlap_mat : np .ndarray , variation : float , cond : float
283+ grad : np .ndarray ,
284+ overlap_mat : np .ndarray ,
285+ regularization : float ,
286+ variation : float ,
287+ cond : float ,
221288) -> np .ndarray :
222- x , _ , _ , _ = scipy .linalg .lstsq (overlap_mat , - 0.5 * variation * grad , cond = cond )
289+ x , _ , _ , _ = scipy .linalg .lstsq (
290+ overlap_mat + regularization * np .eye (overlap_mat .shape [0 ]),
291+ - 0.5 * variation * grad ,
292+ cond = cond ,
293+ )
223294 return x
0 commit comments