Skip to content

Commit 4adfd11

Browse files
committed
Added global targets
1 parent 5755ef3 commit 4adfd11

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

TODO

+3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
Add tests for target
2+
Add tests for global targets
13
Add examples of specific documents to 20ng example
4+
Add better naming to categorical variables, e.g. like target variables
25
Keep track of doc counts between model serializations
36
Add bigramming
47
Add better README

lda2vec/lda2vec.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def add_target(self, name, target_dtype, out_size, loss_type):
152152
loss_func = getattr(chainer.functions, loss_type)
153153
msg = "Added loss function %s"
154154
self.logger.info(msg % loss_type)
155-
self.target_losses[name] = (transform, loss_func)
155+
self.target_losses[name] = (transform, loss_func, target_dtype)
156156

157157
def finalize(self):
158158
loss_func = L.NegativeSampling(self.n_hidden, self.counts,
@@ -248,23 +248,32 @@ def _skipgram_flat(self, words, cat_feats, ignore_below=3):
248248
return loss.data + 0.0
249249

250250
def _target(self, data_cat_feats, data_targets):
251+
""" Calculate the local losses relating individual document topic
252+
weights to the target prediction for those documents. Additionally
253+
calculate a regression on all documents to global targets.
254+
"""
251255
losses = None
252-
args = (data_cat_feats, data_targets, self.categorical_feature_names)
253-
for data_cat_feat, data_target, cat_feat_name in zip(*args):
256+
weights = []
257+
args = (data_cat_feats, self.categorical_feature_names)
258+
for data_cat_feat, cat_feat_name in zip(*args):
254259
cat_feat = self.categorical_features[cat_feat_name]
255260
embedding, transform, loss_func, penalty = cat_feat
261+
weights.append(embedding.unnormalized_weights(data_cat_feat))
262+
if loss_func is None:
263+
continue
256264
# This function will input an ID and ouput
257265
# (batchsize, n_hidden)
258266
latent = embedding(data_cat_feat)
259267
# Transform (batchsize, n_hidden) -> (batchsize, n_dim)
260268
# n_dim is 1 for RMSE, 1 for logistic outcomes, n for softmax
261269
output = transform(latent)
262270
# Loss_func gives likelihood of data_target given output
263-
l = loss_func(output, data_target)
271+
l = loss_func(output, data_targets[cat_feat_name])
264272
losses = l if losses is None else losses + l
265-
features = F.concat(data_cat_feats)
266-
for name, (transform, loss_func) in self.target_losses.items():
267-
prediction = transform(features)
273+
# Construct the latent vectors for all doc_ids
274+
feature_values = F.concat(weights)
275+
for name, (transform, loss_func, dtype) in self.target_losses.items():
276+
prediction = transform(feature_values)
268277
data_target = data_targets[name]
269278
l = loss_func(prediction, data_target)
270279
losses = l if losses is None else losses + l
@@ -273,42 +282,41 @@ def _target(self, data_cat_feats, data_targets):
273282
return losses
274283

275284
def _check_input(self, word_matrix, categorical_features, targets):
285+
to_var = lambda c: Variable(self.xp.asarray(c.astype('int32')))
286+
to_vard = lambda c, dt: Variable(self.xp.asarray(c.astype(dt)))
276287
if word_matrix is not None:
277288
word_matrix = word_matrix.astype('int32')
278289
if self._finalized is False:
279290
self.finalize()
280291
if isinstance(categorical_features, (np.ndarray, np.generic)):
281292
# If we pass in a single categorical feature, wrap it into a list
282293
categorical_features = [categorical_features]
283-
if isinstance(targets, (np.ndarray, np.generic)):
284-
# If we pass in a single target, wrap it into a list
285-
targets = [targets]
294+
msg = "target variable must be of format {'target_name': nd.array}"
295+
assert not isinstance(targets, (np.ndarray, np.generic)), msg
286296
if categorical_features is None:
287297
categorical_features = []
288298
else:
289299
msg = "Number of categorical features not equal to initialized"
290300
test = len(categorical_features) == len(self.categorical_features)
291301
assert test, msg
292-
to_var = lambda c: Variable(self.xp.asarray(c.astype('int32')))
293302
categorical_features = [to_var(c) for c in categorical_features]
294303
if targets is None:
295-
targets = []
304+
targets = {}
296305
else:
297-
msg = "Number of targets not equal to initialized no. of targets"
298-
vals = self.categorical_features.values()
299-
assert len(targets) == sum([c[2] is not None for c in vals])
306+
msg = "target %s shape not equal to other inputs"
307+
new_targets = {}
308+
for name, target in targets.items():
309+
assert len(target) == word_matrix.shape[0], msg % name
310+
_, _, dtype = self.target_losses[name]
311+
vtarget = to_vard(target, dtype)
312+
new_targets[name] = vtarget
300313
for i, categorical_feature in enumerate(categorical_features):
301314
msg = "Number of rows in word matrix unequal"
302315
msg += "to that in categorical feature #%i" % i
303316
if word_matrix is not None:
304317
assert word_matrix.shape[0] == \
305318
categorical_feature.data.shape[0], msg
306-
for i, target in enumerate(targets):
307-
msg = "Number of rows in word matrix unequal"
308-
msg += "to that in target array %i" % i
309-
if word_matrix is not None:
310-
assert word_matrix.shape[0] == target.data.shape[0], msg
311-
return word_matrix, categorical_features, targets
319+
return word_matrix, categorical_features, new_targets
312320

313321
def _log_prob_words(self, context, temperature=1.0):
314322
""" This calculates a softmax over the vocabulary as a function

0 commit comments

Comments
 (0)