Skip to content

Commit ba78e67

Browse files
committed
Changed MPI->multivariate for computing each SPI to match the bivariate signature. Also fixed the pdist signature (it was adjacency)
1 parent 5e58cee commit ba78e67

File tree

9 files changed

+44
-27
lines changed

9 files changed

+44
-27
lines changed

demos/demo.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,31 @@
77

88
import seaborn as sns
99

10-
calc = Calculator(dataset=load_dataset('forex'))
10+
# Load one of our stored datasets
11+
dataset = load_dataset('forex')
12+
13+
# visualize the dataset as a heat map (also called a temporal raster plot or carpet plot)
14+
plt.pcolormesh(dataset.to_numpy(squeeze=True),cmap='coolwarm',vmin=-2,vmax=2)
15+
plt.show()
16+
17+
# Instantiate the calculator (inputting the dataset)
18+
calc = Calculator(dataset=dataset)
19+
20+
# Compute all SPIs (this may take a while)
1121
calc.compute()
22+
23+
# Now, we can access all of the matrices by calling calc.table.
24+
# This property will be an Nx(SN) pandas dataframe, where N is the number of processes in the dataset and S is the number of SPIs
25+
print(calc.table)
26+
27+
# We can use this to compute the correlation between all of the methods on this dataset...
1228
corrmat = calc.table.stack().corr(method='spearman').abs()
1329

30+
# ...and plot this correlation matrix
1431
sns.set(font_scale=0.5)
1532
g = sns.clustermap( corrmat.fillna(0), mask=corrmat.isna(),
1633
center=0.0,
1734
cmap='RdYlBu_r',
1835
xticklabels=1, yticklabels=1 )
1936
plt.setp(g.ax_heatmap.xaxis.get_majorticklabels(), rotation=45, ha='right')
20-
plt.show()
37+
plt.show()

pyspi/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def bivariate(self,data,i=None,j=None):
8989
raise NotImplementedError("Method not yet overloaded.")
9090

9191
@parse_multivariate
92-
def mpi(self,data):
92+
def multivariate(self,data):
9393
""" Compute the dependency statistics for the entire multivariate dataset
9494
"""
9595
A = np.empty((data.n_processes,data.n_processes))
@@ -135,8 +135,8 @@ def ispositive(self):
135135
return False
136136

137137
@parse_multivariate
138-
def mpi(self,data):
139-
A = super(undirected,self).mpi(data)
138+
def multivariate(self,data):
139+
A = super(undirected,self).multivariate(data)
140140

141141
li = np.tril_indices(data.n_processes,-1)
142142
A[li] = A.T[li]

pyspi/calculator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
""" TODO: use the MPI class for each entry in the table
1818
"""
19-
class MPI():
19+
class multivariate():
2020
def __init__(self, procnames, S=None):
2121
if S is None:
2222
S = np.full((len(procnames),len(procnames)),np.nan)
@@ -158,14 +158,14 @@ def compute(self,replication=None):
158158
replication = 0
159159

160160
pbar = tqdm(self.spis.keys())
161-
for m in pbar:
162-
pbar.set_description(f'Processing [{self._name}: {m}]')
161+
for spi in pbar:
162+
pbar.set_description(f'Processing [{self._name}: {spi}]')
163163
start_time = time.time()
164164
try:
165-
self._table[m] = self._spis[m].mpi(self.dataset)
165+
self._table[spi] = self._spis[spi].multivariate(self.dataset)
166166
except Exception as err:
167-
warnings.warn(f'Caught {type(err)} for SPI "{self._statnames[m]}": {err}')
168-
self._table[m] = np.NaN
167+
warnings.warn(f'Caught {type(err)} for SPI "{spi}": {err}')
168+
self._table[spi] = np.NaN
169169
pbar.close()
170170

171171
def rmmin(self):

pyspi/statistics/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _from_cache(self,data):
4242
return mycov
4343

4444
@parse_multivariate
45-
def mpi(self,data):
45+
def multivariate(self,data):
4646
mycov = self._from_cache(data)
4747
matrix = getattr(mycov,self._kind+'_')
4848
np.fill_diagonal(matrix,np.nan)

pyspi/statistics/causal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _from_cache(self,data):
120120
return ccmf
121121

122122
@parse_multivariate
123-
def mpi(self,data):
123+
def multivariate(self,data):
124124
ccmf = self._from_cache(data)
125125

126126
if self._statistic == 'mean':

pyspi/statistics/distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self,metric='euclidean',**kwargs):
2020
self.name += f'_{metric}'
2121

2222
@parse_multivariate
23-
def adjacency(self,data):
23+
def multivariate(self,data):
2424
return pairwise_distances(data.to_numpy(squeeze=True),metric=self._metric)
2525

