From d7ed1703785e770ba9b4da7c3694ae01d8f56220 Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Thu, 12 Dec 2024 11:26:21 +0100 Subject: [PATCH] fix alpha for iterative methods --- bsi_zoo/estimators.py | 44 +++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/bsi_zoo/estimators.py b/bsi_zoo/estimators.py index 7dedd86..44d9918 100644 --- a/bsi_zoo/estimators.py +++ b/bsi_zoo/estimators.py @@ -207,7 +207,7 @@ def _solve_reweighted_lasso( n_positions = L_w.shape[1] // n_orient lc = np.empty(n_positions) for j in range(n_positions): - L_j = L_w[:, (j * n_orient): ((j + 1) * n_orient)] + L_j = L_w[:, (j * n_orient) : ((j + 1) * n_orient)] lc[j] = np.linalg.norm(np.dot(L_j.T, L_j), ord=2) coef_, active_set, _ = _mixed_norm_solver_bcd( y, @@ -530,8 +530,10 @@ def gprime(w): return x + def norm_l2inf(A, n_orient, copy=True): from math import sqrt + """L2-inf norm.""" if A.size == 0: return 0.0 @@ -539,6 +541,7 @@ def norm_l2inf(A, n_orient, copy=True): A = A.copy() return sqrt(np.max(groups_norm2(A, n_orient))) + def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10): """Iterative Type-I estimator with L1 regularizer. @@ -586,20 +589,20 @@ def gprime(w): grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient)) return np.repeat(grp_norms, n_orient).ravel() + eps - if n_orient==1: + if n_orient == 1: alpha_max = abs(L.T.dot(y)).max() / len(L) - else: + else: n_dip_per_pos = 3 alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) - + alpha = alpha * alpha_max - eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False) + # eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False) # y->M # L->gain x = _solve_reweighted_lasso( - eigen_leads, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime + L, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime ) return x @@ -617,6 +620,7 @@ def iterative_L2(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweightin for solving the following problem: x^(k+1) <-- argmin_x ||y - Lx||^2_Fro + alpha * sum_i w_i^(k)|x_i| + Parameters ---------- L : array, shape (n_sensors, n_sources) @@ -651,7 +655,12 @@ def gprime(w): grp_norm2 = groups_norm2(w.copy(), n_orient) return np.repeat(grp_norm2, n_orient).ravel() + eps - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max x = _solve_reweighted_lasso( @@ -710,7 +719,12 @@ def g(w): def gprime(w): return 2.0 * np.repeat(g(w), n_orient).ravel() - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max x = _solve_reweighted_lasso( @@ -795,7 +809,12 @@ def iterative_L1_typeII( n_sensors, n_sources = L.shape weights = np.ones(n_sources) - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max if isinstance(cov, float): @@ -894,7 +913,12 @@ def iterative_L2_typeII( n_sensors, n_sources = L.shape weights = np.ones(n_sources) - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max if isinstance(cov, float):