Skip to content

Commit 93f2777

Browse files
committed
Including simplex dropout
1 parent a38f5e8 commit 93f2777

File tree

4 files changed

+25
-11
lines changed

4 files changed

+25
-11
lines changed

examples/twenty_newsgroups/lda2vec/lda2vec_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ class LDA2Vec(Chain):
1313
def __init__(self, n_documents=100, n_document_topics=10,
1414
n_units=256, n_vocab=1000, dropout_ratio=0.5, train=True,
1515
counts=None, n_samples=15, word_dropout_ratio=0.0,
16-
power=0.75):
16+
power=0.75, temperature=1.0):
1717
em = EmbedMixture(n_documents, n_document_topics, n_units,
18-
dropout_ratio=dropout_ratio)
18+
dropout_ratio=dropout_ratio, temperature=temperature)
1919
kwargs = {}
2020
kwargs['mixture'] = em
2121
kwargs['sampler'] = L.NegativeSampling(n_units, counts, n_samples,

examples/twenty_newsgroups/lda2vec/lda2vec_run.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
power = float(os.getenv('power', 0.75))
5050
# Intialize with pretrained word vectors
5151
pretrained = bool(int(os.getenv('pretrained', True)))
52+
# Sampling temperature
53+
temperature = float(os.getenv('temperature', 1.0))
5254
# Number of dimensions in a single word vector
5355
n_units = int(os.getenv('n_units', 300))
5456
# Get the string representation for every compact key
@@ -69,7 +71,7 @@
6971

7072
model = LDA2Vec(n_documents=n_docs, n_document_topics=n_topics,
7173
n_units=n_units, n_vocab=n_vocab, counts=term_frequency,
72-
n_samples=15, power=power)
74+
n_samples=15, power=power, temperature=temperature)
7375
if os.path.exists('lda2vec.hdf5'):
7476
print "Reloading from saved"
7577
serializers.load_hdf5("lda2vec.hdf5", model)
@@ -91,11 +93,12 @@
9193
cuda.to_cpu(model.sampler.W.data).copy(),
9294
words)
9395
top_words = print_top_words_per_topic(data)
94-
coherence = topic_coherence(top_words)
95-
for j in range(n_topics):
96-
print j, coherence[(j, 'cv')]
97-
kw = dict(top_words=top_words, coherence=coherence, epoch=epoch)
98-
progress[str(epoch)] = pickle.dumps(kw)
96+
if j % 100 == 0 and j > 100:
97+
coherence = topic_coherence(top_words)
98+
for j in range(n_topics):
99+
print j, coherence[(j, 'cv')]
100+
kw = dict(top_words=top_words, coherence=coherence, epoch=epoch)
101+
progress[str(epoch)] = pickle.dumps(kw)
99102
data['doc_lengths'] = doc_lengths
100103
data['term_frequency'] = term_frequency
101104
np.savez('topics.pyldavis', **data)
Binary file not shown.

lda2vec/embed_mixture.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import chainer
44
import chainer.links as L
55
import chainer.functions as F
6+
from chainer import Variable
67

78

89
def _orthogonal_matrix(shape):
@@ -60,7 +61,8 @@ class EmbedMixture(chainer.Chain):
6061
.. seealso:: :func:`lda2vec.dirichlet_likelihood`
6162
"""
6263

63-
def __init__(self, n_documents, n_topics, n_dim, dropout_ratio=0.2):
64+
def __init__(self, n_documents, n_topics, n_dim, dropout_ratio=0.2,
65+
temperature=1.0):
6466
self.n_documents = n_documents
6567
self.n_topics = n_topics
6668
self.n_dim = n_dim
@@ -70,6 +72,7 @@ def __init__(self, n_documents, n_topics, n_dim, dropout_ratio=0.2):
7072
super(EmbedMixture, self).__init__(
7173
weights=L.EmbedID(n_documents, n_topics),
7274
factors=L.Parameter(factors))
75+
self.temperature = temperature
7376
self.weights.W.data[...] /= np.sqrt(n_documents + n_topics)
7477

7578
def __call__(self, doc_ids, update_only_docs=False):
@@ -102,5 +105,13 @@ def proportions(self, doc_ids, softmax=False):
102105
doc_weights : chainer.Variable
103106
Two dimensional topic weights of each document.
104107
"""
105-
w = F.dropout(self.weights(doc_ids), ratio=self.dropout_ratio)
106-
return F.softmax(w) if softmax else w
108+
w = self.weights(doc_ids)
109+
if softmax:
110+
size = w.data.shape
111+
mask = self.xp.random.random_integers(0, 1, size=size)
112+
y = (F.softmax(w * self.temperature) *
113+
Variable(mask.astype('float32')))
114+
norm, y = F.broadcast(F.expand_dims(F.sum(y, axis=1), 1), y)
115+
return y / (norm + 1e-7)
116+
else:
117+
return w

0 commit comments

Comments
 (0)