-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtu_dataset.py
48 lines (41 loc) · 2.08 KB
/
tu_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from torch_geometric.datasets import TUDataset
class TUDatasetExt(TUDataset):
r"""A variety of graph kernel benchmark datasets, *.e.g.* "IMDB-BINARY",
"REDDIT-BINARY" or "PROTEINS", collected from the `TU Dortmund University
<http://graphkernels.cs.tu-dortmund.de>`_.
Args:
root (string): Root directory where the dataset should be saved.
name (string): The `name <http://graphkernels.cs.tu-dortmund.de>`_ of
the dataset.
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
use_node_attr (bool, optional): If :obj:`True`, the dataset will
contain additional continuous node features (if present).
(default: :obj:`False`)
"""
url = 'https://ls11-www.cs.tu-dortmund.de/people/morris/' \
'graphkerneldatasets'
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None,
use_node_attr=False,
processed_filename='data.pt'):
self.processed_filename = processed_filename
super(TUDatasetExt, self).__init__(root, name, transform, pre_transform,
pre_filter, use_node_attr)
@property
def processed_file_names(self):
return self.processed_filename