2626
""" TODO: include optional kernels in each method

pyspi/statistics/spectral.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get_cache(self,data):
7777
return res, freq
7878

7979
@parse_multivariate
80-
def mpi(self, data):
80+
def multivariate(self, data):
8181
adj_freq, freq = self._get_cache(data)
8282
freq_id = np.where((freq >= self._fmin) * (freq <= self._fmax))[0]
8383

@@ -299,7 +299,7 @@ def _get_statistic(self,C):
299299
# self.name = self.name + paramstr
300300

301301
# @parse_multivariate
302-
# def mpi(self,data):
302+
# def multivariate(self,data):
303303
# # This should be changed to conditioning on all, rather than averaging all conditionals
304304
# if not hasattr(data,'pcoh'):
305305
# z = np.squeeze(data.to_numpy())
@@ -372,7 +372,7 @@ def _get_cache(self,data):
372372
return F, freq
373373

374374
@parse_multivariate
375-
def mpi(self,data):
375+
def multivariate(self,data):
376376
try:
377377
F, freq = self._get_cache(data)
378378
freq_id = np.where((freq >= self._fmin) * (freq <= self._fmax))[0]
@@ -399,7 +399,7 @@ def __init__(self,orth=False,log=False,absolute=False):
399399
self.name += '_abs'
400400

401401
@parse_multivariate
402-
def mpi(self, data):
402+
def multivariate(self, data):
403403
z = np.moveaxis(data.to_numpy(),2,0)
404404
adj = np.squeeze(mnec.envelope_correlation(z,orthogonalize=self._orth,log=self._log,absolute=self._absolute))
405405
np.fill_diagonal(adj,np.nan)

pyspi/statistics/wavelet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _get_cache(self,data):
5959
return conn, freq_id
6060

6161
@parse_multivariate
62-
def mpi(self, data):
62+
def multivariate(self, data):
6363
adj_freq, freq_id = self._get_cache(data)
6464
try:
6565
adj = self._statfn(adj_freq[...,freq_id,:], axis=(2,3))
@@ -189,7 +189,7 @@ def _get_cache(self,data):
189189
return psi, freq_id
190190

191191
@parse_multivariate
192-
def mpi(self, data):
192+
def multivariate(self, data):
193193
adj_freq, freq_id = self._get_cache(data)
194194
adj = self._statfn(np.real(adj_freq[...,freq_id]), axis=(2,3))
195195

test/test_statistics.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_yaml():
5353
assert calc.n_statistics == len(calc._statistics), (
5454
'Property not equal to number of statistics')
5555

56-
def test_mpi():
56+
def test_multivariate():
5757
# Load in all base statistics from the YAML file
5858

5959
data = get_data()
@@ -81,14 +81,14 @@ def test_mpi():
8181
if any([m.name == e for e in excuse_stochastic]):
8282
continue
8383

84-
m.mpi(get_more_data())
84+
m.multivariate(get_more_data())
8585

86-
scratch_adj = m.mpi(data.to_numpy())
87-
adj = m.mpi(data)
86+
scratch_adj = m.multivariate(data.to_numpy())
87+
adj = m.multivariate(data)
8888
assert np.allclose(adj,scratch_adj,rtol=1e-1,atol=1e-2,equal_nan=True), (
8989
f'{m.name} ({m.humanname}) mpi output changed between cached and strach computations: {adj} != {scratch_adj}')
9090

91-
recomp_adj = m.mpi(data)
91+
recomp_adj = m.multivariate(data)
9292
assert np.allclose(adj,recomp_adj,rtol=1e-1,atol=1e-2,equal_nan=True), (
9393
f'{m.name} ({m.humanname}) mpi output changed when recomputing.')
9494

@@ -111,7 +111,7 @@ def test_mpi():
111111
assert t_s == pytest.approx(new_t_s,rel=1e-1,abs=1e-2), (
112112
f'{m.name} ({m.humanname}) Bivariate output from cache mismatch results from scratch for computation ({j},{i}): {t_s} != {new_t_s}')
113113
except NotImplementedError:
114-
a = m.mpi(p[[i,j]])
114+
a = m.multivariate(p[[i,j]])
115115
s_t, t_s = a[0,1], a[1,0]
116116

117117
if not math.isfinite(s_t):
@@ -206,7 +206,7 @@ def test_group():
206206
test_yaml()
207207
test_load()
208208
test_group()
209-
test_mpi()
209+
test_multivariate()
210210

211211
# This was a bit tricky to implement so just ensuring it passes a test from the creator's website
212212
test_ccm()

0 commit comments

Comments
 (0)