Skip to content

Commit 6848a8e

Browse files
committed
added tests
1 parent a4e74ab commit 6848a8e

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*.pyo
55
*.cpp
66
*.so
7+
*.swp
78
build
89
\#*\#
910
.\#*

lda2vec/EmbedFactor.py renamed to lda2vec/embed_factor.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,24 @@ class EmbedMixture(chainer.Chain):
1919
uninterpretable until you measure the words most similar to this topic
2020
vector.
2121
22+
:math:`e=\Sigma_{j=0}^{j=n\_topics} c_j \cdot \vec{T_j}`
23+
24+
This is usually paired with regularization on the weights `c_j`. If using
25+
a Dirichlet prior with low alpha, these weights will be sparse.
26+
2227
Args:
2328
n_documents (int): Total number of documents
2429
n_topics (int): Number of topics per document
2530
n_dim (int): Number of dimensions per topic vector (should match word
2631
vector size)
2732
2833
Attributes:
29-
weights (~chainer.links.EmbedID): Unnormalized topic weights. To
30-
normalize these weights, use `F.softmax(weights)`.
31-
factors (~chainer.links.Parameter): Topic vector matrix.
34+
weights (~chainer.links.EmbedID): Unnormalized topic weights
35+
(:math:`c_j`). To normalize these weights, use
36+
`F.softmax(weights)`.
37+
factors (~chainer.links.Parameter): Topic vector matrix (:math:`T_j`)
38+
39+
.. seealso:: :func:`lda2vec.dirichlet_likelihood`
3240
"""
3341

3442
def __init__(self, n_documents, n_topics, n_dim):
@@ -52,7 +60,16 @@ def to_cpu(self):
5260
super(EmbedMixture, self).to_cpu()
5361

5462
def __call__(self, doc_ids):
55-
"""
63+
""" Given an array of document integer indices, returns a vector
64+
for each document. The vector is composed of topic weights projected
65+
onto topic vectors.
66+
67+
Args:
68+
doc_ids (~chainer.Variable): One-dimensional batch vectors of IDs
69+
70+
Returns:
71+
~chainer.Variable: Batch of two-dimensional embeddings for every
72+
document.
5673
"""
5774
# (batchsize, ) --> (batchsize, logweights)
5875
w = self.weights(doc_ids)

tests/test_dirichlet_likelihood.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
from chainer import Variable
3+
4+
from lda2vec import dirichlet_likelihood
5+
6+
7+
def test_concentration():
8+
""" Test that alpha > 1.0 on a dense vector has a higher likelihood
9+
than alpha < 1.0 on a dense vector, and test that a sparse vector
10+
has the opposite character. """
11+
12+
dense = np.abs(np.random.randn(5, 10, dtype='float32'))
13+
dense /= dense.max(axis=0)
14+
weights = Variable(dense)
15+
dhl_likely = dirichlet_likelihood(weights, alpha=10.0)
16+
dhl_unlikely = dirichlet_likelihood(weights, alpha=0.1)
17+
18+
assert dhl_likely > dhl_unlikely
19+
20+
sparse = np.abs(np.random.randn(5, 10, dtype='float32'))
21+
sparse[1:, :] = 0.0
22+
sparse /= sparse.max(axis=0)
23+
weights = Variable(sparse)
24+
dhl_unlikely = dirichlet_likelihood(weights, alpha=10.0)
25+
dhl_likely = dirichlet_likelihood(weights, alpha=0.1)
26+
27+
assert dhl_likely > dhl_unlikely

tests/test_embed_factor.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
from chainer import Variable
3+
4+
from lda2vec import EmbedMixture
5+
6+
7+
def softmax(v):
8+
return np.exp(v) / np.sum(np.exp(v))
9+
10+
11+
def test_embed_mixture():
12+
""" Manually test """
13+
# Ten documents, two topics, five hidden dimensions
14+
em = EmbedMixture(10, 2, 5)
15+
doc_ids = Variable(np.arange(1, dtype='int32'))
16+
doc_vector = em(doc_ids)
17+
# weights -- (n_topics)
18+
weights, = softmax(em.weights.data[0, :])
19+
# (n_hidden) = (n_topics) . (n_topics, n_hidden)
20+
doc_vector_test = weights * em.factors.data
21+
assert np.all_close(doc_vector, doc_vector_test)

0 commit comments

Comments
 (0)