Skip to content

Commit 238177e

Browse files
authored
Merge branch 'main' into main
2 parents a32257f + b4692a5 commit 238177e

File tree

1 file changed

+41
-42
lines changed

1 file changed

+41
-42
lines changed

pyspi/statistics/spectral.py

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from copy import deepcopy
33

4-
import spectral_connectivity as sc # For directed spectral statistics (excl. spectral GC)
4+
import spectral_connectivity as sc # For directed spectral statistics (excl. spectral GC)
55
from pyspi.base import directed, parse_bivariate, undirected, parse_multivariate, unsigned
66
import nitime.analysis as nta
77
import nitime.timeseries as ts
@@ -18,7 +18,7 @@ class kramer(unsigned):
1818
def __init__(self,fs=1,fmin=0,fmax=None,statistic='mean'):
1919
if fmax is None:
2020
fmax = fs/2
21-
21+
2222
self._fs = fs
2323
if fs != 1:
2424
warnings.warn('Multiple sampling frequencies not yet handled.')
@@ -265,7 +265,7 @@ def __init__(self,**kwargs):
265265
self.identifier = 'psi'
266266
super().__init__(**kwargs)
267267
self._measure = 'phase_slope_index'
268-
268+
269269
def _get_statistic(self,C):
270270
return C.phase_slope_index(frequencies_of_interest=[self._fmin,self._fmax],
271271
frequency_resolution=(self._fmax-self._fmin)/50)
@@ -278,7 +278,7 @@ def __init__(self,**kwargs):
278278
self.identifier = 'gd'
279279
super().__init__(**kwargs)
280280
self._measure = 'group_delay'
281-
281+
282282
def _get_statistic(self,C):
283283
return C.group_delay(frequencies_of_interest=[self._fmin,self._fmax],
284284
frequency_resolution=(self._fmax-self._fmin)/50)
@@ -288,17 +288,24 @@ class spectral_granger(kramer_mv,directed,unsigned):
288288
identifier = 'sgc'
289289
labels = ['unsigned','embedding','spectral','directed','lagged']
290290

