Skip to content

Commit 3ebb55c

Browse files
committed
add regularization hyperparameter
1 parent a52a9bb commit 3ebb55c

File tree

2 files changed

+124
-7
lines changed

2 files changed

+124
-7
lines changed

python/ffsim/optimize/stochastic_reconfiguration.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def minimize_stochastic_reconfiguration(
3737
cond: float = 1e-4,
3838
epsilon: float = 1e-8,
3939
gtol: float = 1e-5,
40+
regularization: float = 0.0,
4041
variation: float = 1.0,
42+
optimize_regularization: bool = True,
4143
optimize_variation: bool = True,
4244
optimize_kwargs: dict | None = None,
4345
callback: Callable[[OptimizeResult], Any] | None = None,
@@ -60,9 +62,16 @@ def minimize_stochastic_reconfiguration(
6062
epsilon: Increment to use for approximating the gradient using
6163
finite difference.
6264
gtol: Convergence threshold for the norm of the projected gradient.
65+
regularization: Hyperparameter controlling regularization of the
66+
overlap matrix. Its value must be positive. A larger value results in
67+
greater regularization.
6368
variation: TODO Hyperparameter controlling the size of parameter variations
6469
used in the linear expansion of the wavefunction. Its value must be
6570
positive.
71+
optimize_regularization: Whether to optimize the `regularization` hyperparameter
72+
in each iteration. Optimizing hyperparameters incurs more function and
73+
energy evaluations in each iteration, but may improve convergence.
74+
The optimization is performed using `scipy.optimize.minimize`_.
6675
optimize_variation: Whether to optimize the `variation` hyperparameter
6776
in each iteration. Optimizing hyperparameters incurs more function and
6877
energy evaluations in each iteration, but may improve convergence.
@@ -117,7 +126,6 @@ def minimize_stochastic_reconfiguration(
117126
if optimize_kwargs is None:
118127
optimize_kwargs = dict(method="L-BFGS-B")
119128

120-
variation_param = math.sqrt(variation)
121129
params = x0.copy()
122130
converged = False
123131
intermediate_result = OptimizeResult(
@@ -140,27 +148,81 @@ def minimize_stochastic_reconfiguration(
140148
intermediate_result.fun = energy
141149
intermediate_result.jac = grad
142150
intermediate_result.overlap_mat = overlap_mat
151+
intermediate_result.regularization = regularization
143152
intermediate_result.variation = variation
144153
callback(intermediate_result)
145154

146155
if np.linalg.norm(grad) < gtol:
147156
converged = True
148157
break
149158

150-
if optimize_variation:
159+
if optimize_regularization and optimize_variation:
160+
161+
def f(x: np.ndarray) -> float:
162+
(regularization_param, variation_param) = x
163+
regularization = regularization_param**2
164+
variation = variation_param**2
165+
param_update = _get_param_update(
166+
grad,
167+
overlap_mat,
168+
regularization=regularization,
169+
variation=variation,
170+
cond=cond,
171+
)
172+
vec = params_to_vec(params + param_update)
173+
return np.vdot(vec, hamiltonian @ vec).real
174+
175+
regularization_param = math.sqrt(regularization)
176+
variation_param = math.sqrt(variation)
177+
result = minimize(
178+
f,
179+
x0=[regularization_param, variation_param],
180+
**optimize_kwargs,
181+
)
182+
(regularization_param, variation_param) = result.x
183+
regularization = regularization_param**2
184+
variation = variation_param**2
185+
186+
elif optimize_regularization:
187+
188+
def f(x: np.ndarray) -> float:
189+
(regularization_param,) = x
190+
regularization = regularization_param**2
191+
param_update = _get_param_update(
192+
grad,
193+
overlap_mat,
194+
regularization=regularization,
195+
variation=variation,
196+
cond=cond,
197+
)
198+
vec = params_to_vec(params + param_update)
199+
return np.vdot(vec, hamiltonian @ vec).real
200+
201+
regularization_param = math.sqrt(regularization)
202+
result = minimize(
203+
f,
204+
x0=[regularization_param],
205+
**optimize_kwargs,
206+
)
207+
(regularization_param,) = result.x
208+
regularization = regularization_param**2
209+
210+
elif optimize_variation:
151211

152212
def f(x: np.ndarray) -> float:
153213
(variation_param,) = x
154214
variation = variation_param**2
155215
param_update = _get_param_update(
156216
grad,
157217
overlap_mat,
158-
variation,
218+
regularization=regularization,
219+
variation=variation,
159220
cond=cond,
160221
)
161222
vec = params_to_vec(params + param_update)
162223
return np.vdot(vec, hamiltonian @ vec).real
163224

225+
variation_param = math.sqrt(variation)
164226
result = minimize(
165227
f,
166228
x0=[variation_param],
@@ -172,7 +234,8 @@ def f(x: np.ndarray) -> float:
172234
param_update = _get_param_update(
173235
grad,
174236
overlap_mat,
175-
variation,
237+
regularization=regularization,
238+
variation=variation,
176239
cond=cond,
177240
)
178241
params = params + param_update
@@ -217,7 +280,15 @@ def _sr_matrices(
217280

218281

219282
def _get_param_update(
220-
grad: np.ndarray, overlap_mat: np.ndarray, variation: float, cond: float
283+
grad: np.ndarray,
284+
overlap_mat: np.ndarray,
285+
regularization: float,
286+
variation: float,
287+
cond: float,
221288
) -> np.ndarray:
222-
x, _, _, _ = scipy.linalg.lstsq(overlap_mat, -0.5 * variation * grad, cond=cond)
289+
x, _, _, _ = scipy.linalg.lstsq(
290+
overlap_mat + regularization * np.eye(overlap_mat.shape[0]),
291+
-0.5 * variation * grad,
292+
cond=cond,
293+
)
223294
return x

tests/python/optimize/stochastic_reconfiguration_test.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def callback(intermediate_result):
6363
)
6464
if hasattr(intermediate_result, "jac"):
6565
info["jac"].append(intermediate_result.jac)
66+
if hasattr(intermediate_result, "regularization"):
67+
info["regularization"].append(intermediate_result.regularization)
6668
if hasattr(intermediate_result, "variation"):
6769
info["variation"].append(intermediate_result.variation)
6870

@@ -85,6 +87,9 @@ def callback(intermediate_result):
8587
params_to_vec,
8688
x0=x0,
8789
hamiltonian=hamiltonian,
90+
regularization=1e-4,
91+
variation=0.9,
92+
optimize_regularization=False,
8893
optimize_variation=False,
8994
callback=callback,
9095
)
@@ -94,7 +99,48 @@ def callback(intermediate_result):
9499
np.testing.assert_allclose(energy(params), fun)
95100
assert result.nit <= 30
96101
assert result.nit < result.nlinop < result.nfev
97-
assert set(info["variation"]) == {1.0}
102+
assert set(info["regularization"]) == {1e-4}
103+
assert set(info["variation"]) == {0.9}
104+
105+
# optimization without optimizing regularization
106+
info = defaultdict(list)
107+
result = ffsim.optimize.minimize_stochastic_reconfiguration(
108+
params_to_vec,
109+
x0=x0,
110+
hamiltonian=hamiltonian,
111+
regularization=1e-4,
112+
variation=0.9,
113+
optimize_regularization=False,
114+
callback=callback,
115+
)
116+
np.testing.assert_allclose(energy(result.x), result.fun)
117+
np.testing.assert_allclose(result.fun, -0.970773)
118+
for params, fun in zip(info["x"], info["fun"]):
119+
np.testing.assert_allclose(energy(params), fun)
120+
assert result.nit <= 30
121+
assert result.nit < result.nlinop < result.nfev
122+
assert set(info["regularization"]) == {1e-4}
123+
assert len(set(info["variation"])) > 1
124+
125+
# optimization without optimizing variation
126+
info = defaultdict(list)
127+
result = ffsim.optimize.minimize_stochastic_reconfiguration(
128+
params_to_vec,
129+
x0=x0,
130+
hamiltonian=hamiltonian,
131+
regularization=1e-4,
132+
variation=0.9,
133+
optimize_variation=False,
134+
callback=callback,
135+
)
136+
np.testing.assert_allclose(energy(result.x), result.fun)
137+
np.testing.assert_allclose(result.fun, -0.970773)
138+
for params, fun in zip(info["x"], info["fun"]):
139+
np.testing.assert_allclose(energy(params), fun)
140+
assert result.nit <= 30
141+
assert result.nit < result.nlinop < result.nfev
142+
assert set(info["regularization"]) != {1e-4}
143+
assert set(info["variation"]) == {0.9}
98144

99145
# optimization with maxiter
100146
info = defaultdict(list)

0 commit comments

Comments
 (0)