Skip to content

Commit

Permalink
Update dien.py
Browse files Browse the repository at this point in the history
add a global variable initializer op in model def
  • Loading branch information
Weichen Shen committed Mar 30, 2019
1 parent 542b39d commit 4f5cd38
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions deepctr/models/dien.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def interest_evolution(concat_behavior, deep_input_item, user_behavior_length, g
hist = AttentionSequencePoolingLayer(hidden_size=att_hidden_size, activation=att_activation, weight_normalization=att_weight_normalization, return_score=False)([
deep_input_item, rnn_outputs2, user_behavior_length])

else:#AIGRU AGRU AUGRU
else: # AIGRU AGRU AUGRU

scores = AttentionSequencePoolingLayer(hidden_size=att_hidden_size, activation=att_activation, weight_normalization=att_weight_normalization, return_score=True)([
deep_input_item, rnn_outputs, user_behavior_length])
Expand All @@ -125,7 +125,7 @@ def interest_evolution(concat_behavior, deep_input_item, user_behavior_length, g
hist = multiply([rnn_outputs, Permute([2, 1])(scores)])
final_state2 = DynamicGRU(embedding_size * 2, gru_type="GRU", return_sequence=False, name='gru2')(
[hist, user_behavior_length])
else:#AGRU AUGRU
else: # AGRU AUGRU
final_state2 = DynamicGRU(embedding_size * 2, gru_type=gru_type, return_sequence=False,
name='gru2')([rnn_outputs, user_behavior_length, Permute([2, 1])(scores)])
hist = final_state2
Expand Down Expand Up @@ -220,4 +220,5 @@ def DIEN(feature_dim_dict, seq_feature_list, embedding_size=8, hist_len_max=16,

if use_negsampling:
model.add_loss(alpha * aux_loss_1)
tf.keras.backend.get_session().run(tf.global_variables_initializer())
return model

0 comments on commit 4f5cd38

Please sign in to comment.