Skip to content

Commit e576766

Browse files
[Operator] add scatter_max (#308)
* new scatter_max * [Bugfix] Update loading datasets (#307) * Fix scatter max * Add max-aggr sage * Update submodule urls * Update submodule urls * fix bugs for scatter_max * use relative path * fix bugs for multi gpus * remove redundancy Co-authored-by: Yukuo Cen <[email protected]>
1 parent f6f33c6 commit e576766

File tree

8 files changed

+501
-349
lines changed

8 files changed

+501
-349
lines changed

.gitmodules

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
[submodule "third_party/dgNN"]
22
ignore = dirty
33
path = third_party/dgNN
4-
url = https://github.com/dgSPARSE/dgNN
4+
url = ../../dgSPARSE/dgNN
55
branch = main
66
[submodule "third_party/actnn"]
77
ignore = dirty
88
path = third_party/actnn
9-
url = https://github.com/ucbrise/actnn
9+
url = ../../ucbrise/actnn
1010
branch = main
1111
[submodule "third_party/fastmoe"]
1212
ignore = dirty
1313
path = third_party/fastmoe
14-
url = https://github.com/laekov/fastmoe
14+
url = ../../laekov/fastmoe
1515
branch = master

cogdl/layers/sage_layer.py

+13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ def __call__(self, graph, x):
1818
return x
1919

2020

21+
class MaxAggregator(object):
22+
def __init__(self):
23+
from cogdl.operators.scatter_max import scatter_max
24+
25+
self.scatter_max = scatter_max
26+
27+
def __call__(self, graph, x):
28+
x = self.scatter_max(graph.row_indptr.int(), graph.col_indices.int(), x)
29+
return x
30+
31+
2132
class SAGELayer(nn.Module):
2233
def __init__(
2334
self, in_feats, out_feats, normalize=False, aggr="mean", dropout=0.0, norm=None, activation=None, residual=False
@@ -35,6 +46,8 @@ def __init__(
3546
self.aggr = MeanAggregator()
3647
elif aggr == "sum":
3748
self.aggr = SumAggregator()
49+
elif aggr == "max":
50+
self.aggr = MaxAggregator()
3851
else:
3952
raise NotImplementedError
4053

cogdl/operators/scatter_max.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import numpy as np
3+
import torch
4+
from torch.utils.cpp_extension import load
5+
6+
path = os.path.join(os.path.dirname(__file__))
7+
8+
# SPMM
9+
10+
try:
11+
spmm_max = load(
12+
name="scatter_max",
13+
sources=[os.path.join(path, "scatter_max/scatter_max.cc"), os.path.join(path, "scatter_max/scatter_max.cu")],
14+
verbose=True,
15+
)
16+
17+
def scatter_max(rowptr, colind, feat):
18+
return ScatterMaxFunction.apply(rowptr, colind, feat)
19+
20+
21+
except Exception:
22+
spmm_max = None
23+
24+
25+
class ScatterMaxFunction(torch.autograd.Function):
26+
@staticmethod
27+
def forward(ctx, rowptr, colind, feat):
28+
out, max_id = spmm_max.scatter_max_fp(rowptr, colind, feat)
29+
ctx.save_for_backward(max_id)
30+
return out
31+
32+
@staticmethod
33+
def backward(ctx, grad):
34+
grad = grad.contiguous()
35+
max_id = ctx.saved_tensors[0]
36+
out = spmm_max.scatter_max_bp(grad, max_id)
37+
return None, None, out
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <pybind11/pybind11.h>
2+
#include <torch/extension.h>
3+
#include <vector>
4+
5+
void assertTensor(torch::Tensor &T, c10::ScalarType type) {
6+
assert(T.is_contiguous());
7+
assert(T.device().type() == torch::kCUDA);
8+
assert(T.dtype() == type);
9+
}
10+
11+
std::vector<torch::Tensor> scatter_max_fp_cuda(torch::Tensor rowptr,
12+
torch::Tensor colind,
13+
torch::Tensor node_feature);
14+
15+
torch::Tensor scatter_max_bp_cuda(torch::Tensor node_feature,
16+
torch::Tensor max_mask, long num_nodes);
17+
18+
std::vector<torch::Tensor> scatter_max(torch::Tensor rowptr,
19+
torch::Tensor colind,
20+
torch::Tensor node_feature) {
21+
assertTensor(rowptr, torch::kInt32);
22+
assertTensor(colind, torch::kInt32);
23+
assertTensor(node_feature, torch::kFloat32);
24+
return scatter_max_fp_cuda(rowptr, colind, node_feature);
25+
}
26+
27+
torch::Tensor scatter_max_bp(torch::Tensor node_feature,
28+
torch::Tensor max_mask) {
29+
assertTensor(node_feature, torch::kFloat32);
30+
assertTensor(max_mask, torch::kInt32);
31+
return scatter_max_bp_cuda(node_feature, max_mask);
32+
}
33+
34+
PYBIND11_MODULE(scatter_max, m) {
35+
m.doc() = "scatter max kernel";
36+
m.def("scatter_max_fp", &scatter_max, "scatter max forward");
37+
m.def("scatter_max_bp", &scatter_max_bp, "scatter max backward");
38+
}
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include <cuda.h>
2+
#include <torch/types.h>
3+
#include <vector>
4+
5+
__global__ void scatter_max_forward(const int *A_indptr, const int *A_indices,
6+
const float *B, float *C, int *max_mask) {
7+
int rid = blockDim.y * blockIdx.x + threadIdx.y;
8+
int m = gridDim.x;
9+
int k = blockDim.x;
10+
if (rid < m) {
11+
int lb = A_indptr[rid];
12+
int hb = A_indptr[(rid + 1)];
13+
int stride = hb - lb;
14+
int offset;
15+
int max_id;
16+
float acc = (stride > 0) ? FLT_MIN : 0;
17+
for (int ptr = lb; ptr < hb; ptr++) {
18+
int cid = A_indices[ptr];
19+
offset = cid * k + threadIdx.x;
20+
if (acc < B[offset]) {
21+
acc = B[offset];
22+
max_id = cid;
23+
}
24+
}
25+
C[(rid * k + threadIdx.x)] = acc;
26+
max_mask[(rid * k + threadIdx.x)] = max_id;
27+
}
28+
}
29+
30+
__global__ void scatter_max_backward(const float *grad, float *out,
31+
int *max_mask) {
32+
int rid = blockDim.y * blockIdx.x + threadIdx.y;
33+
int m = gridDim.x;
34+
int k = blockDim.x;
35+
if (rid < m) {
36+
int offset = rid * k + threadIdx.x;
37+
int max_id;
38+
max_id = max_mask[offset]; // max mapping
39+
float grad_tmp = grad[offset];
40+
atomicAdd(&out[max_id * k + threadIdx.x], grad_tmp);
41+
}
42+
}
43+
44+
std::vector<torch::Tensor> scatter_max_fp_cuda(torch::Tensor rowptr,
45+
torch::Tensor colind,
46+
torch::Tensor node_feature) {
47+
const long m = rowptr.size(0) - 1;
48+
const long k = node_feature.size(1);
49+
auto devid = node_feature.device().index();
50+
auto optionsI =
51+
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, devid);
52+
auto optionsF =
53+
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid);
54+
auto max_mask = torch::empty({m, k}, optionsI);
55+
auto out = torch::empty({m, k}, optionsF);
56+
scatter_max_forward<<<m, k>>>(rowptr.data_ptr<int>(), colind.data_ptr<int>(),
57+
node_feature.data_ptr<float>(),
58+
out.data_ptr<float>(),
59+
max_mask.data_ptr<int>());
60+
return {out, max_mask};
61+
}
62+
63+
torch::Tensor scatter_max_bp_cuda(torch::Tensor node_feature,
64+
torch::Tensor max_mask) {
65+
const long m = node_feature.size(0);
66+
const long k = node_feature.size(1);
67+
auto devid = node_feature.device().index();
68+
auto options =
69+
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid);
70+
auto out = torch::empty({m, k}, options);
71+
scatter_max_backward<<<m, k>>>(node_feature.data_ptr<float>(),
72+
out.data_ptr<float>(),
73+
max_mask.data_ptr<int>());
74+
return out;
75+
}

0 commit comments

Comments
 (0)