Skip to content

Commit

Permalink
Cythonized sslm_counts_init method of sslm class
Browse files Browse the repository at this point in the history
Reduced time for sslm initialization (it was especially critical for large datasets). Removed duplicated code.
  • Loading branch information
Sharganov committed Jun 17, 2021
1 parent c3efd57 commit 2015bdb
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 174 deletions.
62 changes: 60 additions & 2 deletions gensim/models/ldaseq_sslm_inner.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,62 @@ import numpy as np
from scipy import optimize


def sslm_counts_init(model, obs_variance, chain_variance, sstats):
"""Initialize the State Space Language Model with LDA sufficient statistics.
Called for each topic-chain and initializes initial mean, variance and Topic-Word probabilities
for the first time-slice.
Parameters
----------
obs_variance : float, optional
Observed variance used to approximate the true and forward variance.
chain_variance : float
Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
sstats : numpy.ndarray
Sufficient statistics of the LDA model. Corresponds to matrix beta in the linked paper for time slice 0,
expected shape (`self.vocab_len`, `num_topics`).
"""
W = model.vocab_len
T = model.num_time_slices

log_norm_counts = np.copy(sstats)
log_norm_counts /= sum(log_norm_counts)
log_norm_counts += 1.0 / W
log_norm_counts /= sum(log_norm_counts)
log_norm_counts = np.log(log_norm_counts)

cdef StateSpaceLanguageModelConfig * config = <StateSpaceLanguageModelConfig *> malloc(
sizeof(StateSpaceLanguageModelConfig))


# setting variational observations to transformed counts
model.obs = (np.repeat(log_norm_counts, T, axis=0)).reshape(W, T)
# set variational parameters
model.obs_variance = obs_variance
model.chain_variance = chain_variance

init_sslm_config(config, model)

cdef int w
cdef int vocab_len = model.vocab_len

# # compute post variance, mean
for w in range(vocab_len):
compute_post_variance(config.variance, config.fwd_variance,
config.obs_variance, config.chain_variance,
w, config.num_time_slices)

compute_post_mean(config.mean, config.fwd_mean, config.fwd_variance,
config.obs, w, config.num_time_slices,
config.obs_variance, config.chain_variance)

update_zeta(config.zeta, config.mean, config.variance, config.num_time_slices, config.vocab_len)
compute_expected_log_prob(config.e_log_prob, config.zeta, config.mean,
config.vocab_len, config.num_time_slices)
model.config_c_address = <uintptr_t>(config)

