Skip to content

Commit aba6d85

Browse files
[SYCL][Matrix] Add support for tf32 type using the unified interface (#8702)
Co-authored-by: Bing1 Yu <[email protected]>
1 parent 3e8f937 commit aba6d85

12 files changed

+805
-56
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

+22-14
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,23 @@
2525
#ifdef __SYCL_DEVICE_ONLY__
2626

2727
#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
28-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
28+
extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);
29+
template <typename T, typename Tp, std::size_t R, std::size_t C,
30+
__spv::MatrixUse U,
2931
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
3032
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
31-
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
32-
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
33-
__spv::MatrixLayout Layout = L,
34-
__spv::Scope::Flag Sc = S, int MemOperand = 0);
33+
extern __DPCPP_SYCL_EXTERNAL
34+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
35+
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
36+
__spv::MatrixLayout Layout = L,
37+
__spv::Scope::Flag Sc = S, int MemOperand = 0);
3538

36-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
39+
template <typename T, typename Tp, std::size_t R, std::size_t C,
40+
__spv::MatrixUse U,
3741
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
3842
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
3943
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
40-
T *Ptr, __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *Object,
44+
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
4145
std::size_t Stride, __spv::MatrixLayout Layout = L,
4246
__spv::Scope::Flag Sc = S, int MemOperand = 0);
4347

@@ -100,11 +104,13 @@ extern __DPCPP_SYCL_EXTERNAL
100104
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
101105
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
102106

103-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
107+
template <typename T, typename Tp, std::size_t R, std::size_t C,
108+
__spv::MatrixUse U,
104109
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
105110
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
106-
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
107-
__spirv_CompositeConstruct(const T v);
111+
extern __DPCPP_SYCL_EXTERNAL
112+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
113+
__spirv_CompositeConstruct(const T v);
108114

109115
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
110116
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
@@ -119,18 +125,20 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
119125
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
120126
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);
121127

122-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
128+
template <typename Ts, typename T, std::size_t R, std::size_t C,
129+
__spv::MatrixUse U,
123130
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
124131
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
125-
extern __DPCPP_SYCL_EXTERNAL T __spirv_VectorExtractDynamic(
132+
extern __DPCPP_SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
126133
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);
127134

128-
template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
135+
template <typename Ts, typename T, std::size_t R, std::size_t C,
136+
__spv::MatrixUse U,
129137
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
130138
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
131139
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
132140
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
133-
T val, size_t i);
141+
Ts val, size_t i);
134142
#else
135143
template <typename T, std::size_t R, std::size_t C,
136144
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

+103-21
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,26 @@ struct joint_matrix;
6767

6868
} // namespace matrix
6969
} // namespace experimental
70+
71+
namespace detail {
72+
// Differentiating between the "element type" and the "storage element type"
73+
template <typename T> struct jm_type_interpretation_helper_trait {
74+
using element_type = T;
75+
using storage_element_type = T;
76+
};
77+
78+
template <>
79+
struct jm_type_interpretation_helper_trait<
80+
sycl::ext::oneapi::experimental::matrix::precision::tf32> {
81+
using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
82+
using storage_element_type = float;
83+
};
84+
} // namespace detail
7085
} // namespace oneapi
7186

