11import numpy as np
22from copy import deepcopy
33
4- import spectral_connectivity as sc # For directed spectral statistics (excl. spectral GC)
4+ import spectral_connectivity as sc # For directed spectral statistics (excl. spectral GC)
55from pyspi .base import directed , parse_bivariate , undirected , parse_multivariate , unsigned
66import nitime .analysis as nta
77import nitime .timeseries as ts
@@ -18,7 +18,7 @@ class kramer(unsigned):
1818 def __init__ (self ,fs = 1 ,fmin = 0 ,fmax = None ,statistic = 'mean' ):
1919 if fmax is None :
2020 fmax = fs / 2
21-
21+
2222 self ._fs = fs
2323 if fs != 1 :
2424 warnings .warn ('Multiple sampling frequencies not yet handled.' )
@@ -265,7 +265,7 @@ def __init__(self,**kwargs):
265265 self .identifier = 'psi'
266266 super ().__init__ (** kwargs )
267267 self ._measure = 'phase_slope_index'
268-
268+
269269 def _get_statistic (self ,C ):
270270 return C .phase_slope_index (frequencies_of_interest = [self ._fmin ,self ._fmax ],
271271 frequency_resolution = (self ._fmax - self ._fmin )/ 50 )
@@ -278,7 +278,7 @@ def __init__(self,**kwargs):
278278 self .identifier = 'gd'
279279 super ().__init__ (** kwargs )
280280 self ._measure = 'group_delay'
281-
281+
282282 def _get_statistic (self ,C ):
283283 return C .group_delay (frequencies_of_interest = [self ._fmin ,self ._fmax ],
284284 frequency_resolution = (self ._fmax - self ._fmin )/ 50 )
@@ -288,17 +288,24 @@ class spectral_granger(kramer_mv,directed,unsigned):
288288 identifier = 'sgc'
289289 labels = ['unsigned' ,'embedding' ,'spectral' ,'directed' ,'lagged' ]
290290
291- def __init__ (self ,fs = 1 , fmin = 0.0 , fmax = 0.5 ,method = 'nonparametric' ,order = None ,max_order = 50 ,statistic = 'mean' ,ignore_NaN = True ):
291+ def __init__ (self , fs = 1 , fmin = 1e-5 , fmax = 0.5 , method = 'nonparametric' , order = None , max_order = 50 , statistic = 'mean' , ignore_nan = True , nan_threshold = 0.5 ):
292292 self ._fs = fs # Not yet implemented
293293 self ._fmin = fmin
294294 self ._fmax = fmax
295+ self .ignore_nan = ignore_nan
296+ self .nan_threshold = nan_threshold
297+
298+ if self ._fmin <= 0. :
299+ warnings .warn (f"Frequency minimum set to { self ._fmin } ; overriding to 1e-5." )
300+ self ._fmin = 1e-5
301+
295302 if statistic == 'mean' :
296- if ignore_NaN :
303+ if self . ignore_nan :
297304 self ._statfn = np .nanmean
298305 else :
299306 self ._statfn = np .mean
300307 elif statistic == 'max' :
301- if ignore_NaN :
308+ if self . ignore_nan :
302309 self ._statfn = np .nanmax
303310 else :
304311 self ._statfn = np .max
@@ -313,16 +320,18 @@ def __init__(self,fs=1,fmin=0.0,fmax=0.5,method='nonparametric',order=None,max_o
313320 self ._order = order
314321 self ._max_order = max_order
315322 paramstr = f'_parametric_{ statistic } _fs-{ fs } _fmin-{ fmin :.3g} _fmax-{ fmax :.3g} _order-{ order } ' .replace ('.' ,'-' )
323+
316324 self .identifier = self .identifier + paramstr
317325
318326 def _getkey (self ):
319327 if self ._method == 'nonparametric' :
320- return (self ._method ,- 1 ,- 1 )
328+ return (self ._method , - 1 , - 1 )
321329 else :
322330 return (self ._method ,self ._order ,self ._max_order )
323331
324332 def _get_cache (self ,data ):
325333 key = self ._getkey ()
334+
326335 try :
327336 F = data .spectral_granger [key ]['F' ]
328337 freq = data .spectral_granger [key ]['freq' ]
@@ -332,52 +341,42 @@ def _get_cache(self,data):
332341 F , freq = super ()._get_cache (data )
333342 else :
334343 z = data .to_numpy (squeeze = True )
335- time_series = ts .TimeSeries (z ,sampling_interval = 1 )
344+ time_series = ts .TimeSeries (z , sampling_interval = 1 )
336345 GA = nta .GrangerAnalyzer (time_series , order = self ._order , max_order = self ._max_order )
337346
338347 triu_id = np .triu_indices (data .n_processes )
339- F = np .full (GA .causality_xy .shape ,np .nan )
340- F [triu_id [0 ],triu_id [1 ],:] = GA .causality_xy [triu_id [0 ],triu_id [1 ],:]
341- F [triu_id [1 ],triu_id [0 ],:] = GA .causality_yx [triu_id [0 ],triu_id [1 ],:]
342- F = np .transpose (np .expand_dims (F ,axis = 3 ),axes = [3 ,2 ,1 ,0 ])
348+
349+ F = np .full (GA .causality_xy .shape , np .nan )
350+ F [triu_id [0 ], triu_id [1 ], :] = GA .causality_xy [triu_id [0 ], triu_id [1 ], :]
351+ F [triu_id [1 ], triu_id [0 ], :] = GA .causality_yx [triu_id [0 ], triu_id [1 ], :]
352+
353+ F = np .transpose (np .expand_dims (F , axis = 3 ), axes = [3 , 2 , 1 , 0 ])
343354 freq = GA .frequencies
344355 try :
345356 data .spectral_granger [key ] = {'freq' : freq , 'F' : F }
346357 except AttributeError :
347358 data .spectral_granger = {key : {'freq' : freq , 'F' : F }}
359+
348360 return F , freq
349361
350362 @parse_multivariate
351- def multivariate (self ,data ):
363+ def multivariate (self , data ):
352364 try :
353- F , freq = self ._get_cache (data )
354- # Restrict frequencies to those greater than 0
355- if self ._fmin == 0 :
356- freq_id = np .where ((freq > self ._fmin ) * (freq <= self ._fmax ))[0 ]
357- else :
358- freq_id = np .where ((freq >= self ._fmin ) * (freq <= self ._fmax ))[0 ]
359-
360- result = self ._statfn (F [0 ,freq_id ,:,:], axis = 0 )
361-
362- # extract proc0 to proc1 SGC F values
363- proc0_proc1_SGC = F [0 ,freq_id ,0 ,1 ]
364- # extract proc0 to proc1 SGC F values
365- proc1_proc0_SGC = F [0 ,freq_id ,1 ,0 ]
366-
367- # Get number of frequency values
368- num_freqs = len (proc0_proc1_SGC )
369-
370- # If more than 10% of values are NaN in either direction,
371- # set the SGC result to NaN
372- perc_proc0_proc1_NaN = np .count_nonzero (np .isnan (proc0_proc1_SGC ))/ num_freqs
373- perc_proc1_proc0_NaN = np .count_nonzero (np .isnan (proc1_proc0_SGC ))/ num_freqs
374-
375- if perc_proc0_proc1_NaN > 0.1 :
376- warnings .warn ("More than 10% NaN from proc0 to proc1, setting to NaN." )
377- result [0 ,1 ] = float ('NaN' )
378- if perc_proc1_proc0_NaN > 0.1 :
379- warnings .warn ("More than 10% NaN from proc0 to proc1, setting to NaN." )
380- result [1 ,0 ] = float ('NaN' )
365+ cache , freq = self ._get_cache (data )
366+
367+ freq_id = np .where ((freq >= self ._fmin ) * (freq <= self ._fmax ))[0 ]
368+
369+ result = self ._statfn (cache [0 , freq_id , :, :], axis = 0 )
370+
371+ nan_pct = np .isnan (cache [0 , freq_id , :, :]).mean (axis = 0 )
372+ np .fill_diagonal (nan_pct , 0.0 )
373+
374+ isna = nan_pct > self .nan_threshold
375+ if isna .any ():
376+ warnings .warn (f"Spectral GC: the following processes have >{ self .nan_threshold * 100 :.1f} % " \
377+ f"NaN values:\n { np .transpose (np .where (isna ))} \n These indices will be set to NaN. " \
378+ "Set ignore_nan to False, or modify nan_threshold parameter if needed." )
379+ result [isna ] = np .nan
381380
382381 return result
383382 except ValueError as err :
0 commit comments