Skip to content

Commit

Permalink
alpha max fix for iterative methods
Browse files Browse the repository at this point in the history
  • Loading branch information
anujanegi committed Jan 2, 2025
1 parent 515539e commit ade971e
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions bsi_zoo/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _compute_reginv2(sing, n_nzero, lambda2):
reginv = np.zeros_like(sing)
sing = sing[:n_nzero]
with np.errstate(invalid="ignore"): # if lambda2==0
reginv[:n_nzero] = np.where(sing > 0, sing / (sing ** 2 + lambda2), 0)
reginv[:n_nzero] = np.where(sing > 0, sing / (sing**2 + lambda2), 0)
return reginv


Expand Down Expand Up @@ -119,7 +119,7 @@ def _compute_eloreta_kernel(L, *, lambda2, n_orient, whitener, loose=1.0, max_it
# Outer product
R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1)
else:
R_prior = source_std ** 2
R_prior = source_std**2

# The following was adapted under BSD license by permission of Guido Nolte
if force_equal or n_orient == 1:
Expand Down Expand Up @@ -207,7 +207,7 @@ def _solve_reweighted_lasso(
n_positions = L_w.shape[1] // n_orient
lc = np.empty(n_positions)
for j in range(n_positions):
L_j = L_w[:, (j * n_orient): ((j + 1) * n_orient)]
L_j = L_w[:, (j * n_orient) : ((j + 1) * n_orient)]
lc[j] = np.linalg.norm(np.dot(L_j.T, L_j), ord=2)
coef_, active_set, _ = _mixed_norm_solver_bcd(
y,
Expand Down Expand Up @@ -341,7 +341,7 @@ def denom_fun(x):

if update_mode == 1:
# MacKay fixed point update (10) in [1]
numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1)
numer = gammas**2 * np.mean((A * A.conj()).real, axis=1)
denom = gammas * np.sum(G * CMinvG, axis=0)
elif update_mode == 2:
# modified MacKay fixed point update (11) in [1]
Expand All @@ -350,7 +350,7 @@ def denom_fun(x):
elif update_mode == 3:
# Expectation Maximization (EM) update
denom = None
numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1) + gammas * (
numer = gammas**2 * np.mean((A * A.conj()).real, axis=1) + gammas * (
1 - gammas * np.sum(G * CMinvG, axis=0)
)
else:
Expand Down Expand Up @@ -530,15 +530,18 @@ def gprime(w):

return x


def norm_l2inf(A, n_orient, copy=True):
from math import sqrt

"""L2-inf norm."""
if A.size == 0:
return 0.0
if copy:
A = A.copy()
return sqrt(np.max(groups_norm2(A, n_orient)))


def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10):
"""Iterative Type-I estimator with L1 regularizer.
Expand Down Expand Up @@ -586,12 +589,12 @@ def gprime(w):
grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient))
return np.repeat(grp_norms, n_orient).ravel() + eps

if n_orient==1:
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

# y->M
Expand Down Expand Up @@ -919,7 +922,7 @@ def epsilon_update(L, weights, cov):
# w_mat(weights)
# - np.multiply(w_mat(weights ** 2), np.diag((L_T @ sigmaY_inv) @ L))
# )
return weights_ - (weights_ ** 2) * ((L_T @ sigmaY_inv) * L_T).sum(axis=1)
return weights_ - (weights_**2) * ((L_T @ sigmaY_inv) * L_T).sum(axis=1)

def g_coef(coef):
return groups_norm2(coef.copy(), n_orient)
Expand Down

0 comments on commit ade971e

Please sign in to comment.