291-
def __init__(self,fs=1,fmin=0.0,fmax=0.5,method='nonparametric',order=None,max_order=50,statistic='mean',ignore_NaN=True):
291+
def __init__(self, fs = 1, fmin = 1e-5, fmax = 0.5, method = 'nonparametric', order = None, max_order = 50, statistic = 'mean', ignore_nan = True, nan_threshold = 0.5):
292292
self._fs = fs # Not yet implemented
293293
self._fmin = fmin
294294
self._fmax = fmax
295+
self.ignore_nan = ignore_nan
296+
self.nan_threshold = nan_threshold
297+
298+
if self._fmin <= 0.:
299+
warnings.warn(f"Frequency minimum set to {self._fmin}; overriding to 1e-5.")
300+
self._fmin = 1e-5
301+
295302
if statistic == 'mean':
296-
if ignore_NaN:
303+
if self.ignore_nan:
297304
self._statfn = np.nanmean
298305
else:
299306
self._statfn = np.mean
300307
elif statistic == 'max':
301-
if ignore_NaN:
308+
if self.ignore_nan:
302309
self._statfn = np.nanmax
303310
else:
304311
self._statfn = np.max
@@ -313,16 +320,18 @@ def __init__(self,fs=1,fmin=0.0,fmax=0.5,method='nonparametric',order=None,max_o
313320
self._order = order
314321
self._max_order = max_order
315322
paramstr = f'_parametric_{statistic}_fs-{fs}_fmin-{fmin:.3g}_fmax-{fmax:.3g}_order-{order}'.replace('.','-')
323+
316324
self.identifier = self.identifier + paramstr
317325

318326
def _getkey(self):
319327
if self._method == 'nonparametric':
320-
return (self._method,-1,-1)
328+
return (self._method, -1, -1)
321329
else:
322330
return (self._method,self._order,self._max_order)
323331

324332
def _get_cache(self,data):
325333
key = self._getkey()
334+
326335
try:
327336
F = data.spectral_granger[key]['F']
328337
freq = data.spectral_granger[key]['freq']
@@ -332,52 +341,42 @@ def _get_cache(self,data):
332341
F, freq = super()._get_cache(data)
333342
else:
334343
z = data.to_numpy(squeeze=True)
335-
time_series = ts.TimeSeries(z,sampling_interval=1)
344+
time_series = ts.TimeSeries(z, sampling_interval=1)
336345
GA = nta.GrangerAnalyzer(time_series, order=self._order, max_order=self._max_order)
337346

338347
triu_id = np.triu_indices(data.n_processes)
339-
F = np.full(GA.causality_xy.shape,np.nan)
340-
F[triu_id[0],triu_id[1],:] = GA.causality_xy[triu_id[0],triu_id[1],:]
341-
F[triu_id[1],triu_id[0],:] = GA.causality_yx[triu_id[0],triu_id[1],:]
342-
F = np.transpose(np.expand_dims(F,axis=3),axes=[3,2,1,0])
348+
349+
F = np.full(GA.causality_xy.shape, np.nan)
350+
F[triu_id[0], triu_id[1], :] = GA.causality_xy[triu_id[0], triu_id[1], :]
351+
F[triu_id[1], triu_id[0], :] = GA.causality_yx[triu_id[0], triu_id[1], :]
352+
353+
F = np.transpose(np.expand_dims(F, axis=3), axes=[3, 2, 1, 0])
343354
freq = GA.frequencies
344355
try:
345356
data.spectral_granger[key] = {'freq': freq, 'F': F}
346357
except AttributeError:
347358
data.spectral_granger = {key: {'freq': freq, 'F': F}}
359+
348360
return F, freq
349361

350362
@parse_multivariate
351-
def multivariate(self,data):
363+
def multivariate(self, data):
352364
try:
353-
F, freq = self._get_cache(data)
354-
# Restrict frequencies to those greater than 0
355-
if self._fmin == 0:
356-
freq_id = np.where((freq > self._fmin) * (freq <= self._fmax))[0]
357-
else:
358-
freq_id = np.where((freq >= self._fmin) * (freq <= self._fmax))[0]
359-
360-
result = self._statfn(F[0,freq_id,:,:], axis=0)
361-
362-
# extract proc0 to proc1 SGC F values
363-
proc0_proc1_SGC = F[0,freq_id,0,1]
364-
# extract proc0 to proc1 SGC F values
365-
proc1_proc0_SGC = F[0,freq_id,1,0]
366-
367-
# Get number of frequency values
368-
num_freqs = len(proc0_proc1_SGC)
369-
370-
# If more than 10% of values are NaN in either direction,
371-
# set the SGC result to NaN
372-
perc_proc0_proc1_NaN = np.count_nonzero(np.isnan(proc0_proc1_SGC))/num_freqs
373-
perc_proc1_proc0_NaN = np.count_nonzero(np.isnan(proc1_proc0_SGC))/num_freqs
374-
375-
if perc_proc0_proc1_NaN > 0.1:
376-
warnings.warn("More than 10% NaN from proc0 to proc1, setting to NaN.")
377-
result[0,1] = float('NaN')
378-
if perc_proc1_proc0_NaN > 0.1:
379-
warnings.warn("More than 10% NaN from proc0 to proc1, setting to NaN.")
380-
result[1,0] = float('NaN')
365+
cache, freq = self._get_cache(data)
366+
367+
freq_id = np.where((freq >= self._fmin) * (freq <= self._fmax))[0]
368+
369+
result = self._statfn(cache[0, freq_id, :, :], axis=0)
370+
371+
nan_pct = np.isnan(cache[0, freq_id, :, :]).mean(axis=0)
372+
np.fill_diagonal(nan_pct, 0.0)
373+
374+
isna = nan_pct > self.nan_threshold
375+
if isna.any():
376+
warnings.warn(f"Spectral GC: the following processes have >{self.nan_threshold*100:.1f}% " \
377+
f"NaN values:\n{np.transpose(np.where(isna))}\nThese indices will be set to NaN. " \
378+
"Set ignore_nan to False, or modify nan_threshold parameter if needed.")
379+
result[isna] = np.nan
381380

382381
return result
383382
except ValueError as err:

0 commit comments

Comments
 (0)