@@ -152,7 +152,7 @@ def add_target(self, name, target_dtype, out_size, loss_type):
152
152
loss_func = getattr (chainer .functions , loss_type )
153
153
msg = "Added loss function %s"
154
154
self .logger .info (msg % loss_type )
155
- self .target_losses [name ] = (transform , loss_func )
155
+ self .target_losses [name ] = (transform , loss_func , target_dtype )
156
156
157
157
def finalize (self ):
158
158
loss_func = L .NegativeSampling (self .n_hidden , self .counts ,
@@ -248,23 +248,32 @@ def _skipgram_flat(self, words, cat_feats, ignore_below=3):
248
248
return loss .data + 0.0
249
249
250
250
def _target (self , data_cat_feats , data_targets ):
251
+ """ Calculate the local losses relating individual document topic
252
+ weights to the target prediction for those documents. Additionally
253
+ calculate a regression on all documents to global targets.
254
+ """
251
255
losses = None
252
- args = (data_cat_feats , data_targets , self .categorical_feature_names )
253
- for data_cat_feat , data_target , cat_feat_name in zip (* args ):
256
+ weights = []
257
+ args = (data_cat_feats , self .categorical_feature_names )
258
+ for data_cat_feat , cat_feat_name in zip (* args ):
254
259
cat_feat = self .categorical_features [cat_feat_name ]
255
260
embedding , transform , loss_func , penalty = cat_feat
261
+ weights .append (embedding .unnormalized_weights (data_cat_feat ))
262
+ if loss_func is None :
263
+ continue
256
264
# This function will input an ID and ouput
257
265
# (batchsize, n_hidden)
258
266
latent = embedding (data_cat_feat )
259
267
# Transform (batchsize, n_hidden) -> (batchsize, n_dim)
260
268
# n_dim is 1 for RMSE, 1 for logistic outcomes, n for softmax
261
269
output = transform (latent )
262
270
# Loss_func gives likelihood of data_target given output
263
- l = loss_func (output , data_target )
271
+ l = loss_func (output , data_targets [ cat_feat_name ] )
264
272
losses = l if losses is None else losses + l
265
- features = F .concat (data_cat_feats )
266
- for name , (transform , loss_func ) in self .target_losses .items ():
267
- prediction = transform (features )
273
+ # Construct the latent vectors for all doc_ids
274
+ feature_values = F .concat (weights )
275
+ for name , (transform , loss_func , dtype ) in self .target_losses .items ():
276
+ prediction = transform (feature_values )
268
277
data_target = data_targets [name ]
269
278
l = loss_func (prediction , data_target )
270
279
losses = l if losses is None else losses + l
@@ -273,42 +282,41 @@ def _target(self, data_cat_feats, data_targets):
273
282
return losses
274
283
275
284
def _check_input (self , word_matrix , categorical_features , targets ):
285
+ to_var = lambda c : Variable (self .xp .asarray (c .astype ('int32' )))
286
+ to_vard = lambda c , dt : Variable (self .xp .asarray (c .astype (dt )))
276
287
if word_matrix is not None :
277
288
word_matrix = word_matrix .astype ('int32' )
278
289
if self ._finalized is False :
279
290
self .finalize ()
280
291
if isinstance (categorical_features , (np .ndarray , np .generic )):
281
292
# If we pass in a single categorical feature, wrap it into a list
282
293
categorical_features = [categorical_features ]
283
- if isinstance (targets , (np .ndarray , np .generic )):
284
- # If we pass in a single target, wrap it into a list
285
- targets = [targets ]
294
+ msg = "target variable must be of format {'target_name': nd.array}"
295
+ assert not isinstance (targets , (np .ndarray , np .generic )), msg
286
296
if categorical_features is None :
287
297
categorical_features = []
288
298
else :
289
299
msg = "Number of categorical features not equal to initialized"
290
300
test = len (categorical_features ) == len (self .categorical_features )
291
301
assert test , msg
292
- to_var = lambda c : Variable (self .xp .asarray (c .astype ('int32' )))
293
302
categorical_features = [to_var (c ) for c in categorical_features ]
294
303
if targets is None :
295
- targets = []
304
+ targets = {}
296
305
else :
297
- msg = "Number of targets not equal to initialized no. of targets"
298
- vals = self .categorical_features .values ()
299
- assert len (targets ) == sum ([c [2 ] is not None for c in vals ])
306
+ msg = "target %s shape not equal to other inputs"
307
+ new_targets = {}
308
+ for name , target in targets .items ():
309
+ assert len (target ) == word_matrix .shape [0 ], msg % name
310
+ _ , _ , dtype = self .target_losses [name ]
311
+ vtarget = to_vard (target , dtype )
312
+ new_targets [name ] = vtarget
300
313
for i , categorical_feature in enumerate (categorical_features ):
301
314
msg = "Number of rows in word matrix unequal"
302
315
msg += "to that in categorical feature #%i" % i
303
316
if word_matrix is not None :
304
317
assert word_matrix .shape [0 ] == \
305
318
categorical_feature .data .shape [0 ], msg
306
- for i , target in enumerate (targets ):
307
- msg = "Number of rows in word matrix unequal"
308
- msg += "to that in target array %i" % i
309
- if word_matrix is not None :
310
- assert word_matrix .shape [0 ] == target .data .shape [0 ], msg
311
- return word_matrix , categorical_features , targets
319
+ return word_matrix , categorical_features , new_targets
312
320
313
321
def _log_prob_words (self , context , temperature = 1.0 ):
314
322
""" This calculates a softmax over the vocabulary as a function
0 commit comments