Skip to content

Commit 515539e

Browse files
committed
alpha max fix for iterative L1
1 parent 4369d10 commit 515539e

File tree

2 files changed

+106
-89
lines changed

2 files changed

+106
-89
lines changed

bsi_zoo/estimators.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _gamma_map_opt(
253253
254254
Parameters
255255
----------
256-
M : array, shape=(n_sensors, n_times)
256+
: array, shape=(n_sensors, n_times)
257257
Observation.
258258
G : array, shape=(n_sensors, n_sources)
259259
Forward operator.
@@ -530,6 +530,14 @@ def gprime(w):
530530

531531
return x
532532

533+
def norm_l2inf(A, n_orient, copy=True):
534+
from math import sqrt
535+
"""L2-inf norm."""
536+
if A.size == 0:
537+
return 0.0
538+
if copy:
539+
A = A.copy()
540+
return sqrt(np.max(groups_norm2(A, n_orient)))
533541

534542
def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10):
535543
"""Iterative Type-I estimator with L1 regularizer.
@@ -578,9 +586,16 @@ def gprime(w):
578586
grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient))
579587
return np.repeat(grp_norms, n_orient).ravel() + eps
580588

581-
alpha_max = abs(L.T.dot(y)).max() / len(L)
589+
if n_orient==1:
590+
alpha_max = abs(L.T.dot(y)).max() / len(L)
591+
else:
592+
n_dip_per_pos = 3
593+
alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos)
594+
582595
alpha = alpha * alpha_max
583596

597+
# y->M
598+
# L->gain
584599
x = _solve_reweighted_lasso(
585600
L, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime
586601
)

bsi_zoo/run_benchmark.py

Lines changed: 89 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from bsi_zoo.config import get_leadfield_path
2020

2121
n_jobs = 20
22-
nruns = 10
23-
spatial_cv = [False, True]
22+
nruns = 1
23+
# spatial_cv = [False, True]
24+
spatial_cv = [False]
25+
2426
#
2527
subjects = ["CC120166", "CC120264", "CC120313", "CC120309"]
2628
metrics = [
@@ -38,81 +40,81 @@
3840

3941
for do_spatial_cv in spatial_cv:
4042
for subject in subjects:
41-
"""Fixed orientation parameters for the benchmark"""
42-
43-
orientation_type = "fixed"
44-
data_args_I = {
45-
# "n_sensors": [50],
46-
"n_times": [10],
47-
# "n_sources": [200],
48-
"nnz": nnzs,
49-
"cov_type": ["diag"],
50-
"path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)],
51-
"orientation_type": [orientation_type],
52-
"alpha": alpha_SNR, # this is actually SNR
53-
}
54-
55-
data_args_II = {
56-
# "n_sensors": [50],
57-
"n_times": [10],
58-
# "n_sources": [200],
59-
"nnz": nnzs,
60-
"cov_type": ["full"],
61-
"path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)],
62-
"orientation_type": [orientation_type],
63-
"alpha": alpha_SNR, # this is actually SNR
64-
}
65-
66-
estimators = [
67-
(fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}),
68-
(eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}),
69-
(iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}),
70-
(iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}),
71-
(iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}),
72-
(iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
73-
(iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
74-
# (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}),
75-
(gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}),
76-
# (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}),
77-
]
78-
79-
df_results = []
80-
for estimator, data_args, estimator_args, estimator_extra_params in estimators:
81-
benchmark = Benchmark(
82-
estimator,
83-
subject,
84-
metrics,
85-
data_args,
86-
estimator_args,
87-
random_state=42,
88-
memory=memory,
89-
n_jobs=n_jobs,
90-
do_spatial_cv=do_spatial_cv,
91-
estimator_extra_params=estimator_extra_params,
92-
)
93-
results = benchmark.run(nruns=nruns)
94-
df_results.append(results)
95-
# save results
96-
data_path = Path("bsi_zoo/data/updated_alpha_grid")
97-
data_path.mkdir(exist_ok=True)
98-
if do_spatial_cv:
99-
FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
100-
else:
101-
FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
102-
results.to_pickle(data_path / FILE_NAME)
103-
104-
105-
df_results = pd.concat(df_results, axis=0)
106-
107-
data_path = Path("bsi_zoo/data/ramen")
108-
data_path.mkdir(exist_ok=True)
109-
if do_spatial_cv:
110-
FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
111-
else:
112-
FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
113-
df_results.to_pickle(data_path / FILE_NAME)
114-
115-
print(df_results)
43+
# """Fixed orientation parameters for the benchmark"""
44+
45+
# orientation_type = "fixed"
46+
# data_args_I = {
47+
# # "n_sensors": [50],
48+
# "n_times": [10],
49+
# # "n_sources": [200],
50+
# "nnz": nnzs,
51+
# "cov_type": ["diag"],
52+
# "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)],
53+
# "orientation_type": [orientation_type],
54+
# "alpha": alpha_SNR, # this is actually SNR
55+
# }
56+
57+
# data_args_II = {
58+
# # "n_sensors": [50],
59+
# "n_times": [10],
60+
# # "n_sources": [200],
61+
# "nnz": nnzs,
62+
# "cov_type": ["full"],
63+
# "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)],
64+
# "orientation_type": [orientation_type],
65+
# "alpha": alpha_SNR, # this is actually SNR
66+
# }
67+
68+
# estimators = [
69+
# (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}),
70+
# (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}),
71+
# (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}),
72+
# (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}),
73+
# (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}),
74+
# (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
75+
# (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
76+
# # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}),
77+
# (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}),
78+
# # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}),
79+
# ]
80+
81+
# df_results = []
82+
# for estimator, data_args, estimator_args, estimator_extra_params in estimators:
83+
# benchmark = Benchmark(
84+
# estimator,
85+
# subject,
86+
# metrics,
87+
# data_args,
88+
# estimator_args,
89+
# random_state=42,
90+
# memory=memory,
91+
# n_jobs=n_jobs,
92+
# do_spatial_cv=do_spatial_cv,
93+
# estimator_extra_params=estimator_extra_params,
94+
# )
95+
# results = benchmark.run(nruns=nruns)
96+
# df_results.append(results)
97+
# # save results
98+
# data_path = Path("bsi_zoo/data/updated_alpha_grid")
99+
# data_path.mkdir(exist_ok=True)
100+
# if do_spatial_cv:
101+
# FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
102+
# else:
103+
# FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
104+
# results.to_pickle(data_path / FILE_NAME)
105+
106+
107+
# df_results = pd.concat(df_results, axis=0)
108+
109+
# data_path = Path("bsi_zoo/data/ramen")
110+
# data_path.mkdir(exist_ok=True)
111+
# if do_spatial_cv:
112+
# FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
113+
# else:
114+
# FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"
115+
# df_results.to_pickle(data_path / FILE_NAME)
116+
117+
# print(df_results)
116118

117119
""" Free orientation parameters for the benchmark """
118120

@@ -142,22 +144,22 @@
142144
if spatial_cv:
143145
# currently no support for type II methods
144146
estimators = [
145-
(fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}),
147+
# (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}),
146148
(iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}),
147149
(iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}),
148150
(iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}),
149151
]
150152
else:
151153
estimators = [
152-
(fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}),
153-
(eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}),
154+
# (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}),
155+
# (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}),
154156
(iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}),
155-
(iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}),
156-
(iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}),
157-
(iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
158-
(iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
157+
# (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}),
158+
# (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}),
159+
# (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
160+
# (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}),
159161
# (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}),
160-
(gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}),
162+
# (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}),
161163
# (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}),
162164
]
163165

@@ -178,7 +180,7 @@
178180
results = benchmark.run(nruns=nruns)
179181
df_results.append(results)
180182
# save results
181-
data_path = Path("bsi_zoo/data/free2")
183+
data_path = Path("bsi_zoo/data/free3")
182184
data_path.mkdir(exist_ok=True)
183185

184186
if do_spatial_cv:
@@ -189,7 +191,7 @@
189191

190192
df_results = pd.concat(df_results, axis=0)
191193

192-
data_path = Path("bsi_zoo/data/free2")
194+
data_path = Path("bsi_zoo/data/free3")
193195
data_path.mkdir(exist_ok=True)
194196
if do_spatial_cv:
195197
FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl"

0 commit comments

Comments
 (0)