You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, Thanks for your work, I previously encountered an issue where I couldn't directly use PyG's GNNExplainer to explain RGCN, and it would throw the same error as the one you encountered in the gnnexplainer.ipynb. Have you found a solution to this problem when trying to explain RGCN?
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /data/conghao001/FYP/GCN_Drug/models.py:749, in RGCNNet.forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight)
741 def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None):
742 # get graph input
743 # edge_weight is only used for decoding
744
745 # x, edge_index, batch = data.x, data.edge_index, data.batch
746 # edge_index = edge_index.long()
747 edge_feat = edge_feat.squeeze()
--> 749 x = self.conv1(x, edge_index, edge_type=edge_feat)
750 x = self.relu(x)
751 x = self.conv2(x, edge_index, edge_type=edge_feat)
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/conv/rgcn_conv.py:218, in RGCNConv.forward(self, x, edge_index, edge_type)
216 out += self.propagate(tmp, x=weight[i, x_l], size=size)
217 else:
--> 218 h = self.propagate(tmp, x=x_l, size=size)
219 out = out + (h @ weight[i])
221 root = self.root
The error in my notebook was caused since in old versions of PyG, input arguments can not be properly passed to the model forward function when the model is wrapped by GNNExplainer.
PyG has updated the GNNExplainer in recent versions (I used 2.0.4 in this work) and wrapped it in a more general class Explainer, and the problem is also fixed. I have marked the erratic codes as deprecated and included the correct codes of using the Explainer on the top of the notebook and also in gnnexplainer.py.
Hi, Thanks for your work, I previously encountered an issue where I couldn't directly use PyG's GNNExplainer to explain RGCN, and it would throw the same error as the one you encountered in the gnnexplainer.ipynb. Have you found a solution to this problem when trying to explain RGCN?
AssertionError Traceback (most recent call last)
Input In [22], in <cell line: 22>()
10 # model_args = (
11 # x_cell_mut,
12 # batch_drug,
13 # edge_features
14 # )
16 kwargs = {
17 "x_cell_mut": x_cell_mut,
18 "batch_drug": batch_drug,
19 "edge_feat": edge_features
20 }
---> 22 node_feature_mask, edge_mask = explainer.explain_graph(x = x, edge_index = edge_index, x_cell_mut = x_cell_mut, edge_feat = edge_features)
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/models/gnn_explainer.py:165, in GNNExplainer.explain_graph(self, x, edge_index, **kwargs)
163 print('debug h', h.size())
164 print('debug edge', edge_index.size())
--> 165 out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
166 loss = self.get_loss(out, prediction, None)
167 loss.backward()
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /data/conghao001/FYP/GCN_Drug/models.py:749, in RGCNNet.forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight)
741 def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None):
742 # get graph input
743 # edge_weight is only used for decoding
744
745 # x, edge_index, batch = data.x, data.edge_index, data.batch
746 # edge_index = edge_index.long()
747 edge_feat = edge_feat.squeeze()
--> 749 x = self.conv1(x, edge_index, edge_type=edge_feat)
750 x = self.relu(x)
751 x = self.conv2(x, edge_index, edge_type=edge_feat)
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/conv/rgcn_conv.py:218, in RGCNConv.forward(self, x, edge_index, edge_type)
216 out += self.propagate(tmp, x=weight[i, x_l], size=size)
217 else:
--> 218 h = self.propagate(tmp, x=x_l, size=size)
219 out = out + (h @ weight[i])
221 root = self.root
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:338, in MessagePassing.propagate(self, edge_index, size, **kwargs)
336 edge_mask = torch.cat([edge_mask, loop], dim=0)
337 print(out.size(self.node_dim), edge_mask.size(0))
--> 338 assert out.size(self.node_dim) == edge_mask.size(0)
339 out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
341 aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
AssertionError:
The text was updated successfully, but these errors were encountered: