Skip to content

Commit a093ed4

Browse files
committed
addedd option for gskip
1 parent ea4e514 commit a093ed4

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

basicrta/cluster.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@ class ProcessProtein(object):
2828
:type cutoff: float
2929
"""
3030

31-
def __init__(self, niter, prot, cutoff):
31+
def __init__(self, niter, prot, cutoff, gskip):
3232
self.residues = {}
3333
self.niter = niter
3434
self.prot = prot
3535
self.cutoff = cutoff
36+
self.gskip = gskip
3637

3738
def __getitem__(self, item):
3839
return getattr(self, item)
@@ -43,6 +44,7 @@ def _single_residue(self, adir, process=False):
4344
result = f'{adir}/gibbs_{self.niter}.pkl'
4445
g = Gibbs().load(result)
4546
if process:
47+
g.gskip = self.gskip
4648
g.process_gibbs()
4749
except ValueError:
4850
result = None

basicrta/gibbs.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,12 @@ def cluster(self, method="GaussianMixture", **kwargs):
231231
from scipy import stats
232232

233233
clu = getattr(mixture, method)
234-
burnin_ind = self.burnin // (self.g*self.gskip)
234+
burnin_ind = self.burnin // self.g
235235
data_len = len(self.times)
236236
wcutoff = 10 / data_len
237237

238-
weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:]
238+
weights = self.mcweights[burnin_ind::self.gskip]
239+
rates = self.mcrates[burnin_ind::self.gskip]
239240
lens = np.array([len(row[row > wcutoff]) for row in weights])
240241
lmode = stats.mode(lens).mode
241242
train_param = lmode
@@ -258,7 +259,7 @@ def cluster(self, method="GaussianMixture", **kwargs):
258259
all_labels = r.predict(np.log(data))
259260

260261
if self.indicator is not None:
261-
indicator = self.indicator[burnin_ind:]
262+
indicator = self.indicator[burnin_ind::self.gskip]
262263
else:
263264
indicator = self._sample_indicator()
264265

@@ -285,13 +286,15 @@ def process_gibbs(self):
285286
data_len = len(self.times)
286287
wcutoff = 10/data_len
287288
burnin_ind = self.burnin//self.g
288-
inds = np.where(self.mcweights[burnin_ind:] > wcutoff)
289+
inds = np.where(self.mcweights[burnin_ind::self.gskip] > wcutoff)
289290
indices = (np.arange(self.burnin, self.niter + 1, self.g*self.gskip)
290291
[inds[0]] // self.g)
291-
weights, rates = self.mcweights[burnin_ind:], self.mcrates[burnin_ind:]
292+
weights = self.mcweights[burnin_ind::self.gskip]
293+
rates = self.mcrates[burnin_ind::self.gskip]
292294
fweights, frates = weights[inds], rates[inds]
293295

294-
lens = [len(row[row > wcutoff]) for row in self.mcweights[burnin_ind:]]
296+
lens = [len(row[row > wcutoff]) for row in
297+
self.mcweights[burnin_ind::self.gskip]]
295298
lmode = stats.mode(lens).mode
296299

297300
self.cluster(n_init=117, n_components=lmode)
@@ -320,8 +323,8 @@ def result_plot(self, remove_noise=False, **kwargs):
320323
mixture_and_plot(self, remove_noise=remove_noise, **kwargs)
321324

322325
def _sample_indicator(self):
323-
indicator = np.zeros(((self.niter+1)//self.g, self.times.shape[0]),
324-
dtype=np.uint8)
326+
indicator = np.zeros(((self.niter+1)//(self.g*self.gskip),
327+
self.times.shape[0]), dtype=np.uint8)
325328
burnin_ind = self.burnin//self.g
326329
for i, (w, r) in enumerate(zip(self.mcweights, self.mcrates)):
327330
# compute probabilities
@@ -332,7 +335,7 @@ def _sample_indicator(self):
332335
s = np.argmax(rng.multinomial(1, z), axis=1)
333336
indicator[i] = s
334337
setattr(self, 'indicator', indicator)
335-
return indicator[burnin_ind:]
338+
return indicator[burnin_ind::self.gskip]
336339

337340
def save(self):
338341
"""

basicrta/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,8 @@ def mixture_and_plot(gibbs, scale=2, sparse=1, remove_noise=False, wlim=None,
704704
else:
705705
wmin, wmax = wcutoff, 2
706706

707-
weights, rates = gibbs.mcweights[burnin_ind:], gibbs.mcrates[burnin_ind:]
707+
weights = gibbs.mcweights[burnin_ind::gibbs.gskip]
708+
rates = gibbs.mcrates[burnin_ind::gibbs.gskip]
708709
lens = np.array([len(row[row > wcutoff]) for row in weights])
709710
lmode = stats.mode(lens).mode
710711
train_param = lmode

0 commit comments

Comments
 (0)