Skip to content

Commit afebb25

Browse files
authored
[SYCL][Matrix] Add initial get_coord API (#7851)
This patch adds an initial API for the retrieval of coordinates from a work item element. A `get_coord()` method is added to the intel namespace to work on `wi_element` class. Also, a relevant SPIRV op is added, which the get_coord() gets lowered to. This is recreated PR from my forked repo. The discussions are in the original (closed) PR #7037
1 parent b9122b3 commit afebb25

File tree

3 files changed

+270
-0
lines changed

3 files changed

+270
-0
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

+7
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,13 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
106106
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
107107
__spirv_CompositeConstruct(const T v);
108108

109+
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
110+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
111+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
112+
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
113+
__spirv_JointMatrixGetElementCoordINTEL(
114+
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
115+
109116
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
110117
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
111118
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

+28
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ class wi_element {
8888
Group, T, Use, NumRows, NumCols, Layout> &Mat,
8989
std::size_t i)
9090
: M(Mat), idx(i) {}
91+
92+
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
93+
#if defined(__SYCL_DEVICE_ONLY__)
94+
__ocl_vec_t<uint32_t, 2> coord =
95+
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
96+
const uint32_t row = coord[0];
97+
const uint32_t col = coord[1];
98+
return std::make_tuple(row, col);
99+
#else
100+
throw runtime_error("joint matrix is not supported on host device.",
101+
PI_ERROR_INVALID_DEVICE);
102+
#endif // __SYCL_DEVICE_ONLY__
103+
}
104+
91105
operator T() {
92106
#ifdef __SYCL_DEVICE_ONLY__
93107
return __spirv_VectorExtractDynamic(M.spvm, idx);
@@ -171,6 +185,20 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
171185
Layout> &Mat,
172186
std::size_t i)
173187
: M(Mat), idx(i) {}
188+
189+
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
190+
#if defined(__SYCL_DEVICE_ONLY__)
191+
__ocl_vec_t<uint32_t, 2> coord =
192+
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
193+
const uint32_t row = coord[0];
194+
const uint32_t col = coord[1];
195+
return std::make_tuple(row, col);
196+
#else
197+
throw runtime_error("joint matrix is not supported on host device.",
198+
PI_ERROR_INVALID_DEVICE);
199+
#endif // __SYCL_DEVICE_ONLY__
200+
}
201+
174202
operator sycl::ext::oneapi::bfloat16() {
175203
#ifdef __SYCL_DEVICE_ONLY__
176204
return __spirv_VectorExtractDynamic(M.spvm, idx);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out
2+
3+
// Kernel B sum by col
4+
#include <iostream>
5+
#include <sycl/sycl.hpp>
6+
7+
using namespace sycl;
8+
using namespace sycl::ext::oneapi::experimental::matrix;
9+
10+
#define SG_SZ 16
11+
12+
#define TN SG_SZ
13+
#define TK 32
14+
15+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
16+
public:
17+
T *mat;
18+
19+
public:
20+
T *get_data() { return mat; }
21+
void set_data(T *data) { mat = data; }
22+
big_matrix(T *data) : mat(data) {}
23+
};
24+
25+
template <typename T, size_t M, size_t N>
26+
void sum_cols_ref(host_accessor<T, 2, access::mode::read_write> B,
27+
host_accessor<int, 1, access::mode::read_write> sum_cols) {
28+
int sum_cols_ref[N] = {0};
29+
for (size_t j = 0; j < N; j++) {
30+
for (size_t i = 0; i < M; i++) {
31+
sum_cols_ref[j] += B[i][j];
32+
}
33+
auto diff = sum_cols[j] - sum_cols_ref[j];
34+
assert(std::fabs(static_cast<int>(diff)) <=
35+
std::numeric_limits<int>::epsilon());
36+
}
37+
}
38+
39+
// clang-format off
40+
/*
41+
Here is a demonstration of how matrix B will be divided across
42+
work items for this test case.
43+
< --------------- 128 ---------------------------------->
44+
x x x x x x x x x x x x x x x x .......... x x x x x x ^
45+
x x x x x x x x x x x x x x x x .......... x x x x x x 16
46+
x x x x x x x x x x x x x x x x .......... x x x x x x |
47+
..... |
48+
x x x x x x x x x x x x x x x x .......... x x x x x x |
49+
x x x x x x x x x x x x x x x x .......... x x x x x x v
50+
51+
52+
--------------- 64 ---------------->
53+
x x x x x x .......... x x x x x x ^
54+
x x x x x x .......... x x x x x x 8
55+
x x x x x x .......... x x x x x x | <-- part of (VNNI-ed)
56+
..... | original matrix each SG
57+
x x x x x x .......... x x x x x x | holds
58+
x x x x x x .......... x x x x x x v
59+
< WI0 > < WI15 >
60+
61+
62+
<-------- 16 ------------->
63+
x x x .......... x x x ^
64+
x x x .......... x x x |
65+
x x x .......... x x x | <-- part of (non-VNNI-ed) original matrix
66+
..... | each SG holds
67+
x x x .......... x x x |
68+
x x x .......... x x x |
69+
x x x .......... x x x 32
70+
x x x .......... x x x |
71+
x x x .......... x x x |
72+
x x x .......... x x x |
73+
x x x .......... x x x |
74+
x x x .......... x x x |
75+
x x x .......... x x x v
76+
77+
If we dividie the above matrix across 16 (SG_SZ) work items,
78+
each WI will hold 32 elements. And these 32 elements will be
79+
8x4 chunks as shown in the VNNI-ed matrix figure.
80+
*/
81+
82+
// The total distribution among the WIs in ALL the sub-groups is as follows:
83+
// This is useful to figure out the the global index is to be calculated
84+
85+
/*
86+
W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements
87+
wi [0,0] --> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | wi [0,16] --> i=0, [0, 64]
88+
i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | i=1, [0, 65]
89+
i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | i=2, [0, 66]
90+
i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | i=3, [0, 67]
91+
92+
i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | ....
93+
i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] |
94+
i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] |
95+
i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] |
96+
... ... .... |
97+
i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | i=28, [7, 124]
98+
i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | i=29, [7, 125]
99+
i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | i=30, [7, 126]
100+
i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | i=31, [7, 127]
101+
---------------------------------------------------------------------------------------- ---------------------------
102+
wi [1,0] --> i=0, [8, 0]
103+
i=1, [8, 1]
104+
i=2, [8, 2]
105+
i=3, [8, 2]
106+
...
107+
i=28, [15, 0]
108+
i=29, [15, 1]
109+
i=30, [15, 2]
110+
i=31, [15, 3]
111+
*/
112+
113+
// The following is the distribution among WIs in a SINGLE SG.
114+
/*
115+
W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements
116+
117+
wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] |
118+
i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] |
119+
i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] |
120+
i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] |
121+
122+
i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] |
123+
i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] |
124+
i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] |
125+
i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] |
126+
... ... .... |
127+
i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] |
128+
i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] |
129+
i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] |
130+
i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] |
131+
132+
*/
133+
// clang-format on
134+
135+
template <typename T, size_t M, size_t N>
136+
void matrix_sum_cols(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
137+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N));
138+
// size of vector is known because SG size of set by the user in this case
139+
int sum_cols[N] = {0};
140+
buffer<int> sum_cols_v(sum_cols, N); // there are total of tK/4 * 2, 16 rows
141+
q.submit([&](handler &cgh) {
142+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
143+
144+
auto v = sum_cols_v.get_access<access::mode::atomic>(cgh);
145+
auto os = sycl::stream(100000, 6144, cgh);
146+
147+
cgh.parallel_for<class add_matrix>(
148+
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
149+
const auto global_idx = spmd_item.get_global_id(0);
150+
const auto global_idy = spmd_item.get_global_id(1);
151+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
152+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
153+
154+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
155+
156+
// TK = 32, TN = 16
157+
joint_matrix<sub_group, int8_t, use::b, TK, TN,
158+
ext::intel::experimental::matrix::layout::packed>
159+
sub_b;
160+
161+
joint_matrix_load(sg, sub_b,
162+
accB.get_pointer() + (global_idx * (TK / 4) * N) +
163+
sg_starty / SG_SZ * TN * 4,
164+
N);
165+
166+
int32_t sum_local_cols[N] = {0}; // 4 local cols, N total
167+
// sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row
168+
auto wiData =
169+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
170+
171+
size_t
172+
global_index; // Index into the result array that holds the sums.
173+
174+
// Keep track of cols handled in this WI
175+
int32_t handled_cols[N] = {-1};
176+
177+
// each WI calculates local sum of cols
178+
for (int i = 0; i < wiData.length(); ++i) {
179+
// get the index of the element in the submatrix
180+
auto dataItem = wiData[i];
181+
auto [row, col] = dataItem.get_coord();
182+
183+
// Calculation of global index
184+
int sg_idx = (int)global_idy / SG_SZ;
185+
global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ;
186+
sum_local_cols[global_index] += wiData[i];
187+
handled_cols[global_index] = 1;
188+
}
189+
190+
for (int j = 0; j < N; j++) {
191+
if (handled_cols[j] == 1) {
192+
global_index = j;
193+
sum_local_cols[global_index] = reduce_over_group(
194+
sg, sum_local_cols[global_index], sycl::plus<>());
195+
atomic_fetch_add(v[global_index], sum_local_cols[global_index]);
196+
}
197+
}
198+
}); // parallel for
199+
}).wait();
200+
sum_cols_ref<T, M, N>(bufB.get_host_access(), sum_cols_v.get_host_access());
201+
}
202+
203+
// TK = 32, TN = 16
204+
static constexpr size_t MATRIX_K = TK / 4 * 2; // 16
205+
static constexpr size_t MATRIX_N = TN * 4 * 2; // 128
206+
int8_t B[MATRIX_K][MATRIX_N];
207+
208+
/* < --------------- 128 ---------------------------------->
209+
x x x x x x x x x x x x x x x x .......... x x x x x x ^
210+
x x x x x x x x x x x x x x x x .......... x x x x x x 16
211+
x x x x x x x x x x x x x x x x .......... x x x x x x |
212+
..... |
213+
x x x x x x x x x x x x x x x x .......... x x x x x x |
214+
x x x x x x x x x x x x x x x x .......... x x x x x x v
215+
*/
216+
int main() {
217+
big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B);
218+
219+
size_t NDRangeK = MATRIX_K / (TK / 4);
220+
size_t NDRangeN = (MATRIX_N / 4) / TN;
221+
queue q;
222+
nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
223+
224+
for (int i = 0; i < MATRIX_K; i++) {
225+
for (int j = 0; j < MATRIX_N; j++) {
226+
B[i][j] = i;
227+
}
228+
}
229+
230+
matrix_sum_cols<int8_t, MATRIX_K, MATRIX_N>(q, MB, r);
231+
232+
std::cout << "Passed\n";
233+
234+
return 0;
235+
}

0 commit comments

Comments
 (0)