@@ -67,10 +67,26 @@ struct joint_matrix;
67
67
68
68
} // namespace matrix
69
69
} // namespace experimental
70
+
71
+ namespace detail {
72
+ // Differentiating between the "element type" and the "storage element type"
73
+ template <typename T> struct jm_type_interpretation_helper_trait {
74
+ using element_type = T;
75
+ using storage_element_type = T;
76
+ };
77
+
78
+ template <>
79
+ struct jm_type_interpretation_helper_trait <
80
+ sycl::ext::oneapi::experimental::matrix::precision::tf32> {
81
+ using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
82
+ using storage_element_type = float ;
83
+ };
84
+ } // namespace detail
70
85
} // namespace oneapi
71
86
72
87
namespace intel ::experimental::matrix {
73
88
89
+ using namespace sycl ::ext::oneapi::experimental::matrix;
74
90
// Begin wi_element definition
75
91
76
92
template <typename T, size_t NumRows, size_t NumCols,
@@ -84,6 +100,9 @@ class wi_element {
84
100
std::size_t idx;
85
101
86
102
public:
103
+ using storage_element_type =
104
+ typename oneapi::detail::jm_type_interpretation_helper_trait<
105
+ T>::storage_element_type;
87
106
wi_element (sycl::ext::oneapi::experimental::matrix::joint_matrix<
88
107
Group, T, Use, NumRows, NumCols, Layout> &Mat,
89
108
std::size_t i)
@@ -102,9 +121,15 @@ class wi_element {
102
121
#endif // __SYCL_DEVICE_ONLY__
103
122
}
104
123
105
- operator T () {
124
+ operator storage_element_type () {
106
125
#ifdef __SYCL_DEVICE_ONLY__
107
- return __spirv_VectorExtractDynamic (M.spvm , idx);
126
+ storage_element_type elem =
127
+ __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
128
+ spv_matrix_use_traits<Use>::value,
129
+ spv_matrix_layout_traits<Layout>::value,
130
+ spv_scope_traits<Group>::value>(M.spvm ,
131
+ idx);
132
+ return elem;
108
133
#else
109
134
throw runtime_error (" joint matrix is not supported on host device." ,
110
135
PI_ERROR_INVALID_DEVICE);
@@ -113,7 +138,12 @@ class wi_element {
113
138
114
139
explicit operator bool () {
115
140
#ifdef __SYCL_DEVICE_ONLY__
116
- return __spirv_VectorExtractDynamic (M.spvm , idx) != static_cast <T>(0 );
141
+ return __spirv_VectorExtractDynamic<storage_element_type, T, NumRows,
142
+ NumCols,
143
+ spv_matrix_use_traits<Use>::value,
144
+ spv_matrix_layout_traits<Layout>::value,
145
+ spv_scope_traits<Group>::value>(
146
+ M.spvm , idx) != static_cast <storage_element_type>(0 );
117
147
#else
118
148
throw runtime_error (" joint matrix is not supported on host device." ,
119
149
PI_ERROR_INVALID_DEVICE);
@@ -122,7 +152,8 @@ class wi_element {
122
152
123
153
template <typename T2> wi_element &operator =(const T2 &rhs) {
124
154
#ifdef __SYCL_DEVICE_ONLY__
125
- M.spvm = __spirv_VectorInsertDynamic (M.spvm , static_cast <T>(rhs), idx);
155
+ M.spvm = __spirv_VectorInsertDynamic (
156
+ M.spvm , static_cast <storage_element_type>(rhs), idx);
126
157
return *this ;
127
158
#else
128
159
(void )rhs;
@@ -135,7 +166,13 @@ class wi_element {
135
166
operator =(const wi_element<T, NumRows, NumCols, Use, Layout, Group> &rhs) {
136
167
#ifdef __SYCL_DEVICE_ONLY__
137
168
M.spvm = __spirv_VectorInsertDynamic (
138
- M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
169
+ M.spvm ,
170
+ __spirv_VectorExtractDynamic<storage_element_type, T, NumRows, NumCols,
171
+ spv_matrix_use_traits<Use>::value,
172
+ spv_matrix_layout_traits<Layout>::value,
173
+ spv_scope_traits<Group>::value>(rhs.M .spvm ,
174
+ rhs.idx ),
175
+ idx);
139
176
return *this ;
140
177
#else
141
178
(void )rhs;
@@ -149,8 +186,13 @@ class wi_element {
149
186
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
150
187
M.spvm = __spirv_VectorInsertDynamic ( \
151
188
M.spvm , \
152
- static_cast <T>(__spirv_VectorExtractDynamic (M.spvm , idx) \
153
- op static_cast <T>(rhs)), \
189
+ static_cast <storage_element_type>( \
190
+ __spirv_VectorExtractDynamic< \
191
+ storage_element_type, T, NumRows, NumCols, \
192
+ spv_matrix_use_traits<Use>::value, \
193
+ spv_matrix_layout_traits<Layout>::value, \
194
+ spv_scope_traits<Group>::value>(M.spvm , idx) \
195
+ op static_cast <storage_element_type>(rhs)), \
154
196
idx); \
155
197
return *this ; \
156
198
}
@@ -201,7 +243,11 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
201
243
202
244
operator sycl::ext::oneapi::bfloat16 () {
203
245
#ifdef __SYCL_DEVICE_ONLY__
204
- return __spirv_VectorExtractDynamic (M.spvm , idx);
246
+ return __spirv_VectorExtractDynamic<
247
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows,
248
+ NumCols, spv_matrix_use_traits<Use>::value,
249
+ spv_matrix_layout_traits<Layout>::value,
250
+ spv_scope_traits<Group>::value>(M.spvm , idx);
205
251
#else
206
252
throw runtime_error (" joint matrix is not supported on host device." ,
207
253
PI_ERROR_INVALID_DEVICE);
@@ -210,8 +256,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
210
256
211
257
explicit operator bool () {
212
258
#ifdef __SYCL_DEVICE_ONLY__
213
- return std::fabs (static_cast <float >(__spirv_VectorExtractDynamic (
214
- M.spvm , idx))) >= std::numeric_limits<float >::epsilon ();
259
+ return std::fabs (static_cast <float >(
260
+ __spirv_VectorExtractDynamic<
261
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16,
262
+ NumRows, NumCols, spv_matrix_use_traits<Use>::value,
263
+ spv_matrix_layout_traits<Layout>::value,
264
+ spv_scope_traits<Group>::value>(M.spvm , idx))) >=
265
+ std::numeric_limits<float >::epsilon ();
215
266
#else
216
267
throw runtime_error (" joint matrix is not supported on host device." ,
217
268
PI_ERROR_INVALID_DEVICE);
@@ -233,7 +284,14 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
233
284
NumCols, Use, Layout, Group> &rhs) {
234
285
#ifdef __SYCL_DEVICE_ONLY__
235
286
M.spvm = __spirv_VectorInsertDynamic (
236
- M.spvm , __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ), idx);
287
+ M.spvm ,
288
+ __spirv_VectorExtractDynamic<sycl::ext::oneapi::bfloat16,
289
+ sycl::ext::oneapi::bfloat16, NumRows,
290
+ NumCols, spv_matrix_use_traits<Use>::value,
291
+ spv_matrix_layout_traits<Layout>::value,
292
+ spv_scope_traits<Group>::value>(rhs.M .spvm ,
293
+ rhs.idx ),
294
+ idx);
237
295
return *this ;
238
296
#else
239
297
(void )rhs;
@@ -246,7 +304,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
246
304
#define OP (opassign, op ) \
247
305
wi_element &operator opassign (const sycl::ext::oneapi::bfloat16 &rhs) { \
248
306
M.spvm = __spirv_VectorInsertDynamic ( \
249
- M.spvm , __spirv_VectorExtractDynamic (M.spvm , idx) op rhs, idx); \
307
+ M.spvm , \
308
+ __spirv_VectorExtractDynamic< \
309
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
310
+ NumCols, spv_matrix_use_traits<Use>::value, \
311
+ spv_matrix_layout_traits<Layout>::value, \
312
+ spv_scope_traits<Group>::value>(M.spvm , idx) op rhs, \
313
+ idx); \
250
314
return *this ; \
251
315
}
252
316
#else // __SYCL_DEVICE_ONLY__
@@ -269,13 +333,21 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
269
333
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
270
334
Layout, Group> &lhs, \
271
335
const sycl::ext::oneapi::bfloat16 &rhs) { \
272
- return __spirv_VectorExtractDynamic (lhs.M .spvm , lhs.idx ) op rhs; \
336
+ return __spirv_VectorExtractDynamic< \
337
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
338
+ NumCols, spv_matrix_use_traits<Use>::value, \
339
+ spv_matrix_layout_traits<Layout>::value, \
340
+ spv_scope_traits<Group>::value>(lhs.M .spvm , lhs.idx ) op rhs; \
273
341
} \
274
342
friend type operator op ( \
275
343
const sycl::ext::oneapi::bfloat16 &lhs, \
276
344
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
277
345
Layout, Group> &rhs) { \
278
- return __spirv_VectorExtractDynamic (rhs.M .spvm , rhs.idx ) op lhs; \
346
+ return __spirv_VectorExtractDynamic< \
347
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
348
+ NumCols, spv_matrix_use_traits<Use>::value, \
349
+ spv_matrix_layout_traits<Layout>::value, \
350
+ spv_scope_traits<Group>::value>(rhs.M .spvm , rhs.idx ) op lhs; \
279
351
}
280
352
OP (sycl::ext::oneapi::bfloat16, +)
281
353
OP(sycl::ext::oneapi::bfloat16, -)
@@ -287,15 +359,25 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
287
359
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
288
360
Layout, Group> &lhs, \
289
361
const sycl::ext::oneapi::bfloat16 &rhs) { \
290
- return type{static_cast <float >(__spirv_VectorExtractDynamic ( \
291
- lhs.M .spvm , lhs.idx )) op static_cast <float >(rhs)}; \
362
+ return type{static_cast <float >( \
363
+ __spirv_VectorExtractDynamic< \
364
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
365
+ NumCols, spv_matrix_use_traits<Use>::value, \
366
+ spv_matrix_layout_traits<Layout>::value, \
367
+ spv_scope_traits<Group>::value>(lhs.M .spvm , lhs.idx )) \
368
+ op static_cast <float >(rhs)}; \
292
369
} \
293
370
friend type operator op ( \
294
371
const sycl::ext::oneapi::bfloat16 &lhs, \
295
372
const wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, \
296
373
Layout, Group> &rhs) { \
297
- return type{static_cast <float >(__spirv_VectorExtractDynamic ( \
298
- rhs.M .spvm , rhs.idx )) op static_cast <float >(lhs)}; \
374
+ return type{static_cast <float >( \
375
+ __spirv_VectorExtractDynamic< \
376
+ sycl::ext::oneapi::bfloat16, sycl::ext::oneapi::bfloat16, NumRows, \
377
+ NumCols, spv_matrix_use_traits<Use>::value, \
378
+ spv_matrix_layout_traits<Layout>::value, \
379
+ spv_scope_traits<Group>::value>(rhs.M .spvm , rhs.idx )) \
380
+ op static_cast <float >(lhs)}; \
299
381
}
300
382
OP (bool , ==)
301
383
OP(bool , !=)
@@ -386,7 +468,7 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
386
468
// End wi_data definition
387
469
388
470
template <
389
- typename Group, typename T,
471
+ typename Group, typename T, typename Tp,
390
472
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
391
473
size_t NumCols, sycl::ext::oneapi::experimental::matrix::layout Layout,
392
474
access::address_space Space, access::decorated IsDecorated,
@@ -396,7 +478,7 @@ template <
396
478
inline __SYCL_ALWAYS_INLINE void
397
479
joint_matrix_store (Group sg,
398
480
sycl::ext::oneapi::experimental::matrix::joint_matrix<
399
- Group, T , Use, NumRows, NumCols, Layout> &src,
481
+ Group, Tp , Use, NumRows, NumCols, Layout> &src,
400
482
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
401
483
#if defined(__SYCL_DEVICE_ONLY__)
402
484
#if defined(__NVPTX__)
@@ -411,7 +493,7 @@ joint_matrix_store(Group sg,
411
493
#else
412
494
// intel's impl
413
495
T *Ptr = dst.get ();
414
- __spirv_JointMatrixStoreINTEL<T, NumRows, NumCols,
496
+ __spirv_JointMatrixStoreINTEL<T, Tp, NumRows, NumCols,
415
497
sycl::ext::oneapi::experimental::matrix::
416
498
spv_matrix_use_traits<Use>::value,
417
499
sycl::ext::oneapi::experimental::matrix::
0 commit comments