Skip to content

Commit

Permalink
alpha max fix for iterative methods
Browse files Browse the repository at this point in the history
  • Loading branch information
anujanegi authored Jan 2, 2025
2 parents bd9af0a + 7e5c7d4 commit f369387
Show file tree
Hide file tree
Showing 3 changed files with 1,124 additions and 128 deletions.
65 changes: 53 additions & 12 deletions bsi_zoo/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _compute_reginv2(sing, n_nzero, lambda2):
reginv = np.zeros_like(sing)
sing = sing[:n_nzero]
with np.errstate(invalid="ignore"): # if lambda2==0
reginv[:n_nzero] = np.where(sing > 0, sing / (sing ** 2 + lambda2), 0)
reginv[:n_nzero] = np.where(sing > 0, sing / (sing**2 + lambda2), 0)
return reginv


Expand Down Expand Up @@ -119,7 +119,7 @@ def _compute_eloreta_kernel(L, *, lambda2, n_orient, whitener, loose=1.0, max_it
# Outer product
R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1)
else:
R_prior = source_std ** 2
R_prior = source_std**2

# The following was adapted under BSD license by permission of Guido Nolte
if force_equal or n_orient == 1:
Expand Down Expand Up @@ -207,7 +207,7 @@ def _solve_reweighted_lasso(
n_positions = L_w.shape[1] // n_orient
lc = np.empty(n_positions)
for j in range(n_positions):
L_j = L_w[:, (j * n_orient): ((j + 1) * n_orient)]
L_j = L_w[:, (j * n_orient) : ((j + 1) * n_orient)]
lc[j] = np.linalg.norm(np.dot(L_j.T, L_j), ord=2)
coef_, active_set, _ = _mixed_norm_solver_bcd(
y,
Expand Down Expand Up @@ -253,7 +253,7 @@ def _gamma_map_opt(
Parameters
----------
M : array, shape=(n_sensors, n_times)
: array, shape=(n_sensors, n_times)
Observation.
G : array, shape=(n_sensors, n_sources)
Forward operator.
Expand Down Expand Up @@ -341,7 +341,7 @@ def denom_fun(x):

if update_mode == 1:
# MacKay fixed point update (10) in [1]
numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1)
numer = gammas**2 * np.mean((A * A.conj()).real, axis=1)
denom = gammas * np.sum(G * CMinvG, axis=0)
elif update_mode == 2:
# modified MacKay fixed point update (11) in [1]
Expand All @@ -350,7 +350,7 @@ def denom_fun(x):
elif update_mode == 3:
# Expectation Maximization (EM) update
denom = None
numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1) + gammas * (
numer = gammas**2 * np.mean((A * A.conj()).real, axis=1) + gammas * (
1 - gammas * np.sum(G * CMinvG, axis=0)
)
else:
Expand Down Expand Up @@ -531,6 +531,17 @@ def gprime(w):
return x


def norm_l2inf(A, n_orient, copy=True):
from math import sqrt

"""L2-inf norm."""
if A.size == 0:
return 0.0
if copy:
A = A.copy()
return sqrt(np.max(groups_norm2(A, n_orient)))


def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10):
"""Iterative Type-I estimator with L1 regularizer.
Expand Down Expand Up @@ -578,9 +589,18 @@ def gprime(w):
grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient))
return np.repeat(grp_norms, n_orient).ravel() + eps

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

# eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False)

# y->M
# L->gain
x = _solve_reweighted_lasso(
L, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime
)
Expand All @@ -600,6 +620,7 @@ def iterative_L2(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweightin
for solving the following problem:
x^(k+1) <-- argmin_x ||y - Lx||^2_Fro + alpha * sum_i w_i^(k)|x_i|
Parameters
----------
L : array, shape (n_sensors, n_sources)
Expand Down Expand Up @@ -634,7 +655,12 @@ def gprime(w):
grp_norm2 = groups_norm2(w.copy(), n_orient)
return np.repeat(grp_norm2, n_orient).ravel() + eps

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

x = _solve_reweighted_lasso(
Expand Down Expand Up @@ -693,7 +719,12 @@ def g(w):
def gprime(w):
return 2.0 * np.repeat(g(w), n_orient).ravel()

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

x = _solve_reweighted_lasso(
Expand Down Expand Up @@ -778,7 +809,12 @@ def iterative_L1_typeII(
n_sensors, n_sources = L.shape
weights = np.ones(n_sources)

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

if isinstance(cov, float):
Expand Down Expand Up @@ -877,7 +913,12 @@ def iterative_L2_typeII(
n_sensors, n_sources = L.shape
weights = np.ones(n_sources)

alpha_max = abs(L.T.dot(y)).max() / len(L)
if n_orient == 1:
alpha_max = abs(L.T.dot(y)).max() / len(L)
else:
n_dip_per_pos = 3
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)

alpha = alpha * alpha_max

if isinstance(cov, float):
Expand All @@ -904,7 +945,7 @@ def epsilon_update(L, weights, cov):
# w_mat(weights)
# - np.multiply(w_mat(weights ** 2), np.diag((L_T @ sigmaY_inv) @ L))
# )
return weights_ - (weights_ ** 2) * ((L_T @ sigmaY_inv) * L_T).sum(axis=1)
return weights_ - (weights_**2) * ((L_T @ sigmaY_inv) * L_T).sum(axis=1)

def g_coef(coef):
return groups_norm2(coef.copy(), n_orient)
Expand Down
Loading

0 comments on commit f369387

Please sign in to comment.