33
44#define BF16_WIDEN_ONE
55
6+ #ifdef BF16_WIDEN_ONE
7+ #define FORCEINLINE inline __attribute__((always_inline))
8+ #define B_UNROLL 64
9+
10+ // Convert from BF16 to FP32
11+ static void FORCEINLINE B_CONV (__bf16 * BB , FLOAT * CONV , BLASLONG count )
12+ {
13+ BLASLONG count2 = (count & (B_UNROLL - 1 ));
14+ count &= - B_UNROLL ;
15+ while (count ) {
16+ vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4 (BB , B_UNROLL );
17+ vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8 (B00 , B_UNROLL );
18+ __riscv_vse32_v_f32m8 (CONV , B0 , B_UNROLL );
19+ BB += B_UNROLL ;
20+ CONV += B_UNROLL ;
21+ count -= B_UNROLL ;
22+ }
23+ if (count2 ) {
24+ BLASLONG gvl2 = __riscv_vsetvl_e16m4 (count2 );
25+ vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4 (BB , gvl2 );
26+ vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8 (B00 , gvl2 );
27+ __riscv_vse32_v_f32m8 (CONV , B0 , gvl2 );
28+ }
29+ }
30+ #endif
31+
632int CNAME (BLASLONG M , BLASLONG N , BLASLONG K , FLOAT alpha , IFLOAT * A , IFLOAT * B , FLOAT * C , BLASLONG ldc )
733{
834 BLASLONG gvl = 0 ;
@@ -12,10 +38,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
1238 __bf16 * AA = (__bf16 * )(A );
1339
1440#ifdef BF16_WIDEN_ONE
15- FLOAT * B_CONV = NULL ;
41+ FLOAT * CONV = NULL ;
1642 if ((M >= 4 ) && (N >= 4 ) && (K > 0 )) {
17- B_CONV = (FLOAT * )(malloc (K * 8 * sizeof (FLOAT )));
18- if (!B_CONV ) return 1 ;
43+ CONV = (FLOAT * )(malloc ((K * (8 + (M & -4 ))) * sizeof (FLOAT )));
44+ if (!CONV ) return 1 ;
45+ B_CONV (AA , CONV + (K * 8 ), (M & -4 ) * K );
1946 }
2047#endif
2148
@@ -24,26 +51,16 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
2451 m_top = 0 ;
2552 BLASLONG gvl = __riscv_vsetvl_e16m1 (16 );
2653#ifdef BF16_WIDEN_ONE
27- BLASLONG bi2 ;
28- {
29- BLASLONG bi3 = 0 ;
30- BLASLONG gvl2 ;
31- bi2 = K * 8 ;
32- do {
33- gvl2 = __riscv_vsetvl_e16m4 (bi2 );
34- vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4 (& BB [bi3 + (n_top * K )], gvl2 );
35- vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8 (A00 , gvl2 );
36- __riscv_vse32_v_f32m8 (& B_CONV [bi3 ], A0 , gvl2 );
37- bi3 += gvl2 ;
38- } while (bi2 -= gvl2 );
39- }
54+ BLASLONG bi2 = K * 8 ;
55+ B_CONV (BB + (n_top * K ), CONV , bi2 );
56+ BLASLONG ai2 = K * 8 ;
4057#endif
4158
4259 for (BLASLONG i = 0 ; i < M /16 ; i += 1 ) {
43- BLASLONG ai = m_top * K ;
4460#ifdef BF16_WIDEN_ONE
4561 bi2 = 0 ;
4662#else
63+ BLASLONG ai = m_top * K ;
4764 BLASLONG bi = n_top * K ;
4865#endif
4966
@@ -58,19 +75,18 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
5875
5976 for (BLASLONG k = 0 ; k < K ; k ++ ) {
6077#ifdef BF16_WIDEN_ONE
61- float B0 = B_CONV [bi2 + 0 ];
62- float B1 = B_CONV [bi2 + 1 ];
63- float B2 = B_CONV [bi2 + 2 ];
64- float B3 = B_CONV [bi2 + 3 ];
65- float B4 = B_CONV [bi2 + 4 ];
66- float B5 = B_CONV [bi2 + 5 ];
67- float B6 = B_CONV [bi2 + 6 ];
68- float B7 = B_CONV [bi2 + 7 ];
78+ float B0 = CONV [bi2 + 0 ];
79+ float B1 = CONV [bi2 + 1 ];
80+ float B2 = CONV [bi2 + 2 ];
81+ float B3 = CONV [bi2 + 3 ];
82+ float B4 = CONV [bi2 + 4 ];
83+ float B5 = CONV [bi2 + 5 ];
84+ float B6 = CONV [bi2 + 6 ];
85+ float B7 = CONV [bi2 + 7 ];
6986 bi2 += 8 ;
7087
71- vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1 ( & AA [ai + 0 * gvl ], gvl );
72- vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2 (A00 , gvl );
73- ai += 16 ;
88+ vfloat32m2_t A0 = __riscv_vle32_v_f32m2 (& CONV [ai2 ], gvl );
89+ ai2 += 16 ;
7490
7591 result0 = __riscv_vfmacc_vf_f32m2 (result0 , B0 , A0 , gvl );
7692 result1 = __riscv_vfmacc_vf_f32m2 (result1 , B1 , A0 , gvl );
@@ -143,10 +159,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
143159 if ( M & 8 ) {
144160 gvl = __riscv_vsetvl_e16mf2 (8 );
145161
146- BLASLONG ai = m_top * K ;
147162#ifdef BF16_WIDEN_ONE
148163 bi2 = 0 ;
149164#else
165+ BLASLONG ai = m_top * K ;
150166 BLASLONG bi = n_top * K ;
151167#endif
152168
@@ -161,19 +177,18 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
161177
162178 for (BLASLONG k = 0 ; k < K ; k ++ ) {
163179#ifdef BF16_WIDEN_ONE
164- float B0 = B_CONV [bi2 + 0 ];
165- float B1 = B_CONV [bi2 + 1 ];
166- float B2 = B_CONV [bi2 + 2 ];
167- float B3 = B_CONV [bi2 + 3 ];
168- float B4 = B_CONV [bi2 + 4 ];
169- float B5 = B_CONV [bi2 + 5 ];
170- float B6 = B_CONV [bi2 + 6 ];
171- float B7 = B_CONV [bi2 + 7 ];
180+ float B0 = CONV [bi2 + 0 ];
181+ float B1 = CONV [bi2 + 1 ];
182+ float B2 = CONV [bi2 + 2 ];
183+ float B3 = CONV [bi2 + 3 ];
184+ float B4 = CONV [bi2 + 4 ];
185+ float B5 = CONV [bi2 + 5 ];
186+ float B6 = CONV [bi2 + 6 ];
187+ float B7 = CONV [bi2 + 7 ];
172188 bi2 += 8 ;
173189
174- vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2 ( & AA [ai + 0 * gvl ], gvl );
175- vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1 (A00 , gvl );
176- ai += 8 ;
190+ vfloat32m1_t A0 = __riscv_vle32_v_f32m1 (& CONV [ai2 ], gvl );
191+ ai2 += 8 ;
177192
178193 result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
179194 result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
@@ -244,10 +259,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
244259 if ( M & 4 ) {
245260 gvl = __riscv_vsetvl_e16mf2 (4 );
246261
247- BLASLONG ai = m_top * K ;
248262#ifdef BF16_WIDEN_ONE
249263 bi2 = 0 ;
250264#else
265+ BLASLONG ai = m_top * K ;
251266 BLASLONG bi = n_top * K ;
252267#endif
253268
@@ -262,19 +277,18 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
262277
263278 for (BLASLONG k = 0 ; k < K ; ++ k ) {
264279#ifdef BF16_WIDEN_ONE
265- float B0 = B_CONV [bi2 + 0 ];
266- float B1 = B_CONV [bi2 + 1 ];
267- float B2 = B_CONV [bi2 + 2 ];
268- float B3 = B_CONV [bi2 + 3 ];
269- float B4 = B_CONV [bi2 + 4 ];
270- float B5 = B_CONV [bi2 + 5 ];
271- float B6 = B_CONV [bi2 + 6 ];
272- float B7 = B_CONV [bi2 + 7 ];
280+ float B0 = CONV [bi2 + 0 ];
281+ float B1 = CONV [bi2 + 1 ];
282+ float B2 = CONV [bi2 + 2 ];
283+ float B3 = CONV [bi2 + 3 ];
284+ float B4 = CONV [bi2 + 4 ];
285+ float B5 = CONV [bi2 + 5 ];
286+ float B6 = CONV [bi2 + 6 ];
287+ float B7 = CONV [bi2 + 7 ];
273288 bi2 += 8 ;
274289
275- vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4 ( & AA [ai + 0 * gvl ], gvl );
276- vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vfwcvtbf16_f_f_v_f32mf2 (A00 , gvl ));
277- ai += 4 ;
290+ vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vle32_v_f32mf2 (& CONV [ai2 ], gvl ));
291+ ai2 += 4 ;
278292
279293 result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
280294 result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
@@ -459,26 +473,16 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
459473 m_top = 0 ;
460474
461475#ifdef BF16_WIDEN_ONE
462- BLASLONG bi2 ;
463- {
464- BLASLONG bi3 = 0 ;
465- BLASLONG gvl2 ;
466- bi2 = K * 4 ;
467- do {
468- gvl2 = __riscv_vsetvl_e16m4 (bi2 );
469- vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4 (& BB [bi3 + (n_top * K )], gvl2 );
470- vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8 (A00 , gvl2 );
471- __riscv_vse32_v_f32m8 (& B_CONV [bi3 ], A0 , gvl2 );
472- bi3 += gvl2 ;
473- } while (bi2 -= gvl2 );
474- }
476+ BLASLONG bi2 = K * 4 ;
477+ B_CONV (BB + (n_top * K ), CONV , bi2 );
478+ BLASLONG ai2 = K * 8 ;
475479#endif
476480
477481 for (BLASLONG i = 0 ; i < M /16 ; i += 1 ) {
478- BLASLONG ai = m_top * K ;
479482#ifdef BF16_WIDEN_ONE
480483 bi2 = 0 ;
481484#else
485+ BLASLONG ai = m_top * K ;
482486 BLASLONG bi = n_top * K ;
483487#endif
484488
@@ -489,15 +493,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
489493
490494 for (BLASLONG k = 0 ; k < K ; k ++ ) {
491495#ifdef BF16_WIDEN_ONE
492- float B0 = B_CONV [bi2 + 0 ];
493- float B1 = B_CONV [bi2 + 1 ];
494- float B2 = B_CONV [bi2 + 2 ];
495- float B3 = B_CONV [bi2 + 3 ];
496+ float B0 = CONV [bi2 + 0 ];
497+ float B1 = CONV [bi2 + 1 ];
498+ float B2 = CONV [bi2 + 2 ];
499+ float B3 = CONV [bi2 + 3 ];
496500 bi2 += 4 ;
497501
498- vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1 ( & AA [ai + 0 * gvl ], gvl );
499- vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2 (A00 , gvl );
500- ai += 16 ;
502+ vfloat32m2_t A0 = __riscv_vle32_v_f32m2 (& CONV [ai2 ], gvl );
503+ ai2 += 16 ;
501504
502505 result0 = __riscv_vfmacc_vf_f32m2 (result0 , B0 , A0 , gvl );
503506 result1 = __riscv_vfmacc_vf_f32m2 (result1 , B1 , A0 , gvl );
@@ -543,10 +546,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
543546
544547 if ( M & 8 ) {
545548 gvl = __riscv_vsetvl_e16mf2 (8 );
546- BLASLONG ai = m_top * K ;
547549#ifdef BF16_WIDEN_ONE
548550 bi2 = 0 ;
549551#else
552+ BLASLONG ai = m_top * K ;
550553 BLASLONG bi = n_top * K ;
551554#endif
552555
@@ -557,15 +560,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
557560
558561 for (BLASLONG k = 0 ; k < K ; k ++ ) {
559562#ifdef BF16_WIDEN_ONE
560- float B0 = B_CONV [bi2 + 0 ];
561- float B1 = B_CONV [bi2 + 1 ];
562- float B2 = B_CONV [bi2 + 2 ];
563- float B3 = B_CONV [bi2 + 3 ];
563+ float B0 = CONV [bi2 + 0 ];
564+ float B1 = CONV [bi2 + 1 ];
565+ float B2 = CONV [bi2 + 2 ];
566+ float B3 = CONV [bi2 + 3 ];
564567 bi2 += 4 ;
565568
566- vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2 ( & AA [ai + 0 * gvl ], gvl );
567- vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1 (A00 , gvl );
568- ai += 8 ;
569+ vfloat32m1_t A0 = __riscv_vle32_v_f32m1 (& CONV [ai2 ], gvl );
570+ ai2 += 8 ;
569571
570572 result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
571573 result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
@@ -612,10 +614,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
612614 if ( M & 4 ) {
613615 gvl = __riscv_vsetvl_e16mf2 (4 );
614616
615- BLASLONG ai = m_top * K ;
616617#ifdef BF16_WIDEN_ONE
617618 bi2 = 0 ;
618619#else
620+ BLASLONG ai = m_top * K ;
619621 BLASLONG bi = n_top * K ;
620622#endif
621623
@@ -626,15 +628,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
626628
627629 for (BLASLONG k = 0 ; k < K ; ++ k ) {
628630#ifdef BF16_WIDEN_ONE
629- float B0 = B_CONV [bi2 + 0 ];
630- float B1 = B_CONV [bi2 + 1 ];
631- float B2 = B_CONV [bi2 + 2 ];
632- float B3 = B_CONV [bi2 + 3 ];
631+ float B0 = CONV [bi2 + 0 ];
632+ float B1 = CONV [bi2 + 1 ];
633+ float B2 = CONV [bi2 + 2 ];
634+ float B3 = CONV [bi2 + 3 ];
633635 bi2 += 4 ;
634636
635- vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4 ( & AA [ai + 0 * gvl ], gvl );
636- vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vfwcvtbf16_f_f_v_f32mf2 (A00 , gvl ));
637- ai += 4 ;
637+ vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1 (__riscv_vle32_v_f32mf2 (& CONV [ai2 ], gvl ));
638+ ai2 += 4 ;
638639
639640 result0 = __riscv_vfmacc_vf_f32m1 (result0 , B0 , A0 , gvl );
640641 result1 = __riscv_vfmacc_vf_f32m1 (result1 , B1 , A0 , gvl );
@@ -1041,7 +1042,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
10411042 n_top += 1 ;
10421043 }
10431044#ifdef BF16_WIDEN_ONE
1044- if (B_CONV ) free (B_CONV );
1045+ if (CONV ) free (CONV );
10451046#endif
10461047 return 0 ;
10471048}
0 commit comments