cdef compute_post_mean(REAL_t *mean, REAL_t *fwd_mean, const REAL_t *fwd_variance, const REAL_t *obs,
const int word, const int num_time_slices,
const REAL_t obs_variance, const REAL_t chain_variance):
Expand Down Expand Up @@ -698,7 +754,8 @@ def fit_sslm(model, np_sstats):
"""

# Initialize C structures based on Python instance of the model
cdef StateSpaceLanguageModelConfig* config = <StateSpaceLanguageModelConfig *>malloc(sizeof(StateSpaceLanguageModelConfig))
cdef StateSpaceLanguageModelConfig * config = <StateSpaceLanguageModelConfig *> (<uintptr_t>(model.config_c_address))

init_sslm_config(config, model)

cdef int W = config[0].vocab_len
Expand Down Expand Up @@ -736,5 +793,6 @@ def fit_sslm(model, np_sstats):
compute_expected_log_prob(config[0].e_log_prob, config[0].zeta, config[0].mean,
W, config[0].num_time_slices)

free(config)
# TODO find a way/place where to free a memory
# free(config)
return bound
175 changes: 3 additions & 172 deletions gensim/models/ldaseqmodel_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from gensim import utils, matutils
from gensim.models import ldamodel
from .ldaseq_sslm_inner import fit_sslm
from .ldaseq_sslm_inner import fit_sslm, sslm_counts_init
from .ldaseq_posterior_inner import fit_lda_post

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -670,157 +670,8 @@ def __init__(self, vocab_len=None, num_time_slices=None, num_topics=None, obs_va
self.w_phi_sum = None
self.w_phi_l_sq = None
self.m_update_coeff_g = None
self.config_c_address = 0

def update_zeta(self):
"""Update the Zeta variational parameter.
Zeta is described in the appendix and is equal to sum (exp(mean[word] + Variance[word] / 2)),
over every time-slice. It is the value of variational parameter zeta which maximizes the lower bound.
Returns
-------
list of float
The updated zeta values for each time slice.
"""
for j, val in enumerate(self.zeta):
self.zeta[j] = np.sum(np.exp(self.mean[:, j + 1] + self.variance[:, j + 1] / 2))
return self.zeta

def compute_post_variance(self, word, chain_variance):
r"""Get the variance, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1)
<https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf>`_.
This function accepts the word to compute variance for, along with the associated sslm class object,
and returns the `variance` and the posterior approximation `fwd_variance`.
Notes
-----
This function essentially computes Var[\beta_{t,w}] for t = 1:T
.. :math::
fwd\_variance[t] \equiv E((beta_{t,w}-mean_{t,w})^2 |beta_{t}\ for\ 1:t) =
(obs\_variance / fwd\_variance[t - 1] + chain\_variance + obs\_variance ) *
(fwd\_variance[t - 1] + obs\_variance)
.. :math::
variance[t] \equiv E((beta_{t,w}-mean\_cap_{t,w})^2 |beta\_cap_{t}\ for\ 1:t) =
fwd\_variance[t - 1] + (fwd\_variance[t - 1] / fwd\_variance[t - 1] + obs\_variance)^2 *
(variance[t - 1] - (fwd\_variance[t-1] + obs\_variance))
Parameters
----------
word: int
The word's ID.
chain_variance : float
Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
Returns
-------
(numpy.ndarray, numpy.ndarray)
The first returned value is the variance of each word in each time slice, the second value is the
inferred posterior variance for the same pairs.
"""
INIT_VARIANCE_CONST = 1000

T = self.num_time_slices
variance = self.variance[word]
fwd_variance = self.fwd_variance[word]
# forward pass. Set initial variance very high
fwd_variance[0] = chain_variance * INIT_VARIANCE_CONST
for t in range(1, T + 1):
if self.obs_variance:
c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance)
else:
c = 0
fwd_variance[t] = c * (fwd_variance[t - 1] + chain_variance)

# backward pass
variance[T] = fwd_variance[T]
for t in range(T - 1, -1, -1):
if fwd_variance[t] > 0.0:
c = np.power((fwd_variance[t] / (fwd_variance[t] + chain_variance)), 2)
else:
c = 0
variance[t] = (c * (variance[t + 1] - chain_variance)) + ((1 - c) * fwd_variance[t])

return variance, fwd_variance

def compute_post_mean(self, word, chain_variance):
"""Get the mean, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1)
<https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf>`_.
Notes
-----
This function essentially computes E[\beta_{t,w}] for t = 1:T.
.. :math::
Fwd_Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:t )
= (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * fwd_mean[t - 1] +
(1 - (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance)) * beta
.. :math::
Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:T )
= fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) +
(1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t]
Parameters
----------
word: int
The word's ID.
chain_variance : float
Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
Returns
-------
(numpy.ndarray, numpy.ndarray)
The first returned value is the mean of each word in each time slice, the second value is the
inferred posterior mean for the same pairs.
"""
T = self.num_time_slices
obs = self.obs[word]
fwd_variance = self.fwd_variance[word]
mean = self.mean[word]
fwd_mean = self.fwd_mean[word]

# forward
fwd_mean[0] = 0
for t in range(1, T + 1):
c = self.obs_variance / (fwd_variance[t - 1] + chain_variance + self.obs_variance)
fwd_mean[t] = c * fwd_mean[t - 1] + (1 - c) * obs[t - 1]

# backward pass
mean[T] = fwd_mean[T]
for t in range(T - 1, -1, -1):
if chain_variance == 0.0:
c = 0.0
else:
c = chain_variance / (fwd_variance[t] + chain_variance)
mean[t] = c * fwd_mean[t] + (1 - c) * mean[t + 1]
return mean, fwd_mean

def compute_expected_log_prob(self):
"""Compute the expected log probability given values of m.
The appendix describes the Expectation of log-probabilities in equation 5 of the DTM paper;
The below implementation is the result of solving the equation and is implemented as in the original
Blei DTM code.
Returns
-------
numpy.ndarray of float
The expected value for the log probabilities for each word and time slice.
"""
for (w, t), val in np.ndenumerate(self.e_log_prob):
self.e_log_prob[w][t] = self.mean[w][t + 1] - np.log(self.zeta[t])
return self.e_log_prob

def sslm_counts_init(self, obs_variance, chain_variance, sstats):
"""Initialize the State Space Language Model with LDA sufficient statistics.
Expand All @@ -839,28 +690,8 @@ def sslm_counts_init(self, obs_variance, chain_variance, sstats):
expected shape (`self.vocab_len`, `num_topics`).
"""
W = self.vocab_len
T = self.num_time_slices

log_norm_counts = np.copy(sstats)
log_norm_counts /= sum(log_norm_counts)
log_norm_counts += 1.0 / W
log_norm_counts /= sum(log_norm_counts)
log_norm_counts = np.log(log_norm_counts)

# setting variational observations to transformed counts
self.obs = (np.repeat(log_norm_counts, T, axis=0)).reshape(W, T)
# set variational parameters
self.obs_variance = obs_variance
self.chain_variance = chain_variance

# # compute post variance, mean
for w in range(W):
self.variance[w], self.fwd_variance[w] = self.compute_post_variance(w, self.chain_variance)
self.mean[w], self.fwd_mean[w] = self.compute_post_mean(w, self.chain_variance)

self.zeta = self.update_zeta()
self.e_log_prob = self.compute_expected_log_prob()
sslm_counts_init(self, obs_variance, chain_variance, sstats)

def fit_sslm(self, sstats):
"""Fits variational distribution.
Expand Down

0 comments on commit 2015bdb

Please sign in to comment.