|
| 1 | +// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out |
| 2 | + |
| 3 | +// Kernel B sum by col |
| 4 | +#include <iostream> |
| 5 | +#include <sycl/sycl.hpp> |
| 6 | + |
| 7 | +using namespace sycl; |
| 8 | +using namespace sycl::ext::oneapi::experimental::matrix; |
| 9 | + |
| 10 | +#define SG_SZ 16 |
| 11 | + |
| 12 | +#define TN SG_SZ |
| 13 | +#define TK 32 |
| 14 | + |
| 15 | +template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix { |
| 16 | +public: |
| 17 | + T *mat; |
| 18 | + |
| 19 | +public: |
| 20 | + T *get_data() { return mat; } |
| 21 | + void set_data(T *data) { mat = data; } |
| 22 | + big_matrix(T *data) : mat(data) {} |
| 23 | +}; |
| 24 | + |
| 25 | +template <typename T, size_t M, size_t N> |
| 26 | +void sum_cols_ref(host_accessor<T, 2, access::mode::read_write> B, |
| 27 | + host_accessor<int, 1, access::mode::read_write> sum_cols) { |
| 28 | + int sum_cols_ref[N] = {0}; |
| 29 | + for (size_t j = 0; j < N; j++) { |
| 30 | + for (size_t i = 0; i < M; i++) { |
| 31 | + sum_cols_ref[j] += B[i][j]; |
| 32 | + } |
| 33 | + auto diff = sum_cols[j] - sum_cols_ref[j]; |
| 34 | + assert(std::fabs(static_cast<int>(diff)) <= |
| 35 | + std::numeric_limits<int>::epsilon()); |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +// clang-format off |
| 40 | +/* |
| 41 | + Here is a demonstration of how matrix B will be divided across |
| 42 | + work items for this test case. |
| 43 | + < --------------- 128 ----------------------------------> |
| 44 | + x x x x x x x x x x x x x x x x .......... x x x x x x ^ |
| 45 | + x x x x x x x x x x x x x x x x .......... x x x x x x 16 |
| 46 | + x x x x x x x x x x x x x x x x .......... x x x x x x | |
| 47 | + ..... | |
| 48 | + x x x x x x x x x x x x x x x x .......... x x x x x x | |
| 49 | + x x x x x x x x x x x x x x x x .......... x x x x x x v |
| 50 | +
|
| 51 | + |
| 52 | + --------------- 64 ----------------> |
| 53 | + x x x x x x .......... x x x x x x ^ |
| 54 | + x x x x x x .......... x x x x x x 8 |
| 55 | + x x x x x x .......... x x x x x x | <-- part of (VNNI-ed) |
| 56 | + ..... | original matrix each SG |
| 57 | + x x x x x x .......... x x x x x x | holds |
| 58 | + x x x x x x .......... x x x x x x v |
| 59 | + < WI0 > < WI15 > |
| 60 | +
|
| 61 | +
|
| 62 | + <-------- 16 -------------> |
| 63 | + x x x .......... x x x ^ |
| 64 | + x x x .......... x x x | |
| 65 | + x x x .......... x x x | <-- part of (non-VNNI-ed) original matrix |
| 66 | + ..... | each SG holds |
| 67 | + x x x .......... x x x | |
| 68 | + x x x .......... x x x | |
| 69 | + x x x .......... x x x 32 |
| 70 | + x x x .......... x x x | |
| 71 | + x x x .......... x x x | |
| 72 | + x x x .......... x x x | |
| 73 | + x x x .......... x x x | |
| 74 | + x x x .......... x x x | |
| 75 | + x x x .......... x x x v |
| 76 | +
|
| 77 | + If we dividie the above matrix across 16 (SG_SZ) work items, |
| 78 | + each WI will hold 32 elements. And these 32 elements will be |
| 79 | + 8x4 chunks as shown in the VNNI-ed matrix figure. |
| 80 | +*/ |
| 81 | + |
| 82 | +// The total distribution among the WIs in ALL the sub-groups is as follows: |
| 83 | +// This is useful to figure out the the global index is to be calculated |
| 84 | + |
| 85 | +/* |
| 86 | +W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements |
| 87 | +wi [0,0] --> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | wi [0,16] --> i=0, [0, 64] |
| 88 | + i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | i=1, [0, 65] |
| 89 | + i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | i=2, [0, 66] |
| 90 | + i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | i=3, [0, 67] |
| 91 | +
|
| 92 | + i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | .... |
| 93 | + i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] | |
| 94 | + i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] | |
| 95 | + i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] | |
| 96 | + ... ... .... | |
| 97 | + i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | i=28, [7, 124] |
| 98 | + i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | i=29, [7, 125] |
| 99 | + i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | i=30, [7, 126] |
| 100 | + i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | i=31, [7, 127] |
| 101 | +---------------------------------------------------------------------------------------- --------------------------- |
| 102 | +wi [1,0] --> i=0, [8, 0] |
| 103 | + i=1, [8, 1] |
| 104 | + i=2, [8, 2] |
| 105 | + i=3, [8, 2] |
| 106 | + ... |
| 107 | + i=28, [15, 0] |
| 108 | + i=29, [15, 1] |
| 109 | + i=30, [15, 2] |
| 110 | + i=31, [15, 3] |
| 111 | +*/ |
| 112 | + |
| 113 | +// The following is the distribution among WIs in a SINGLE SG. |
| 114 | +/* |
| 115 | +W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements |
| 116 | +
|
| 117 | +wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | |
| 118 | + i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | |
| 119 | + i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | |
| 120 | + i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | |
| 121 | +
|
| 122 | + i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | |
| 123 | + i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] | |
| 124 | + i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] | |
| 125 | + i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] | |
| 126 | + ... ... .... | |
| 127 | + i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | |
| 128 | + i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | |
| 129 | + i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | |
| 130 | + i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | |
| 131 | +
|
| 132 | +*/ |
| 133 | +// clang-format on |
| 134 | + |
| 135 | +template <typename T, size_t M, size_t N> |
| 136 | +void matrix_sum_cols(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) { |
| 137 | + buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N)); |
| 138 | + // size of vector is known because SG size of set by the user in this case |
| 139 | + int sum_cols[N] = {0}; |
| 140 | + buffer<int> sum_cols_v(sum_cols, N); // there are total of tK/4 * 2, 16 rows |
| 141 | + q.submit([&](handler &cgh) { |
| 142 | + auto accB = bufB.get_access<access::mode::read_write>(cgh); |
| 143 | + |
| 144 | + auto v = sum_cols_v.get_access<access::mode::atomic>(cgh); |
| 145 | + auto os = sycl::stream(100000, 6144, cgh); |
| 146 | + |
| 147 | + cgh.parallel_for<class add_matrix>( |
| 148 | + r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] { |
| 149 | + const auto global_idx = spmd_item.get_global_id(0); |
| 150 | + const auto global_idy = spmd_item.get_global_id(1); |
| 151 | + const auto sg_startx = global_idx - spmd_item.get_local_id(0); |
| 152 | + const auto sg_starty = global_idy - spmd_item.get_local_id(1); |
| 153 | + |
| 154 | + ext::oneapi::sub_group sg = spmd_item.get_sub_group(); |
| 155 | + |
| 156 | + // TK = 32, TN = 16 |
| 157 | + joint_matrix<sub_group, int8_t, use::b, TK, TN, |
| 158 | + ext::intel::experimental::matrix::layout::packed> |
| 159 | + sub_b; |
| 160 | + |
| 161 | + joint_matrix_load(sg, sub_b, |
| 162 | + accB.get_pointer() + (global_idx * (TK / 4) * N) + |
| 163 | + sg_starty / SG_SZ * TN * 4, |
| 164 | + N); |
| 165 | + |
| 166 | + int32_t sum_local_cols[N] = {0}; // 4 local cols, N total |
| 167 | + // sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row |
| 168 | + auto wiData = |
| 169 | + sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b); |
| 170 | + |
| 171 | + size_t |
| 172 | + global_index; // Index into the result array that holds the sums. |
| 173 | + |
| 174 | + // Keep track of cols handled in this WI |
| 175 | + int32_t handled_cols[N] = {-1}; |
| 176 | + |
| 177 | + // each WI calculates local sum of cols |
| 178 | + for (int i = 0; i < wiData.length(); ++i) { |
| 179 | + // get the index of the element in the submatrix |
| 180 | + auto dataItem = wiData[i]; |
| 181 | + auto [row, col] = dataItem.get_coord(); |
| 182 | + |
| 183 | + // Calculation of global index |
| 184 | + int sg_idx = (int)global_idy / SG_SZ; |
| 185 | + global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ; |
| 186 | + sum_local_cols[global_index] += wiData[i]; |
| 187 | + handled_cols[global_index] = 1; |
| 188 | + } |
| 189 | + |
| 190 | + for (int j = 0; j < N; j++) { |
| 191 | + if (handled_cols[j] == 1) { |
| 192 | + global_index = j; |
| 193 | + sum_local_cols[global_index] = reduce_over_group( |
| 194 | + sg, sum_local_cols[global_index], sycl::plus<>()); |
| 195 | + atomic_fetch_add(v[global_index], sum_local_cols[global_index]); |
| 196 | + } |
| 197 | + } |
| 198 | + }); // parallel for |
| 199 | + }).wait(); |
| 200 | + sum_cols_ref<T, M, N>(bufB.get_host_access(), sum_cols_v.get_host_access()); |
| 201 | +} |
| 202 | + |
| 203 | +// TK = 32, TN = 16 |
| 204 | +static constexpr size_t MATRIX_K = TK / 4 * 2; // 16 |
| 205 | +static constexpr size_t MATRIX_N = TN * 4 * 2; // 128 |
| 206 | +int8_t B[MATRIX_K][MATRIX_N]; |
| 207 | + |
| 208 | +/* < --------------- 128 ----------------------------------> |
| 209 | + x x x x x x x x x x x x x x x x .......... x x x x x x ^ |
| 210 | + x x x x x x x x x x x x x x x x .......... x x x x x x 16 |
| 211 | + x x x x x x x x x x x x x x x x .......... x x x x x x | |
| 212 | + ..... | |
| 213 | + x x x x x x x x x x x x x x x x .......... x x x x x x | |
| 214 | + x x x x x x x x x x x x x x x x .......... x x x x x x v |
| 215 | +*/ |
| 216 | +int main() { |
| 217 | + big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B); |
| 218 | + |
| 219 | + size_t NDRangeK = MATRIX_K / (TK / 4); |
| 220 | + size_t NDRangeN = (MATRIX_N / 4) / TN; |
| 221 | + queue q; |
| 222 | + nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}); |
| 223 | + |
| 224 | + for (int i = 0; i < MATRIX_K; i++) { |
| 225 | + for (int j = 0; j < MATRIX_N; j++) { |
| 226 | + B[i][j] = i; |
| 227 | + } |
| 228 | + } |
| 229 | + |
| 230 | + matrix_sum_cols<int8_t, MATRIX_K, MATRIX_N>(q, MB, r); |
| 231 | + |
| 232 | + std::cout << "Passed\n"; |
| 233 | + |
| 234 | + return 0; |
| 235 | +} |
0 commit comments