Skip to content

Commit 849ca0a

Browse files
authored
Make MessagePassing interface thread-safe (#9001)
Fixes #8994
1 parent 8d625c4 commit 849ca0a

File tree

4 files changed

+34
-26
lines changed

4 files changed

+34
-26
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313

1414
### Changed
1515

16+
- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))
1617
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
1718
- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))
1819

benchmark/inference/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
```bash
88
pip install ogb
99
```
10-
1. Install `autoconf` required for `jemalloc` setup
10+
1. Install `autoconf` required for `jemalloc` setup:
1111
```bash
1212
sudo apt-get install autoconf
1313
```

torch_geometric/nn/conv/message_passing.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def __init__(
185185
fuse=self.fuse,
186186
)
187187

188-
# Cache to potentially disable later on:
189188
self.__class__._orig_propagate = self.__class__.propagate
190189
self.__class__._jinja_propagate = module.propagate
191190

@@ -197,22 +196,30 @@ def __init__(
197196

198197
# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
199198
if (self.inspector.implements('edge_update')
200-
and not self.edge_updater.__module__.startswith(jinja_prefix)
201-
and self.inspector.can_read_source):
202-
module = module_from_template(
203-
module_name=f'{jinja_prefix}_edge_updater',
204-
template_path=osp.join(root_dir, 'edge_updater.jinja'),
205-
tmp_dirname='message_passing',
206-
# Keyword arguments:
207-
modules=self.inspector._modules,
208-
collect_name='edge_collect',
209-
signature=self._get_edge_updater_signature(),
210-
collect_param_dict=self.inspector.get_param_dict(
211-
'edge_update'),
212-
)
199+
and not self.edge_updater.__module__.startswith(jinja_prefix)):
200+
if self.inspector.can_read_source:
201+
202+
module = module_from_template(
203+
module_name=f'{jinja_prefix}_edge_updater',
204+
template_path=osp.join(root_dir, 'edge_updater.jinja'),
205+
tmp_dirname='message_passing',
206+
# Keyword arguments:
207+
modules=self.inspector._modules,
208+
collect_name='edge_collect',
209+
signature=self._get_edge_updater_signature(),
210+
collect_param_dict=self.inspector.get_param_dict(
211+
'edge_update'),
212+
)
213213

214-
self.__class__.edge_updater = module.edge_updater
215-
self.__class__.edge_collect = module.edge_collect
214+
self.__class__._orig_edge_updater = self.__class__.edge_updater
215+
self.__class__._jinja_edge_updater = module.edge_updater
216+
217+
self.__class__.edge_updater = module.edge_updater
218+
self.__class__.edge_collect = module.edge_collect
219+
else:
220+
self.__class__._orig_edge_updater = self.__class__.edge_updater
221+
self.__class__._jinja_edge_updater = (
222+
self.__class__.edge_updater)
216223

217224
# Explainability:
218225
self._explain: Optional[bool] = None

torch_geometric/template.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import importlib
2-
import os
32
import os.path as osp
43
import sys
4+
import tempfile
55
from typing import Any
66

77
from jinja2 import Environment, FileSystemLoader
88

9-
from torch_geometric import get_home_dir
10-
119

1210
def module_from_template(
1311
module_name: str,
@@ -23,13 +21,15 @@ def module_from_template(
2321
template = env.get_template(osp.basename(template_path))
2422
module_repr = template.render(**kwargs)
2523

26-
instance_dir = osp.join(get_home_dir(), tmp_dirname)
27-
os.makedirs(instance_dir, exist_ok=True)
28-
instance_path = osp.join(instance_dir, f'{module_name}.py')
29-
with open(instance_path, 'w') as f:
30-
f.write(module_repr)
24+
with tempfile.NamedTemporaryFile(
25+
mode='w',
26+
prefix=f'{module_name}_',
27+
suffix='.py',
28+
delete=False,
29+
) as tmp:
30+
tmp.write(module_repr)
3131

32-
spec = importlib.util.spec_from_file_location(module_name, instance_path)
32+
spec = importlib.util.spec_from_file_location(module_name, tmp.name)
3333
assert spec is not None
3434
module = importlib.util.module_from_spec(spec)
3535
sys.modules[module_name] = module

0 commit comments

Comments
 (0)