7287
namespace intel::experimental::matrix {
7388

89+
using namespace sycl::ext::oneapi::experimental::matrix;
7490
// Begin wi_element definition
7591

7692
template <typename T, size_t NumRows, size_t NumCols,
@@ -84,6 +100,9 @@ class wi_element {
84100
std::size_t idx;
85101

86102
public:
103+
using storage_element_type =
104+
typename oneapi::detail::jm_type_interpretation_helper_trait<
105+
T>::storage_element_type;
87106
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
88107
Group, T, Use, NumRows, NumCols, Layout> &Mat,
89108
std::size_t i)
@@ -102,9 +121,15 @@ class wi_element {
102121
#endif // __SYCL_DEVICE_ONLY__
103122
}
104123

105-
operator T() {
124+
operator storage_element_type() {
106125
#ifdef __SYCL_DEVICE_ONLY__
107-
return __spirv_VectorExtractDynamic(M.spvm, idx);
126+
storage_element_type elem =
127+
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
128+
spv_matrix_use_traits<Use>::value,
129+
spv_matrix_layout_traits<Layout>::value,
130+
spv_scope_traits<Group>::value>(M.spvm,
131+
idx);
132+
return elem;
108133
#else
109134
throw runtime_error("joint matrix is not supported on host device.",
110135
PI_ERROR_INVALID_DEVICE);
@@ -113,7 +138,12 @@ class wi_element {
113138

114139
explicit operator bool() {
115140
#ifdef __SYCL_DEVICE_ONLY__
116-
return __spirv_VectorExtractDynamic(M.spvm, idx) != static_cast<T>(0);
141+
return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
142+
NumCols,
143+
spv_matrix_use_traits<Use>::value,
144+
spv_matrix_layout_traits<Layout>::value,
145+
spv_scope_traits<Group>::value>(
146+
M.spvm, idx) != static_cast<storage_element_type>(0);
117147
#else
118148
throw runtime_error("joint matrix is not supported on host device.",
119149
PI_ERROR_INVALID_DEVICE);
@@ -122,7 +152,8 @@ class wi_element {
122152

123153
template <typename T2> wi_element &operator=(const T2 &rhs) {
124154
#ifdef __SYCL_DEVICE_ONLY__
125-
M.spvm = __spirv_VectorInsertDynamic(M.spvm, static_cast<T>(rhs), idx);
155+
M.spvm = __spirv_VectorInsertDynamic(
156+
M.spvm, static_cast<storage_element_type>(rhs), idx);
126157
return *this;
127158
#else
128159
(void)rhs;
@@ -135,7 +166,13 @@ class wi_element {
135166
operator=(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) {
136167
#ifdef __SYCL_DEVICE_ONLY__
137168
M.spvm = __spirv_VectorInsertDynamic(
138-
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
169+
M.spvm,
170+
__spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
171+
spv_matrix_use_traits<Use>::value,
172+
spv_matrix_layout_traits<Layout>::value,
173+
spv_scope_traits<Group>::value>(rhs.M.spvm,
174+
rhs.idx),
175+
idx);
139176
return *this;
140177
#else
141178
(void)rhs;
@@ -149,8 +186,13 @@ class wi_element {
149186
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
150187
M.spvm = __spirv_VectorInsertDynamic( \
151188
M.spvm, \
152-
static_cast<T>(__spirv_VectorExtractDynamic(M.spvm, idx) \
153-
op static_cast<T>(rhs)), \
189+
static_cast<storage_element_type>( \
190+
__spirv_VectorExtractDynamic< \
191+
storage_element_type, T, NumRows, NumCols, \
192+
spv_matrix_use_traits<Use>::value, \
193+
spv_matrix_layout_traits<Layout>::value, \
194+
spv_scope_traits<Group>::value>(M.spvm, idx) \
195+
op static_cast<storage_element_type>(rhs)), \
154196
idx); \
155197
return *this; \
156198
}
@@ -201,7 +243,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
201243

202244
operator sycl::ext::oneapi::bfloat16() {
203245
#ifdef __SYCL_DEVICE_ONLY__
204-
return __spirv_VectorExtractDynamic(M.spvm, idx);
246+
return __spirv_VectorExtractDynamic<
247+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows,
248+
NumCols, spv_matrix_use_traits<Use>::value,
249+
spv_matrix_layout_traits<Layout>::value,
250+
spv_scope_traits<Group>::value>(M.spvm, idx);
205251
#else
206252
throw runtime_error("joint matrix is not supported on host device.",
207253
PI_ERROR_INVALID_DEVICE);
@@ -210,8 +256,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
210256

211257
explicit operator bool() {
212258
#ifdef __SYCL_DEVICE_ONLY__
213-
return std::fabs(static_cast<float>(__spirv_VectorExtractDynamic(
214-
M.spvm, idx))) >= std::numeric_limits<float>::epsilon();
259+
return std::fabs(static_cast<float>(
260+
__spirv_VectorExtractDynamic<
261+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16,
262+
NumRows, NumCols, spv_matrix_use_traits<Use>::value,
263+
spv_matrix_layout_traits<Layout>::value,
264+
spv_scope_traits<Group>::value>(M.spvm, idx))) >=
265+
std::numeric_limits<float>::epsilon();
215266
#else
216267
throw runtime_error("joint matrix is not supported on host device.",
217268
PI_ERROR_INVALID_DEVICE);
@@ -233,7 +284,14 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
233284
NumCols, Use, Layout, Group> &rhs) {
234285
#ifdef __SYCL_DEVICE_ONLY__
235286
M.spvm = __spirv_VectorInsertDynamic(
236-
M.spvm, __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx), idx);
287+
M.spvm,
288+
__spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
289+
sycl::ext::oneapi::bfloat16, NumRows,
290+
NumCols, spv_matrix_use_traits<Use>::value,
291+
spv_matrix_layout_traits<Layout>::value,
292+
spv_scope_traits<Group>::value>(rhs.M.spvm,
293+
rhs.idx),
294+
idx);
237295
return *this;
238296
#else
239297
(void)rhs;
@@ -246,7 +304,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
246304
#define OP(opassign, op) \
247305
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
248306
M.spvm = __spirv_VectorInsertDynamic( \
249-
M.spvm, __spirv_VectorExtractDynamic(M.spvm, idx) op rhs, idx); \
307+
M.spvm, \
308+
__spirv_VectorExtractDynamic< \
309+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
310+
NumCols, spv_matrix_use_traits<Use>::value, \
311+
spv_matrix_layout_traits<Layout>::value, \
312+
spv_scope_traits<Group>::value>(M.spvm, idx) op rhs, \
313+
idx); \
250314
return *this; \
251315
}
252316
#else // __SYCL_DEVICE_ONLY__
@@ -269,13 +333,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
269333
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
270334
Layout, Group> &lhs, \
271335
const sycl::ext::oneapi::bfloat16 &rhs) { \
272-
return __spirv_VectorExtractDynamic(lhs.M.spvm, lhs.idx) op rhs; \
336+
return __spirv_VectorExtractDynamic< \
337+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
338+
NumCols, spv_matrix_use_traits<Use>::value, \
339+
spv_matrix_layout_traits<Layout>::value, \
340+
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx) op rhs; \
273341
} \
274342
friend type operator op( \
275343
const sycl::ext::oneapi::bfloat16 &lhs, \
276344
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
277345
Layout, Group> &rhs) { \
278-
return __spirv_VectorExtractDynamic(rhs.M.spvm, rhs.idx) op lhs; \
346+
return __spirv_VectorExtractDynamic< \
347+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
348+
NumCols, spv_matrix_use_traits<Use>::value, \
349+
spv_matrix_layout_traits<Layout>::value, \
350+
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx) op lhs; \
279351
}
280352
OP(sycl::ext::oneapi::bfloat16, +)
281353
OP(sycl::ext::oneapi::bfloat16, -)
@@ -287,15 +359,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
287359
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
288360
Layout, Group> &lhs, \
289361
const sycl::ext::oneapi::bfloat16 &rhs) { \
290-
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
291-
lhs.M.spvm, lhs.idx)) op static_cast<float>(rhs)}; \
362+
return type{static_cast<float>( \
363+
__spirv_VectorExtractDynamic< \
364+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
365+
NumCols, spv_matrix_use_traits<Use>::value, \
366+
spv_matrix_layout_traits<Layout>::value, \
367+
spv_scope_traits<Group>::value>(lhs.M.spvm, lhs.idx)) \
368+
op static_cast<float>(rhs)}; \
292369
} \
293370
friend type operator op( \
294371
const sycl::ext::oneapi::bfloat16 &lhs, \
295372
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
296373
Layout, Group> &rhs) { \
297-
return type{static_cast<float>(__spirv_VectorExtractDynamic( \
298-
rhs.M.spvm, rhs.idx)) op static_cast<float>(lhs)}; \
374+
return type{static_cast<float>( \
375+
__spirv_VectorExtractDynamic< \
376+
sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
377+
NumCols, spv_matrix_use_traits<Use>::value, \
378+
spv_matrix_layout_traits<Layout>::value, \
379+
spv_scope_traits<Group>::value>(rhs.M.spvm, rhs.idx)) \
380+
op static_cast<float>(lhs)}; \
299381
}
300382
OP(bool, ==)
301383
OP(bool, !=)
@@ -386,7 +468,7 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
386468
// End wi_data definition
387469

