From 08045b274cc95df3066e1bf71c8413756a96681d Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Fri, 16 Feb 2024 15:18:26 +0100 Subject: [PATCH 1/6] full benchmark --- bsi_zoo/run_benchmark.py | 204 ++++++++++++++++++++------------------- 1 file changed, 106 insertions(+), 98 deletions(-) diff --git a/bsi_zoo/run_benchmark.py b/bsi_zoo/run_benchmark.py index 0fc52c4..b73855f 100644 --- a/bsi_zoo/run_benchmark.py +++ b/bsi_zoo/run_benchmark.py @@ -18,11 +18,11 @@ from bsi_zoo.metrics import euclidean_distance, mse, emd, f1, reconstructed_noise from bsi_zoo.config import get_leadfield_path -n_jobs = 30 +n_jobs = 20 nruns = 10 spatial_cv = [False, True] -subjects = ["CC120264", "CC120313", "CC120309"] -# "CC120166", "CC120313", +# +subjects = ["CC120166", "CC120264", "CC120313", "CC120309"] metrics = [ euclidean_distance, mse, @@ -32,17 +32,8 @@ ] # list of metric functions here nnzs = [1, 2, 3, 5] alpha_SNR = [0.99, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.01] -# estimator_alphas = [ -# 0.01, -# 0.01544452, -# 0.02385332, -# 0.03684031, -# 0.0568981, -# 0.08787639, -# 0.13572088, -# 0.2096144, -# ] # logspaced -estimator_alphas = np.logspace(0, -2, 20)[1:] +estimator_alphas_I = np.logspace(0, -2, 20)[1:] +estimator_alphas_II = [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000, 100000] memory = Memory(".") for do_spatial_cv in spatial_cv: @@ -73,16 +64,16 @@ } estimators = [ - (fake_solver, data_args_I, {"alpha": estimator_alphas}, {}), - # (eloreta, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_L1, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_L2, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_sqrt, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas}, {}), - # (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas}, {}), - # (gamma_map, data_args_II, {"alpha": estimator_alphas}, {"update_mode": 1}), - # (gamma_map, data_args_II, {"alpha": estimator_alphas}, {"update_mode": 2}), - # (gamma_map, data_args_II, {"alpha": estimator_alphas}, {"update_mode": 3}), + (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), + (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}), + (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}), + (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}), + # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}), ] df_results = [] @@ -102,15 +93,103 @@ results = benchmark.run(nruns=nruns) df_results.append(results) # save results - data_path = Path("bsi_zoo/data") + data_path = Path("bsi_zoo/data/updated_alpha_grid") data_path.mkdir(exist_ok=True) - FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + if do_spatial_cv: + FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + else: + FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" results.to_pickle(data_path / FILE_NAME) df_results = pd.concat(df_results, axis=0) - data_path = Path("bsi_zoo/data") + data_path = Path("bsi_zoo/data/ramen") + data_path.mkdir(exist_ok=True) + if do_spatial_cv: + FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + else: + FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + df_results.to_pickle(data_path / FILE_NAME) + + print(df_results) + + """ Free orientation parameters for the benchmark """ + + orientation_type = "free" + data_args_I = { + "n_sensors": [50], + "n_times": [10], + "n_sources": [200], + "nnz": nnzs, + "cov_type": ["diag"], + "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], + "orientation_type": [orientation_type], + "alpha": alpha_SNR, # this is actually SNR + } + + data_args_II = { + "n_sensors": [50], + "n_times": [10], + "n_sources": [200], + "nnz": nnzs, + "cov_type": ["full"], + "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], + "orientation_type": [orientation_type], + "alpha": alpha_SNR, # this is actually SNR + } + + if spatial_cv: + # currently no support for type II methods + estimators = [ + (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), + ] + else: + estimators = [ + (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), + (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}), + (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), + (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}), + (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}), + # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}), + ] + + df_results = [] + for estimator, data_args, estimator_args, estimator_extra_params in estimators: + benchmark = Benchmark( + estimator, + subject, + metrics, + data_args, + estimator_args, + random_state=42, + memory=memory, + n_jobs=n_jobs, + do_spatial_cv=do_spatial_cv, + estimator_extra_params=estimator_extra_params, + ) + results = benchmark.run(nruns=nruns) + df_results.append(results) + # save results + data_path = Path("bsi_zoo/data/free2") + data_path.mkdir(exist_ok=True) + + if do_spatial_cv: + FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + else: + FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + results.to_pickle(data_path / FILE_NAME) + + df_results = pd.concat(df_results, axis=0) + + data_path = Path("bsi_zoo/data/free2") data_path.mkdir(exist_ok=True) if do_spatial_cv: FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" @@ -120,74 +199,3 @@ print(df_results) - # """ Free orientation parameters for the benchmark """ - - # orientation_type = "free" - # data_args_I = { - # "n_sensors": [50], - # "n_times": [10], - # "n_sources": [200], - # "nnz": nnzs, - # "cov_type": ["diag"], - # "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], - # "orientation_type": [orientation_type], - # "alpha": alpha_SNR, # this is actually SNR - # } - - # data_args_II = { - # "n_sensors": [50], - # "n_times": [10], - # "n_sources": [200], - # "nnz": nnzs, - # "cov_type": ["full"], - # "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], - # "orientation_type": [orientation_type], - # "alpha": alpha_SNR, # this is actually SNR - # } - - # estimators = [ - # (fake_solver, data_args_I, {"alpha": estimator_alphas}, {}), - # (eloreta, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_L1, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_L2, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_sqrt, data_args_I, {"alpha": estimator_alphas}, {}), - # (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas}, {}), - # (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas}, {}), - # # (gamma_map, data_args_II, {"alpha": estimator_alphas}, {"update_mode": 1}), - # (gamma_map, data_args_II, {"alpha": estimator_alphas}, {"update_mode": 2}), - # # (gamma_map, data_args_II, {"alpha": estimator_alphas}, {"update_mode": 3}), - # ] - - # df_results = [] - # for estimator, data_args, estimator_args, estimator_extra_params in estimators: - # benchmark = Benchmark( - # estimator, - # subject, - # metrics, - # data_args, - # estimator_args, - # random_state=42, - # memory=memory, - # n_jobs=n_jobs, - # do_spatial_cv=do_spatial_cv, - # estimator_extra_params=estimator_extra_params, - # ) - # results = benchmark.run(nruns=nruns) - # df_results.append(results) - # # save results - # data_path = Path("bsi_zoo/data") - # data_path.mkdir(exist_ok=True) - # FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - # results.to_pickle(data_path / FILE_NAME) - - # df_results = pd.concat(df_results, axis=0) - - # data_path = Path("bsi_zoo/data") - # data_path.mkdir(exist_ok=True) - # if do_spatial_cv: - # FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - # else: - # FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - # df_results.to_pickle(data_path / FILE_NAME) - - # print(df_results) From 4369d103d15d2ecfc935378d3ab70f7875e78cdc Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Mon, 18 Nov 2024 18:25:46 +0100 Subject: [PATCH 2/6] old plotting ipynb --- plot_benchmark_metrics.ipynb | 959 ++++++++++++++++++++++++++++++++++- 1 file changed, 952 insertions(+), 7 deletions(-) diff --git a/plot_benchmark_metrics.ipynb b/plot_benchmark_metrics.ipynb index 7365ff7..0fbe9b8 100644 --- a/plot_benchmark_metrics.ipynb +++ b/plot_benchmark_metrics.ipynb @@ -28,6 +28,956 @@ "\n" ] }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
estimatorerroralphacov_typen_sensorsn_sourcesn_timesnnzorientation_typepath_to_leadfieldextra_paramsestimator__alphaestimator__alpha_cv
0eloretamatmul: Input operand 1 has a mismatch in its ...0.99diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
1eloretamatmul: Input operand 1 has a mismatch in its ...0.99diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
2eloretamatmul: Input operand 1 has a mismatch in its ...0.99diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
3eloretamatmul: Input operand 1 has a mismatch in its ...0.99diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
4eloretamatmul: Input operand 1 has a mismatch in its ...0.90diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
5eloretamatmul: Input operand 1 has a mismatch in its ...0.90diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
6eloretamatmul: Input operand 1 has a mismatch in its ...0.90diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
7eloretamatmul: Input operand 1 has a mismatch in its ...0.90diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
8eloretamatmul: Input operand 1 has a mismatch in its ...0.80diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
9eloretamatmul: Input operand 1 has a mismatch in its ...0.80diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
10eloretamatmul: Input operand 1 has a mismatch in its ...0.80diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
11eloretamatmul: Input operand 1 has a mismatch in its ...0.80diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
12eloretamatmul: Input operand 1 has a mismatch in its ...0.70diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
13eloretamatmul: Input operand 1 has a mismatch in its ...0.70diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
14eloretamatmul: Input operand 1 has a mismatch in its ...0.70diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
15eloretamatmul: Input operand 1 has a mismatch in its ...0.70diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
16eloretamatmul: Input operand 1 has a mismatch in its ...0.60diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
17eloretamatmul: Input operand 1 has a mismatch in its ...0.60diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
18eloretamatmul: Input operand 1 has a mismatch in its ...0.60diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
19eloretamatmul: Input operand 1 has a mismatch in its ...0.60diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
20eloretamatmul: Input operand 1 has a mismatch in its ...0.50diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
21eloretamatmul: Input operand 1 has a mismatch in its ...0.50diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
22eloretamatmul: Input operand 1 has a mismatch in its ...0.50diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
23eloretamatmul: Input operand 1 has a mismatch in its ...0.50diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
24eloretamatmul: Input operand 1 has a mismatch in its ...0.40diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
25eloretamatmul: Input operand 1 has a mismatch in its ...0.40diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
26eloretamatmul: Input operand 1 has a mismatch in its ...0.40diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
27eloretamatmul: Input operand 1 has a mismatch in its ...0.40diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
28eloretamatmul: Input operand 1 has a mismatch in its ...0.30diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
29eloretamatmul: Input operand 1 has a mismatch in its ...0.30diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
30eloretamatmul: Input operand 1 has a mismatch in its ...0.30diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
31eloretamatmul: Input operand 1 has a mismatch in its ...0.30diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
32eloretamatmul: Input operand 1 has a mismatch in its ...0.20diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
33eloretamatmul: Input operand 1 has a mismatch in its ...0.20diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
34eloretamatmul: Input operand 1 has a mismatch in its ...0.20diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
35eloretamatmul: Input operand 1 has a mismatch in its ...0.20diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
36eloretamatmul: Input operand 1 has a mismatch in its ...0.10diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
37eloretamatmul: Input operand 1 has a mismatch in its ...0.10diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
38eloretamatmul: Input operand 1 has a mismatch in its ...0.10diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
39eloretamatmul: Input operand 1 has a mismatch in its ...0.10diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
40eloretamatmul: Input operand 1 has a mismatch in its ...0.01diag50200101freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
41eloretamatmul: Input operand 1 has a mismatch in its ...0.01diag50200102freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
42eloretamatmul: Input operand 1 has a mismatch in its ...0.01diag50200103freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
43eloretamatmul: Input operand 1 has a mismatch in its ...0.01diag50200105freebsi_zoo/tests/data/lead_field_free_CC120264.npz{}[0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1...None
\n", + "
" + ], + "text/plain": [ + " estimator error alpha \\\n", + "0 eloreta matmul: Input operand 1 has a mismatch in its ... 0.99 \n", + "1 eloreta matmul: Input operand 1 has a mismatch in its ... 0.99 \n", + "2 eloreta matmul: Input operand 1 has a mismatch in its ... 0.99 \n", + "3 eloreta matmul: Input operand 1 has a mismatch in its ... 0.99 \n", + "4 eloreta matmul: Input operand 1 has a mismatch in its ... 0.90 \n", + "5 eloreta matmul: Input operand 1 has a mismatch in its ... 0.90 \n", + "6 eloreta matmul: Input operand 1 has a mismatch in its ... 0.90 \n", + "7 eloreta matmul: Input operand 1 has a mismatch in its ... 0.90 \n", + "8 eloreta matmul: Input operand 1 has a mismatch in its ... 0.80 \n", + "9 eloreta matmul: Input operand 1 has a mismatch in its ... 0.80 \n", + "10 eloreta matmul: Input operand 1 has a mismatch in its ... 0.80 \n", + "11 eloreta matmul: Input operand 1 has a mismatch in its ... 0.80 \n", + "12 eloreta matmul: Input operand 1 has a mismatch in its ... 0.70 \n", + "13 eloreta matmul: Input operand 1 has a mismatch in its ... 0.70 \n", + "14 eloreta matmul: Input operand 1 has a mismatch in its ... 0.70 \n", + "15 eloreta matmul: Input operand 1 has a mismatch in its ... 0.70 \n", + "16 eloreta matmul: Input operand 1 has a mismatch in its ... 0.60 \n", + "17 eloreta matmul: Input operand 1 has a mismatch in its ... 0.60 \n", + "18 eloreta matmul: Input operand 1 has a mismatch in its ... 0.60 \n", + "19 eloreta matmul: Input operand 1 has a mismatch in its ... 0.60 \n", + "20 eloreta matmul: Input operand 1 has a mismatch in its ... 0.50 \n", + "21 eloreta matmul: Input operand 1 has a mismatch in its ... 0.50 \n", + "22 eloreta matmul: Input operand 1 has a mismatch in its ... 0.50 \n", + "23 eloreta matmul: Input operand 1 has a mismatch in its ... 0.50 \n", + "24 eloreta matmul: Input operand 1 has a mismatch in its ... 0.40 \n", + "25 eloreta matmul: Input operand 1 has a mismatch in its ... 0.40 \n", + "26 eloreta matmul: Input operand 1 has a mismatch in its ... 0.40 \n", + "27 eloreta matmul: Input operand 1 has a mismatch in its ... 0.40 \n", + "28 eloreta matmul: Input operand 1 has a mismatch in its ... 0.30 \n", + "29 eloreta matmul: Input operand 1 has a mismatch in its ... 0.30 \n", + "30 eloreta matmul: Input operand 1 has a mismatch in its ... 0.30 \n", + "31 eloreta matmul: Input operand 1 has a mismatch in its ... 0.30 \n", + "32 eloreta matmul: Input operand 1 has a mismatch in its ... 0.20 \n", + "33 eloreta matmul: Input operand 1 has a mismatch in its ... 0.20 \n", + "34 eloreta matmul: Input operand 1 has a mismatch in its ... 0.20 \n", + "35 eloreta matmul: Input operand 1 has a mismatch in its ... 0.20 \n", + "36 eloreta matmul: Input operand 1 has a mismatch in its ... 0.10 \n", + "37 eloreta matmul: Input operand 1 has a mismatch in its ... 0.10 \n", + "38 eloreta matmul: Input operand 1 has a mismatch in its ... 0.10 \n", + "39 eloreta matmul: Input operand 1 has a mismatch in its ... 0.10 \n", + "40 eloreta matmul: Input operand 1 has a mismatch in its ... 0.01 \n", + "41 eloreta matmul: Input operand 1 has a mismatch in its ... 0.01 \n", + "42 eloreta matmul: Input operand 1 has a mismatch in its ... 0.01 \n", + "43 eloreta matmul: Input operand 1 has a mismatch in its ... 0.01 \n", + "\n", + " cov_type n_sensors n_sources n_times nnz orientation_type \\\n", + "0 diag 50 200 10 1 free \n", + "1 diag 50 200 10 2 free \n", + "2 diag 50 200 10 3 free \n", + "3 diag 50 200 10 5 free \n", + "4 diag 50 200 10 1 free \n", + "5 diag 50 200 10 2 free \n", + "6 diag 50 200 10 3 free \n", + "7 diag 50 200 10 5 free \n", + "8 diag 50 200 10 1 free \n", + "9 diag 50 200 10 2 free \n", + "10 diag 50 200 10 3 free \n", + "11 diag 50 200 10 5 free \n", + "12 diag 50 200 10 1 free \n", + "13 diag 50 200 10 2 free \n", + "14 diag 50 200 10 3 free \n", + "15 diag 50 200 10 5 free \n", + "16 diag 50 200 10 1 free \n", + "17 diag 50 200 10 2 free \n", + "18 diag 50 200 10 3 free \n", + "19 diag 50 200 10 5 free \n", + "20 diag 50 200 10 1 free \n", + "21 diag 50 200 10 2 free \n", + "22 diag 50 200 10 3 free \n", + "23 diag 50 200 10 5 free \n", + "24 diag 50 200 10 1 free \n", + "25 diag 50 200 10 2 free \n", + "26 diag 50 200 10 3 free \n", + "27 diag 50 200 10 5 free \n", + "28 diag 50 200 10 1 free \n", + "29 diag 50 200 10 2 free \n", + "30 diag 50 200 10 3 free \n", + "31 diag 50 200 10 5 free \n", + "32 diag 50 200 10 1 free \n", + "33 diag 50 200 10 2 free \n", + "34 diag 50 200 10 3 free \n", + "35 diag 50 200 10 5 free \n", + "36 diag 50 200 10 1 free \n", + "37 diag 50 200 10 2 free \n", + "38 diag 50 200 10 3 free \n", + "39 diag 50 200 10 5 free \n", + "40 diag 50 200 10 1 free \n", + "41 diag 50 200 10 2 free \n", + "42 diag 50 200 10 3 free \n", + "43 diag 50 200 10 5 free \n", + "\n", + " path_to_leadfield extra_params \\\n", + "0 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "1 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "2 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "3 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "4 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "5 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "6 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "7 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "8 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "9 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "10 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "11 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "12 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "13 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "14 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "15 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "16 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "17 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "18 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "19 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "20 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "21 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "22 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "23 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "24 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "25 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "26 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "27 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "28 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "29 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "30 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "31 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "32 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "33 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "34 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "35 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "36 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "37 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "38 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "39 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "40 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "41 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "42 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "43 bsi_zoo/tests/data/lead_field_free_CC120264.npz {} \n", + "\n", + " estimator__alpha estimator__alpha_cv \n", + "0 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "1 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "2 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "3 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "4 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "5 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "6 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "7 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "8 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "9 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "10 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "11 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "12 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "13 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "14 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "15 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "16 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "17 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "18 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "19 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "20 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "21 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "22 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "23 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "24 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "25 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "26 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "27 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "28 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "29 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "30 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "31 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "32 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "33 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "34 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "35 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "36 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "37 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "38 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "39 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "40 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "41 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "42 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None \n", + "43 [0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000, 1... None " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "#read .pkl file\n", + "\n", + "data=pd.read_pickle('/home/anujanegi/BSI-Zoo/bsi_zoo/data/free2/_CC120264_free_spatialCV_Feb-04-2024_2352.pkl')\n", + "# data[data['estimator'] == 'iterative_L1']\n", + "data" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -2835,7 +3785,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mne", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -2849,12 +3799,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.4" - }, - "vscode": { - "interpreter": { - "hash": "1b12537d3f4d8569bf39ad348dc0220ce69024f99df9bda6ed04efb090b409f6" - } + "version": "3.8.10" } }, "nbformat": 4, From 515539e3bb27e6be9d3d2e0d59234ebccf5ca6b2 Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Wed, 20 Nov 2024 11:57:13 +0100 Subject: [PATCH 3/6] alpha max fix for iterative L1 --- bsi_zoo/estimators.py | 19 ++++- bsi_zoo/run_benchmark.py | 176 ++++++++++++++++++++------------------- 2 files changed, 106 insertions(+), 89 deletions(-) diff --git a/bsi_zoo/estimators.py b/bsi_zoo/estimators.py index 25447c3..a26807d 100644 --- a/bsi_zoo/estimators.py +++ b/bsi_zoo/estimators.py @@ -253,7 +253,7 @@ def _gamma_map_opt( Parameters ---------- - M : array, shape=(n_sensors, n_times) + : array, shape=(n_sensors, n_times) Observation. G : array, shape=(n_sensors, n_sources) Forward operator. @@ -530,6 +530,14 @@ def gprime(w): return x +def norm_l2inf(A, n_orient, copy=True): + from math import sqrt + """L2-inf norm.""" + if A.size == 0: + return 0.0 + if copy: + A = A.copy() + return sqrt(np.max(groups_norm2(A, n_orient))) def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10): """Iterative Type-I estimator with L1 regularizer. @@ -578,9 +586,16 @@ def gprime(w): grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient)) return np.repeat(grp_norms, n_orient).ravel() + eps - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient==1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max + # y->M + # L->gain x = _solve_reweighted_lasso( L, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime ) diff --git a/bsi_zoo/run_benchmark.py b/bsi_zoo/run_benchmark.py index b73855f..a8bd2d9 100644 --- a/bsi_zoo/run_benchmark.py +++ b/bsi_zoo/run_benchmark.py @@ -19,8 +19,10 @@ from bsi_zoo.config import get_leadfield_path n_jobs = 20 -nruns = 10 -spatial_cv = [False, True] +nruns = 1 +# spatial_cv = [False, True] +spatial_cv = [False] + # subjects = ["CC120166", "CC120264", "CC120313", "CC120309"] metrics = [ @@ -38,81 +40,81 @@ for do_spatial_cv in spatial_cv: for subject in subjects: - """Fixed orientation parameters for the benchmark""" - - orientation_type = "fixed" - data_args_I = { - # "n_sensors": [50], - "n_times": [10], - # "n_sources": [200], - "nnz": nnzs, - "cov_type": ["diag"], - "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], - "orientation_type": [orientation_type], - "alpha": alpha_SNR, # this is actually SNR - } - - data_args_II = { - # "n_sensors": [50], - "n_times": [10], - # "n_sources": [200], - "nnz": nnzs, - "cov_type": ["full"], - "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], - "orientation_type": [orientation_type], - "alpha": alpha_SNR, # this is actually SNR - } - - estimators = [ - (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), - (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}), - (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), - (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), - (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), - (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), - (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), - # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}), - (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}), - # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}), - ] - - df_results = [] - for estimator, data_args, estimator_args, estimator_extra_params in estimators: - benchmark = Benchmark( - estimator, - subject, - metrics, - data_args, - estimator_args, - random_state=42, - memory=memory, - n_jobs=n_jobs, - do_spatial_cv=do_spatial_cv, - estimator_extra_params=estimator_extra_params, - ) - results = benchmark.run(nruns=nruns) - df_results.append(results) - # save results - data_path = Path("bsi_zoo/data/updated_alpha_grid") - data_path.mkdir(exist_ok=True) - if do_spatial_cv: - FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - else: - FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - results.to_pickle(data_path / FILE_NAME) - - - df_results = pd.concat(df_results, axis=0) - - data_path = Path("bsi_zoo/data/ramen") - data_path.mkdir(exist_ok=True) - if do_spatial_cv: - FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - else: - FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" - df_results.to_pickle(data_path / FILE_NAME) - - print(df_results) + # """Fixed orientation parameters for the benchmark""" + + # orientation_type = "fixed" + # data_args_I = { + # # "n_sensors": [50], + # "n_times": [10], + # # "n_sources": [200], + # "nnz": nnzs, + # "cov_type": ["diag"], + # "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], + # "orientation_type": [orientation_type], + # "alpha": alpha_SNR, # this is actually SNR + # } + + # data_args_II = { + # # "n_sensors": [50], + # "n_times": [10], + # # "n_sources": [200], + # "nnz": nnzs, + # "cov_type": ["full"], + # "path_to_leadfield": [get_leadfield_path(subject, type=orientation_type)], + # "orientation_type": [orientation_type], + # "alpha": alpha_SNR, # this is actually SNR + # } + + # estimators = [ + # (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}), + # (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + # (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + # # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}), + # (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}), + # # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}), + # ] + + # df_results = [] + # for estimator, data_args, estimator_args, estimator_extra_params in estimators: + # benchmark = Benchmark( + # estimator, + # subject, + # metrics, + # data_args, + # estimator_args, + # random_state=42, + # memory=memory, + # n_jobs=n_jobs, + # do_spatial_cv=do_spatial_cv, + # estimator_extra_params=estimator_extra_params, + # ) + # results = benchmark.run(nruns=nruns) + # df_results.append(results) + # # save results + # data_path = Path("bsi_zoo/data/updated_alpha_grid") + # data_path.mkdir(exist_ok=True) + # if do_spatial_cv: + # FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + # else: + # FILE_NAME = f"{estimator}_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + # results.to_pickle(data_path / FILE_NAME) + + + # df_results = pd.concat(df_results, axis=0) + + # data_path = Path("bsi_zoo/data/ramen") + # data_path.mkdir(exist_ok=True) + # if do_spatial_cv: + # FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + # else: + # FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" + # df_results.to_pickle(data_path / FILE_NAME) + + # print(df_results) """ Free orientation parameters for the benchmark """ @@ -142,22 +144,22 @@ if spatial_cv: # currently no support for type II methods estimators = [ - (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), ] else: estimators = [ - (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), - (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}), + # (fake_solver, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (eloreta, data_args_I, {"alpha": estimator_alphas_II}, {}), (iterative_L1, data_args_I, {"alpha": estimator_alphas_I}, {}), - (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), - (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), - (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), - (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + # (iterative_L2, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (iterative_sqrt, data_args_I, {"alpha": estimator_alphas_I}, {}), + # (iterative_L1_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), + # (iterative_L2_typeII, data_args_II, {"alpha": estimator_alphas_I}, {}), # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 1}), - (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}), + # (gamma_map, data_args_II, {"alpha": estimator_alphas_II}, {"update_mode": 2}), # (gamma_map, data_args_II, {"alpha": estimator_alphas_I}, {"update_mode": 3}), ] @@ -178,7 +180,7 @@ results = benchmark.run(nruns=nruns) df_results.append(results) # save results - data_path = Path("bsi_zoo/data/free2") + data_path = Path("bsi_zoo/data/free3") data_path.mkdir(exist_ok=True) if do_spatial_cv: @@ -189,7 +191,7 @@ df_results = pd.concat(df_results, axis=0) - data_path = Path("bsi_zoo/data/free2") + data_path = Path("bsi_zoo/data/free3") data_path.mkdir(exist_ok=True) if do_spatial_cv: FILE_NAME = f"benchmark_data_{subject}_{data_args['orientation_type'][0]}_spatialCV_{time.strftime('%b-%d-%Y_%H%M', time.localtime())}.pkl" From 36e22b9c10790bcebb3309f17dc8768f74ffc0b9 Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Wed, 11 Dec 2024 12:44:01 +0100 Subject: [PATCH 4/6] use eigen leads for solving lasso --- bsi_zoo/estimators.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bsi_zoo/estimators.py b/bsi_zoo/estimators.py index a26807d..7dedd86 100644 --- a/bsi_zoo/estimators.py +++ b/bsi_zoo/estimators.py @@ -594,10 +594,12 @@ def gprime(w): alpha = alpha * alpha_max + eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False) + # y->M # L->gain x = _solve_reweighted_lasso( - L, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime + eigen_leads, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime ) return x From d7ed1703785e770ba9b4da7c3694ae01d8f56220 Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Thu, 12 Dec 2024 11:26:21 +0100 Subject: [PATCH 5/6] fix alpha for iterative methods --- bsi_zoo/estimators.py | 44 +++++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/bsi_zoo/estimators.py b/bsi_zoo/estimators.py index 7dedd86..44d9918 100644 --- a/bsi_zoo/estimators.py +++ b/bsi_zoo/estimators.py @@ -207,7 +207,7 @@ def _solve_reweighted_lasso( n_positions = L_w.shape[1] // n_orient lc = np.empty(n_positions) for j in range(n_positions): - L_j = L_w[:, (j * n_orient): ((j + 1) * n_orient)] + L_j = L_w[:, (j * n_orient) : ((j + 1) * n_orient)] lc[j] = np.linalg.norm(np.dot(L_j.T, L_j), ord=2) coef_, active_set, _ = _mixed_norm_solver_bcd( y, @@ -530,8 +530,10 @@ def gprime(w): return x + def norm_l2inf(A, n_orient, copy=True): from math import sqrt + """L2-inf norm.""" if A.size == 0: return 0.0 @@ -539,6 +541,7 @@ def norm_l2inf(A, n_orient, copy=True): A = A.copy() return sqrt(np.max(groups_norm2(A, n_orient))) + def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10): """Iterative Type-I estimator with L1 regularizer. @@ -586,20 +589,20 @@ def gprime(w): grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient)) return np.repeat(grp_norms, n_orient).ravel() + eps - if n_orient==1: + if n_orient == 1: alpha_max = abs(L.T.dot(y)).max() / len(L) - else: + else: n_dip_per_pos = 3 alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) - + alpha = alpha * alpha_max - eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False) + # eigen_fields, sing, eigen_leads = _safe_svd(L, full_matrices=False) # y->M # L->gain x = _solve_reweighted_lasso( - eigen_leads, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime + L, y, alpha, n_orient, weights, max_iter, max_iter_reweighting, gprime ) return x @@ -617,6 +620,7 @@ def iterative_L2(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweightin for solving the following problem: x^(k+1) <-- argmin_x ||y - Lx||^2_Fro + alpha * sum_i w_i^(k)|x_i| + Parameters ---------- L : array, shape (n_sensors, n_sources) @@ -651,7 +655,12 @@ def gprime(w): grp_norm2 = groups_norm2(w.copy(), n_orient) return np.repeat(grp_norm2, n_orient).ravel() + eps - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max x = _solve_reweighted_lasso( @@ -710,7 +719,12 @@ def g(w): def gprime(w): return 2.0 * np.repeat(g(w), n_orient).ravel() - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max x = _solve_reweighted_lasso( @@ -795,7 +809,12 @@ def iterative_L1_typeII( n_sensors, n_sources = L.shape weights = np.ones(n_sources) - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max if isinstance(cov, float): @@ -894,7 +913,12 @@ def iterative_L2_typeII( n_sensors, n_sources = L.shape weights = np.ones(n_sources) - alpha_max = abs(L.T.dot(y)).max() / len(L) + if n_orient == 1: + alpha_max = abs(L.T.dot(y)).max() / len(L) + else: + n_dip_per_pos = 3 + alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) + alpha = alpha * alpha_max if isinstance(cov, float): From ade971e6f9f0bd99b70bd5e823ffb44cd77829ca Mon Sep 17 00:00:00 2001 From: Anuja Negi Date: Thu, 2 Jan 2025 09:38:58 +0100 Subject: [PATCH 6/6] alpha max fix for iterative methods --- bsi_zoo/estimators.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/bsi_zoo/estimators.py b/bsi_zoo/estimators.py index a26807d..6927202 100644 --- a/bsi_zoo/estimators.py +++ b/bsi_zoo/estimators.py @@ -75,7 +75,7 @@ def _compute_reginv2(sing, n_nzero, lambda2): reginv = np.zeros_like(sing) sing = sing[:n_nzero] with np.errstate(invalid="ignore"): # if lambda2==0 - reginv[:n_nzero] = np.where(sing > 0, sing / (sing ** 2 + lambda2), 0) + reginv[:n_nzero] = np.where(sing > 0, sing / (sing**2 + lambda2), 0) return reginv @@ -119,7 +119,7 @@ def _compute_eloreta_kernel(L, *, lambda2, n_orient, whitener, loose=1.0, max_it # Outer product R_prior = source_std.reshape(n_src, 1, 3) * source_std.reshape(n_src, 3, 1) else: - R_prior = source_std ** 2 + R_prior = source_std**2 # The following was adapted under BSD license by permission of Guido Nolte if force_equal or n_orient == 1: @@ -207,7 +207,7 @@ def _solve_reweighted_lasso( n_positions = L_w.shape[1] // n_orient lc = np.empty(n_positions) for j in range(n_positions): - L_j = L_w[:, (j * n_orient): ((j + 1) * n_orient)] + L_j = L_w[:, (j * n_orient) : ((j + 1) * n_orient)] lc[j] = np.linalg.norm(np.dot(L_j.T, L_j), ord=2) coef_, active_set, _ = _mixed_norm_solver_bcd( y, @@ -341,7 +341,7 @@ def denom_fun(x): if update_mode == 1: # MacKay fixed point update (10) in [1] - numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1) + numer = gammas**2 * np.mean((A * A.conj()).real, axis=1) denom = gammas * np.sum(G * CMinvG, axis=0) elif update_mode == 2: # modified MacKay fixed point update (11) in [1] @@ -350,7 +350,7 @@ def denom_fun(x): elif update_mode == 3: # Expectation Maximization (EM) update denom = None - numer = gammas ** 2 * np.mean((A * A.conj()).real, axis=1) + gammas * ( + numer = gammas**2 * np.mean((A * A.conj()).real, axis=1) + gammas * ( 1 - gammas * np.sum(G * CMinvG, axis=0) ) else: @@ -530,8 +530,10 @@ def gprime(w): return x + def norm_l2inf(A, n_orient, copy=True): from math import sqrt + """L2-inf norm.""" if A.size == 0: return 0.0 @@ -539,6 +541,7 @@ def norm_l2inf(A, n_orient, copy=True): A = A.copy() return sqrt(np.max(groups_norm2(A, n_orient))) + def iterative_L1(L, y, alpha=0.2, n_orient=1, max_iter=1000, max_iter_reweighting=10): """Iterative Type-I estimator with L1 regularizer. @@ -586,12 +589,12 @@ def gprime(w): grp_norms = np.sqrt(groups_norm2(w.copy(), n_orient)) return np.repeat(grp_norms, n_orient).ravel() + eps - if n_orient==1: + if n_orient == 1: alpha_max = abs(L.T.dot(y)).max() / len(L) - else: + else: n_dip_per_pos = 3 alpha_max = norm_l2inf(np.dot(L.T, y), n_dip_per_pos) - + alpha = alpha * alpha_max # y->M @@ -919,7 +922,7 @@ def epsilon_update(L, weights, cov): # w_mat(weights) # - np.multiply(w_mat(weights ** 2), np.diag((L_T @ sigmaY_inv) @ L)) # ) - return weights_ - (weights_ ** 2) * ((L_T @ sigmaY_inv) * L_T).sum(axis=1) + return weights_ - (weights_**2) * ((L_T @ sigmaY_inv) * L_T).sum(axis=1) def g_coef(coef): return groups_norm2(coef.copy(), n_orient)