@@ -185,7 +185,6 @@ def __init__(
185
185
fuse = self .fuse ,
186
186
)
187
187
188
- # Cache to potentially disable later on:
189
188
self .__class__ ._orig_propagate = self .__class__ .propagate
190
189
self .__class__ ._jinja_propagate = module .propagate
191
190
@@ -197,22 +196,30 @@ def __init__(
197
196
198
197
# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
199
198
if (self .inspector .implements ('edge_update' )
200
- and not self .edge_updater .__module__ .startswith (jinja_prefix )
201
- and self .inspector .can_read_source ):
202
- module = module_from_template (
203
- module_name = f'{ jinja_prefix } _edge_updater' ,
204
- template_path = osp .join (root_dir , 'edge_updater.jinja' ),
205
- tmp_dirname = 'message_passing' ,
206
- # Keyword arguments:
207
- modules = self .inspector ._modules ,
208
- collect_name = 'edge_collect' ,
209
- signature = self ._get_edge_updater_signature (),
210
- collect_param_dict = self .inspector .get_param_dict (
211
- 'edge_update' ),
212
- )
199
+ and not self .edge_updater .__module__ .startswith (jinja_prefix )):
200
+ if self .inspector .can_read_source :
201
+
202
+ module = module_from_template (
203
+ module_name = f'{ jinja_prefix } _edge_updater' ,
204
+ template_path = osp .join (root_dir , 'edge_updater.jinja' ),
205
+ tmp_dirname = 'message_passing' ,
206
+ # Keyword arguments:
207
+ modules = self .inspector ._modules ,
208
+ collect_name = 'edge_collect' ,
209
+ signature = self ._get_edge_updater_signature (),
210
+ collect_param_dict = self .inspector .get_param_dict (
211
+ 'edge_update' ),
212
+ )
213
213
214
- self .__class__ .edge_updater = module .edge_updater
215
- self .__class__ .edge_collect = module .edge_collect
214
+ self .__class__ ._orig_edge_updater = self .__class__ .edge_updater
215
+ self .__class__ ._jinja_edge_updater = module .edge_updater
216
+
217
+ self .__class__ .edge_updater = module .edge_updater
218
+ self .__class__ .edge_collect = module .edge_collect
219
+ else :
220
+ self .__class__ ._orig_edge_updater = self .__class__ .edge_updater
221
+ self .__class__ ._jinja_edge_updater = (
222
+ self .__class__ .edge_updater )
216
223
217
224
# Explainability:
218
225
self ._explain : Optional [bool ] = None
0 commit comments