388470
template <
389-
typename Group, typename T,
471+
typename Group, typename T, typename Tp,
390472
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
391473
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
392474
access::address_space Space, access::decorated IsDecorated,
@@ -396,7 +478,7 @@ template <
396478
inline __SYCL_ALWAYS_INLINE void
397479
joint_matrix_store(Group sg,
398480
sycl::ext::oneapi::experimental::matrix::joint_matrix<
399-
Group, T, Use, NumRows, NumCols, Layout> &src,
481+
Group, Tp, Use, NumRows, NumCols, Layout> &src,
400482
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
401483
#if defined(__SYCL_DEVICE_ONLY__)
402484
#if defined(__NVPTX__)
@@ -411,7 +493,7 @@ joint_matrix_store(Group sg,
411493
#else
412494
// intel's impl
413495
T *Ptr = dst.get();
414-
__spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
496+
__spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
415497
sycl::ext::oneapi::experimental::matrix::
416498
spv_matrix_use_traits<Use>::value,
417499
sycl::ext::oneapi::experimental::matrix::

sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp

-6
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ namespace oneapi {
1818
namespace experimental {
1919
namespace matrix {
2020

21-
namespace precision {
22-
class tf32 {
23-
tf32() = delete;
24-
};
25-
} // namespace precision
26-
2721
template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
2822
layout Layout = layout::dynamic>
2923
struct joint_matrix;

sycl/include/sycl/ext/oneapi/matrix/matrix-unified-utils.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ enum class use { a, b, accumulator };
1818

1919
enum class layout { row_major = 0, col_major = 1, dynamic = 3 };
2020

21+
namespace precision {
22+
class tf32 {
23+
tf32() = delete;
24+
};
25+
} // namespace precision
26+
2127
} // namespace matrix
2228
} // namespace experimental
2329
} // namespace oneapi

0 commit comments

Comments
 (0)