Skip to content

Commit 4ce49b8

Browse files
rusty1svid-kocipre-commit-ci[bot]
authored
NeighborSampler boilerplate (#413)
Co-authored-by: Vid Kocijan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0fd7a5f commit 4ce49b8

File tree

2 files changed

+158
-0
lines changed

2 files changed

+158
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#include <ATen/ATen.h>
2+
#include <torch/library.h>
3+
4+
#include "pyg_lib/csrc/utils/types.h"
5+
6+
namespace pyg {
7+
namespace classes {
8+
9+
namespace {
10+
11+
struct NeighborSampler : torch::CustomClassHolder {
12+
public:
13+
NeighborSampler(const at::Tensor& rowptr,
14+
const at::Tensor& col,
15+
const std::optional<at::Tensor>& edge_weight,
16+
const std::optional<at::Tensor>& node_time,
17+
const std::optional<at::Tensor>& edge_time)
18+
: rowptr_(rowptr),
19+
col_(col),
20+
edge_weight_(edge_weight),
21+
node_time_(node_time),
22+
edge_time_(edge_time) {};
23+
24+
std::tuple<at::Tensor, // row
25+
at::Tensor, // col
26+
at::Tensor, // node_id
27+
std::optional<at::Tensor>, // edge_id,
28+
std::optional<at::Tensor>, // batch,
29+
std::vector<int64_t>, // num_sampled_nodes,
30+
std::vector<int64_t>> // num_sampled_edges,
31+
sample(const std::vector<int64_t>& num_neighbors,
32+
const at::Tensor& seed_node,
33+
const std::optional<at::Tensor>& seed_time,
34+
bool disjoint = false,
35+
std::string temporal_strategy = "uniform",
36+
bool return_edge_id = true) {
37+
// TODO
38+
auto row = at::empty(0);
39+
auto col = at::empty(0);
40+
auto node_id = at::empty(0);
41+
auto edge_id = at::empty(0);
42+
auto batch = at::empty(0);
43+
std::vector<int64_t> num_sampled_nodes;
44+
std::vector<int64_t> num_sampled_edges;
45+
return std::make_tuple(row, col, node_id, edge_id, batch, num_sampled_nodes,
46+
num_sampled_edges);
47+
}
48+
49+
private:
50+
const at::Tensor& rowptr_;
51+
const at::Tensor& col_;
52+
const std::optional<at::Tensor>& edge_weight_;
53+
const std::optional<at::Tensor>& node_time_;
54+
const std::optional<at::Tensor>& edge_time_;
55+
};
56+
57+
struct HeteroNeighborSampler : torch::CustomClassHolder {
58+
public:
59+
HeteroNeighborSampler(
60+
const std::vector<node_type>& node_types,
61+
const std::vector<edge_type>& edge_types,
62+
const c10::Dict<rel_type, at::Tensor>& rowptr,
63+
const c10::Dict<rel_type, at::Tensor>& col,
64+
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight,
65+
const std::optional<c10::Dict<node_type, at::Tensor>>& node_time,
66+
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_time)
67+
: node_types_(node_types),
68+
edge_types_(edge_types),
69+
rowptr_(rowptr),
70+
col_(col),
71+
edge_weight_(edge_weight),
72+
node_time_(node_time),
73+
edge_time_(edge_time) {};
74+
75+
std::tuple<c10::Dict<rel_type, at::Tensor>, // row
76+
c10::Dict<rel_type, at::Tensor>, // col
77+
c10::Dict<node_type, at::Tensor>, // node_id
78+
std::optional<c10::Dict<rel_type, at::Tensor>>, // edge_id
79+
std::optional<c10::Dict<node_type, at::Tensor>>, // batch
80+
c10::Dict<node_type, std::vector<int64_t>>, // num_sampled_nodes
81+
c10::Dict<rel_type, std::vector<int64_t>>> // num_sampled_edges
82+
sample(const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors,
83+
const c10::Dict<node_type, at::Tensor>& seed_node,
84+
const std::optional<c10::Dict<node_type, at::Tensor>>& seed_time,
85+
bool disjoint = false,
86+
std::string temporal_strategy = "uniform",
87+
bool return_edge_id = true) {
88+
// TODO
89+
c10::Dict<rel_type, at::Tensor> row;
90+
c10::Dict<rel_type, at::Tensor> col;
91+
c10::Dict<node_type, at::Tensor> node_id;
92+
c10::Dict<rel_type, at::Tensor> edge_id;
93+
c10::Dict<node_type, at::Tensor> batch;
94+
c10::Dict<node_type, std::vector<int64_t>> num_sampled_nodes;
95+
c10::Dict<rel_type, std::vector<int64_t>> num_sampled_edges;
96+
return std::make_tuple(row, col, node_id, edge_id, batch, num_sampled_nodes,
97+
num_sampled_edges);
98+
}
99+
100+
private:
101+
const std::vector<node_type>& node_types_;
102+
const std::vector<edge_type>& edge_types_;
103+
const c10::Dict<rel_type, at::Tensor>& rowptr_;
104+
const c10::Dict<rel_type, at::Tensor>& col_;
105+
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_;
106+
const std::optional<c10::Dict<node_type, at::Tensor>>& node_time_;
107+
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_;
108+
};
109+
110+
} // namespace
111+
112+
TORCH_LIBRARY_FRAGMENT(pyg, m) {
113+
m.class_<NeighborSampler>("NeighborSampler")
114+
.def(torch::init<at::Tensor&, at::Tensor&, std::optional<at::Tensor>,
115+
std::optional<at::Tensor>, std::optional<at::Tensor>>())
116+
.def("sample", &NeighborSampler::sample);
117+
118+
m.class_<HeteroNeighborSampler>("HeteroNeighborSampler")
119+
.def(torch::init<std::vector<node_type>, std::vector<edge_type>,
120+
c10::Dict<rel_type, at::Tensor>,
121+
c10::Dict<rel_type, at::Tensor>,
122+
std::optional<c10::Dict<rel_type, at::Tensor>>,
123+
std::optional<c10::Dict<node_type, at::Tensor>>,
124+
std::optional<c10::Dict<rel_type, at::Tensor>>>())
125+
.def("sample", &HeteroNeighborSampler::sample);
126+
}
127+
128+
} // namespace classes
129+
} // namespace pyg

test/classes/test_neighbor_sampler.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
3+
import pyg_lib # noqa
4+
5+
6+
def test_neighbor_sampler() -> None:
7+
rowptr = torch.tensor([0, 2, 4, 6])
8+
col = torch.tensor([1, 2, 0, 2, 1, 0])
9+
10+
Sampler = torch.classes.pyg.NeighborSampler
11+
sampler = Sampler(rowptr, col, None, None, None)
12+
assert sampler is not None
13+
14+
15+
def test_hetero_neighbor_sampler() -> None:
16+
node_types = ['A', 'B']
17+
edge_types = [('A', 'to', 'B'), ('B', 'to', 'A')]
18+
rowptr = {
19+
'A__to__B': torch.tensor([0, 1]),
20+
'B__to__A': torch.tensor([0, 1]),
21+
}
22+
col = {
23+
'A__to__B': torch.tensor([0]),
24+
'B__to__A': torch.tensor([0]),
25+
}
26+
27+
Sampler = torch.classes.pyg.HeteroNeighborSampler
28+
sampler = Sampler(node_types, edge_types, rowptr, col, None, None, None)
29+
assert sampler is not None

0 commit comments

Comments
 (0)