Skip to content

Commit e5a6a7b

Browse files
committed
update independence modeling
1 parent 1f9ec85 commit e5a6a7b

File tree

2 files changed

+48
-24
lines changed

2 files changed

+48
-24
lines changed

modules/KGIN.py

+47-24
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ class GraphConv(nn.Module):
5555
Graph Convolutional Network
5656
"""
5757
def __init__(self, channel, n_hops, n_users,
58-
n_factors, n_relations,
59-
interact_mat, node_dropout_rate=0.5, mess_dropout_rate=0.1):
58+
n_factors, n_relations, interact_mat,
59+
ind, node_dropout_rate=0.5, mess_dropout_rate=0.1):
6060
super(GraphConv, self).__init__()
6161

6262
self.convs = nn.ModuleList()
@@ -66,6 +66,7 @@ def __init__(self, channel, n_hops, n_users,
6666
self.n_factors = n_factors
6767
self.node_dropout_rate = node_dropout_rate
6868
self.mess_dropout_rate = mess_dropout_rate
69+
self.ind = ind
6970

7071
self.temperature = 0.2
7172

@@ -103,22 +104,21 @@ def _sparse_dropout(self, x, rate=0.5):
103104
out = torch.sparse.FloatTensor(i, v, x.shape).to(x.device)
104105
return out * (1. / (1 - rate))
105106

106-
def _cul_cor_pro(self):
107-
# disen_T: [num_factor, dimension]
108-
disen_T = self.disen_weight_att.t()
109-
110-
# normalized_disen_T: [num_factor, dimension]
111-
normalized_disen_T = disen_T / disen_T.norm(dim=1, keepdim=True)
112-
113-
pos_scores = torch.sum(normalized_disen_T * normalized_disen_T, dim=1)
114-
ttl_scores = torch.sum(torch.mm(disen_T, self.disen_weight_att), dim=1)
115-
116-
pos_scores = torch.exp(pos_scores / self.temperature)
117-
ttl_scores = torch.exp(ttl_scores / self.temperature)
118-
119-
mi_score = - torch.sum(torch.log(pos_scores / ttl_scores))
120-
return mi_score
121-
107+
# def _cul_cor_pro(self):
108+
# # disen_T: [num_factor, dimension]
109+
# disen_T = self.disen_weight_att.t()
110+
#
111+
# # normalized_disen_T: [num_factor, dimension]
112+
# normalized_disen_T = disen_T / disen_T.norm(dim=1, keepdim=True)
113+
#
114+
# pos_scores = torch.sum(normalized_disen_T * normalized_disen_T, dim=1)
115+
# ttl_scores = torch.sum(torch.mm(disen_T, self.disen_weight_att), dim=1)
116+
#
117+
# pos_scores = torch.exp(pos_scores / self.temperature)
118+
# ttl_scores = torch.exp(ttl_scores / self.temperature)
119+
#
120+
# mi_score = - torch.sum(torch.log(pos_scores / ttl_scores))
121+
# return mi_score
122122

123123
def _cul_cor(self):
124124
def CosineSimilarity(tensor_1, tensor_2):
@@ -146,11 +146,33 @@ def DistanceCorrelation(tensor_1, tensor_2):
146146
dcov_AA = torch.sqrt(torch.max((A * A).sum() / channel ** 2, zero) + 1e-8)
147147
dcov_BB = torch.sqrt(torch.max((B * B).sum() / channel ** 2, zero) + 1e-8)
148148
return dcov_AB / torch.sqrt(dcov_AA * dcov_BB + 1e-8)
149-
cor = 0
149+
def MutualInformation():
150+
# disen_T: [num_factor, dimension]
151+
disen_T = self.disen_weight_att.t()
152+
153+
# normalized_disen_T: [num_factor, dimension]
154+
normalized_disen_T = disen_T / disen_T.norm(dim=1, keepdim=True)
155+
156+
pos_scores = torch.sum(normalized_disen_T * normalized_disen_T, dim=1)
157+
ttl_scores = torch.sum(torch.mm(disen_T, self.disen_weight_att), dim=1)
158+
159+
pos_scores = torch.exp(pos_scores / self.temperature)
160+
ttl_scores = torch.exp(ttl_scores / self.temperature)
161+
162+
mi_score = - torch.sum(torch.log(pos_scores / ttl_scores))
163+
return mi_score
164+
150165
"""cul similarity for each latent factor weight pairs"""
151-
for i in range(self.n_factors):
152-
for j in range(i + 1, self.n_factors):
153-
cor += DistanceCorrelation(self.disen_weight_att[i], self.disen_weight_att[j])
166+
if self.ind == 'mi':
167+
return MutualInformation()
168+
else:
169+
cor = 0
170+
for i in range(self.n_factors):
171+
for j in range(i + 1, self.n_factors):
172+
if self.ind == 'distance':
173+
cor += DistanceCorrelation(self.disen_weight_att[i], self.disen_weight_att[j])
174+
else:
175+
cor += CosineSimilarity(self.disen_weight_att[i], self.disen_weight_att[j])
154176
return cor
155177

156178
def forward(self, user_emb, entity_emb, latent_emb, edge_index, edge_type,
@@ -163,8 +185,7 @@ def forward(self, user_emb, entity_emb, latent_emb, edge_index, edge_type,
163185

164186
entity_res_emb = entity_emb # [n_entity, channel]
165187
user_res_emb = user_emb # [n_users, channel]
166-
# cor = self._cul_cor()
167-
cor = self._cul_cor_pro()
188+
cor = self._cul_cor()
168189
for i in range(len(self.convs)):
169190
entity_emb, user_emb = self.convs[i](entity_emb, user_emb, latent_emb,
170191
edge_index, edge_type, interact_mat,
@@ -203,6 +224,7 @@ def __init__(self, data_config, args_config, graph, adj_mat):
203224
self.node_dropout_rate = args_config.node_dropout_rate
204225
self.mess_dropout = args_config.mess_dropout
205226
self.mess_dropout_rate = args_config.mess_dropout_rate
227+
self.ind = args_config.ind
206228
self.device = torch.device("cuda:" + str(args_config.gpu_id)) if args_config.cuda \
207229
else torch.device("cpu")
208230

@@ -231,6 +253,7 @@ def _init_model(self):
231253
n_relations=self.n_relations,
232254
n_factors=self.n_factors,
233255
interact_mat=self.interact_mat,
256+
ind=self.ind,
234257
node_dropout_rate=self.node_dropout_rate,
235258
mess_dropout_rate=self.mess_dropout_rate)
236259

utils/parser.py

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def parse_args():
3131
parser.add_argument('--test_flag', nargs='?', default='part',
3232
help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch')
3333
parser.add_argument("--n_factors", type=int, default=4, help="number of latent factor for user favour")
34+
parser.add_argument("--ind", type=str, default='distance', help="Independence modeling: mi, distance, cosine")
3435

3536
# ===== relation context ===== #
3637
parser.add_argument('--context_hops', type=int, default=3, help='number of context hops')

0 commit comments

Comments
 (0)