@@ -231,11 +231,12 @@ def cluster(self, method="GaussianMixture", **kwargs):
231
231
from scipy import stats
232
232
233
233
clu = getattr (mixture , method )
234
- burnin_ind = self .burnin // ( self .g * self . gskip )
234
+ burnin_ind = self .burnin // self .g
235
235
data_len = len (self .times )
236
236
wcutoff = 10 / data_len
237
237
238
- weights , rates = self .mcweights [burnin_ind :], self .mcrates [burnin_ind :]
238
+ weights = self .mcweights [burnin_ind ::self .gskip ]
239
+ rates = self .mcrates [burnin_ind ::self .gskip ]
239
240
lens = np .array ([len (row [row > wcutoff ]) for row in weights ])
240
241
lmode = stats .mode (lens ).mode
241
242
train_param = lmode
@@ -258,7 +259,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
258
259
all_labels = r .predict (np .log (data ))
259
260
260
261
if self .indicator is not None :
261
- indicator = self .indicator [burnin_ind :]
262
+ indicator = self .indicator [burnin_ind :: self . gskip ]
262
263
else :
263
264
indicator = self ._sample_indicator ()
264
265
@@ -285,13 +286,15 @@ def process_gibbs(self):
285
286
data_len = len (self .times )
286
287
wcutoff = 10 / data_len
287
288
burnin_ind = self .burnin // self .g
288
- inds = np .where (self .mcweights [burnin_ind :] > wcutoff )
289
+ inds = np .where (self .mcweights [burnin_ind :: self . gskip ] > wcutoff )
289
290
indices = (np .arange (self .burnin , self .niter + 1 , self .g * self .gskip )
290
291
[inds [0 ]] // self .g )
291
- weights , rates = self .mcweights [burnin_ind :], self .mcrates [burnin_ind :]
292
+ weights = self .mcweights [burnin_ind ::self .gskip ]
293
+ rates = self .mcrates [burnin_ind ::self .gskip ]
292
294
fweights , frates = weights [inds ], rates [inds ]
293
295
294
- lens = [len (row [row > wcutoff ]) for row in self .mcweights [burnin_ind :]]
296
+ lens = [len (row [row > wcutoff ]) for row in
297
+ self .mcweights [burnin_ind ::self .gskip ]]
295
298
lmode = stats .mode (lens ).mode
296
299
297
300
self .cluster (n_init = 117 , n_components = lmode )
@@ -320,8 +323,8 @@ def result_plot(self, remove_noise=False, **kwargs):
320
323
mixture_and_plot (self , remove_noise = remove_noise , ** kwargs )
321
324
322
325
def _sample_indicator (self ):
323
- indicator = np .zeros (((self .niter + 1 )// self .g , self .times . shape [ 0 ]),
324
- dtype = np .uint8 )
326
+ indicator = np .zeros (((self .niter + 1 )// ( self .g * self .gskip ),
327
+ self . times . shape [ 0 ]), dtype = np .uint8 )
325
328
burnin_ind = self .burnin // self .g
326
329
for i , (w , r ) in enumerate (zip (self .mcweights , self .mcrates )):
327
330
# compute probabilities
@@ -332,7 +335,7 @@ def _sample_indicator(self):
332
335
s = np .argmax (rng .multinomial (1 , z ), axis = 1 )
333
336
indicator [i ] = s
334
337
setattr (self , 'indicator' , indicator )
335
- return indicator [burnin_ind :]
338
+ return indicator [burnin_ind :: self . gskip ]
336
339
337
340
def save (self ):
338
341
"""
0 commit comments