@@ -55,8 +55,8 @@ class GraphConv(nn.Module):
55
55
Graph Convolutional Network
56
56
"""
57
57
def __init__ (self , channel , n_hops , n_users ,
58
- n_factors , n_relations ,
59
- interact_mat , node_dropout_rate = 0.5 , mess_dropout_rate = 0.1 ):
58
+ n_factors , n_relations , interact_mat ,
59
+ ind , node_dropout_rate = 0.5 , mess_dropout_rate = 0.1 ):
60
60
super (GraphConv , self ).__init__ ()
61
61
62
62
self .convs = nn .ModuleList ()
@@ -66,6 +66,7 @@ def __init__(self, channel, n_hops, n_users,
66
66
self .n_factors = n_factors
67
67
self .node_dropout_rate = node_dropout_rate
68
68
self .mess_dropout_rate = mess_dropout_rate
69
+ self .ind = ind
69
70
70
71
self .temperature = 0.2
71
72
@@ -103,22 +104,21 @@ def _sparse_dropout(self, x, rate=0.5):
103
104
out = torch .sparse .FloatTensor (i , v , x .shape ).to (x .device )
104
105
return out * (1. / (1 - rate ))
105
106
106
- def _cul_cor_pro (self ):
107
- # disen_T: [num_factor, dimension]
108
- disen_T = self .disen_weight_att .t ()
109
-
110
- # normalized_disen_T: [num_factor, dimension]
111
- normalized_disen_T = disen_T / disen_T .norm (dim = 1 , keepdim = True )
112
-
113
- pos_scores = torch .sum (normalized_disen_T * normalized_disen_T , dim = 1 )
114
- ttl_scores = torch .sum (torch .mm (disen_T , self .disen_weight_att ), dim = 1 )
115
-
116
- pos_scores = torch .exp (pos_scores / self .temperature )
117
- ttl_scores = torch .exp (ttl_scores / self .temperature )
118
-
119
- mi_score = - torch .sum (torch .log (pos_scores / ttl_scores ))
120
- return mi_score
121
-
107
+ # def _cul_cor_pro(self):
108
+ # # disen_T: [num_factor, dimension]
109
+ # disen_T = self.disen_weight_att.t()
110
+ #
111
+ # # normalized_disen_T: [num_factor, dimension]
112
+ # normalized_disen_T = disen_T / disen_T.norm(dim=1, keepdim=True)
113
+ #
114
+ # pos_scores = torch.sum(normalized_disen_T * normalized_disen_T, dim=1)
115
+ # ttl_scores = torch.sum(torch.mm(disen_T, self.disen_weight_att), dim=1)
116
+ #
117
+ # pos_scores = torch.exp(pos_scores / self.temperature)
118
+ # ttl_scores = torch.exp(ttl_scores / self.temperature)
119
+ #
120
+ # mi_score = - torch.sum(torch.log(pos_scores / ttl_scores))
121
+ # return mi_score
122
122
123
123
def _cul_cor (self ):
124
124
def CosineSimilarity (tensor_1 , tensor_2 ):
@@ -146,11 +146,33 @@ def DistanceCorrelation(tensor_1, tensor_2):
146
146
dcov_AA = torch .sqrt (torch .max ((A * A ).sum () / channel ** 2 , zero ) + 1e-8 )
147
147
dcov_BB = torch .sqrt (torch .max ((B * B ).sum () / channel ** 2 , zero ) + 1e-8 )
148
148
return dcov_AB / torch .sqrt (dcov_AA * dcov_BB + 1e-8 )
149
- cor = 0
149
+ def MutualInformation ():
150
+ # disen_T: [num_factor, dimension]
151
+ disen_T = self .disen_weight_att .t ()
152
+
153
+ # normalized_disen_T: [num_factor, dimension]
154
+ normalized_disen_T = disen_T / disen_T .norm (dim = 1 , keepdim = True )
155
+
156
+ pos_scores = torch .sum (normalized_disen_T * normalized_disen_T , dim = 1 )
157
+ ttl_scores = torch .sum (torch .mm (disen_T , self .disen_weight_att ), dim = 1 )
158
+
159
+ pos_scores = torch .exp (pos_scores / self .temperature )
160
+ ttl_scores = torch .exp (ttl_scores / self .temperature )
161
+
162
+ mi_score = - torch .sum (torch .log (pos_scores / ttl_scores ))
163
+ return mi_score
164
+
150
165
"""cul similarity for each latent factor weight pairs"""
151
- for i in range (self .n_factors ):
152
- for j in range (i + 1 , self .n_factors ):
153
- cor += DistanceCorrelation (self .disen_weight_att [i ], self .disen_weight_att [j ])
166
+ if self .ind == 'mi' :
167
+ return MutualInformation ()
168
+ else :
169
+ cor = 0
170
+ for i in range (self .n_factors ):
171
+ for j in range (i + 1 , self .n_factors ):
172
+ if self .ind == 'distance' :
173
+ cor += DistanceCorrelation (self .disen_weight_att [i ], self .disen_weight_att [j ])
174
+ else :
175
+ cor += CosineSimilarity (self .disen_weight_att [i ], self .disen_weight_att [j ])
154
176
return cor
155
177
156
178
def forward (self , user_emb , entity_emb , latent_emb , edge_index , edge_type ,
@@ -163,8 +185,7 @@ def forward(self, user_emb, entity_emb, latent_emb, edge_index, edge_type,
163
185
164
186
entity_res_emb = entity_emb # [n_entity, channel]
165
187
user_res_emb = user_emb # [n_users, channel]
166
- # cor = self._cul_cor()
167
- cor = self ._cul_cor_pro ()
188
+ cor = self ._cul_cor ()
168
189
for i in range (len (self .convs )):
169
190
entity_emb , user_emb = self .convs [i ](entity_emb , user_emb , latent_emb ,
170
191
edge_index , edge_type , interact_mat ,
@@ -203,6 +224,7 @@ def __init__(self, data_config, args_config, graph, adj_mat):
203
224
self .node_dropout_rate = args_config .node_dropout_rate
204
225
self .mess_dropout = args_config .mess_dropout
205
226
self .mess_dropout_rate = args_config .mess_dropout_rate
227
+ self .ind = args_config .ind
206
228
self .device = torch .device ("cuda:" + str (args_config .gpu_id )) if args_config .cuda \
207
229
else torch .device ("cpu" )
208
230
@@ -231,6 +253,7 @@ def _init_model(self):
231
253
n_relations = self .n_relations ,
232
254
n_factors = self .n_factors ,
233
255
interact_mat = self .interact_mat ,
256
+ ind = self .ind ,
234
257
node_dropout_rate = self .node_dropout_rate ,
235
258
mess_dropout_rate = self .mess_dropout_rate )
236
259
0 commit comments