Skip to content

Commit 0acb60a

Browse files
committed
Conversion from BF16 to FP32 only once.
1 parent 9701a80 commit 0acb60a

File tree

2 files changed

+164
-160
lines changed

2 files changed

+164
-160
lines changed

kernel/riscv64/sbgemm_kernel_16x8_zvl256b.c

Lines changed: 91 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,32 @@
33

44
#define BF16_WIDEN_ONE
55

6+
#ifdef BF16_WIDEN_ONE
7+
#define FORCEINLINE inline __attribute__((always_inline))
8+
#define B_UNROLL 64
9+
10+
// Convert from BF16 to FP32
11+
static void FORCEINLINE B_CONV(__bf16 *BB, FLOAT *CONV, BLASLONG count)
12+
{
13+
BLASLONG count2 = (count & (B_UNROLL - 1));
14+
count &= -B_UNROLL;
15+
while (count) {
16+
vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4(BB, B_UNROLL);
17+
vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8(B00, B_UNROLL);
18+
__riscv_vse32_v_f32m8(CONV, B0, B_UNROLL);
19+
BB += B_UNROLL;
20+
CONV += B_UNROLL;
21+
count -= B_UNROLL;
22+
}
23+
if (count2) {
24+
BLASLONG gvl2 = __riscv_vsetvl_e16m4(count2);
25+
vbfloat16m4_t B00 = __riscv_vle16_v_bf16m4(BB, gvl2);
26+
vfloat32m8_t B0 = __riscv_vfwcvtbf16_f_f_v_f32m8(B00, gvl2);
27+
__riscv_vse32_v_f32m8(CONV, B0, gvl2);
28+
}
29+
}
30+
#endif
31+
632
int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B, FLOAT *C, BLASLONG ldc)
733
{
834
BLASLONG gvl = 0;
@@ -12,10 +38,11 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
1238
__bf16 *AA = (__bf16 *)(A);
1339

1440
#ifdef BF16_WIDEN_ONE
15-
FLOAT *B_CONV = NULL;
41+
FLOAT *CONV = NULL;
1642
if ((M >= 4) && (N >= 4) && (K > 0)) {
17-
B_CONV = (FLOAT *)(malloc(K * 8 * sizeof(FLOAT)));
18-
if (!B_CONV) return 1;
43+
CONV = (FLOAT *)(malloc((K * (8 + (M & -4))) * sizeof(FLOAT)));
44+
if (!CONV) return 1;
45+
B_CONV(AA, CONV + (K * 8), (M & -4) * K);
1946
}
2047
#endif
2148

@@ -24,26 +51,16 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
2451
m_top = 0;
2552
BLASLONG gvl = __riscv_vsetvl_e16m1(16);
2653
#ifdef BF16_WIDEN_ONE
27-
BLASLONG bi2;
28-
{
29-
BLASLONG bi3 = 0;
30-
BLASLONG gvl2;
31-
bi2 = K * 8;
32-
do {
33-
gvl2 = __riscv_vsetvl_e16m4(bi2);
34-
vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4(&BB[bi3 + (n_top*K)], gvl2);
35-
vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8(A00, gvl2);
36-
__riscv_vse32_v_f32m8(&B_CONV[bi3], A0, gvl2);
37-
bi3 += gvl2;
38-
} while (bi2 -= gvl2);
39-
}
54+
BLASLONG bi2 = K * 8;
55+
B_CONV(BB + (n_top*K), CONV, bi2);
56+
BLASLONG ai2 = K * 8;
4057
#endif
4158

4259
for (BLASLONG i=0; i<M/16; i+=1) {
43-
BLASLONG ai=m_top*K;
4460
#ifdef BF16_WIDEN_ONE
4561
bi2 = 0;
4662
#else
63+
BLASLONG ai=m_top*K;
4764
BLASLONG bi=n_top*K;
4865
#endif
4966

@@ -58,19 +75,18 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
5875

5976
for (BLASLONG k=0; k<K; k++) {
6077
#ifdef BF16_WIDEN_ONE
61-
float B0 = B_CONV[bi2+0];
62-
float B1 = B_CONV[bi2+1];
63-
float B2 = B_CONV[bi2+2];
64-
float B3 = B_CONV[bi2+3];
65-
float B4 = B_CONV[bi2+4];
66-
float B5 = B_CONV[bi2+5];
67-
float B6 = B_CONV[bi2+6];
68-
float B7 = B_CONV[bi2+7];
78+
float B0 = CONV[bi2+0];
79+
float B1 = CONV[bi2+1];
80+
float B2 = CONV[bi2+2];
81+
float B3 = CONV[bi2+3];
82+
float B4 = CONV[bi2+4];
83+
float B5 = CONV[bi2+5];
84+
float B6 = CONV[bi2+6];
85+
float B7 = CONV[bi2+7];
6986
bi2 += 8;
7087

71-
vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
72-
vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2(A00, gvl);
73-
ai += 16;
88+
vfloat32m2_t A0 = __riscv_vle32_v_f32m2(&CONV[ai2], gvl);
89+
ai2 += 16;
7490

7591
result0 = __riscv_vfmacc_vf_f32m2(result0, B0, A0, gvl);
7692
result1 = __riscv_vfmacc_vf_f32m2(result1, B1, A0, gvl);
@@ -143,10 +159,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
143159
if ( M & 8 ) {
144160
gvl = __riscv_vsetvl_e16mf2(8);
145161

146-
BLASLONG ai=m_top*K;
147162
#ifdef BF16_WIDEN_ONE
148163
bi2 = 0;
149164
#else
165+
BLASLONG ai=m_top*K;
150166
BLASLONG bi=n_top*K;
151167
#endif
152168

@@ -161,19 +177,18 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
161177

162178
for (BLASLONG k=0; k<K; k++) {
163179
#ifdef BF16_WIDEN_ONE
164-
float B0 = B_CONV[bi2+0];
165-
float B1 = B_CONV[bi2+1];
166-
float B2 = B_CONV[bi2+2];
167-
float B3 = B_CONV[bi2+3];
168-
float B4 = B_CONV[bi2+4];
169-
float B5 = B_CONV[bi2+5];
170-
float B6 = B_CONV[bi2+6];
171-
float B7 = B_CONV[bi2+7];
180+
float B0 = CONV[bi2+0];
181+
float B1 = CONV[bi2+1];
182+
float B2 = CONV[bi2+2];
183+
float B3 = CONV[bi2+3];
184+
float B4 = CONV[bi2+4];
185+
float B5 = CONV[bi2+5];
186+
float B6 = CONV[bi2+6];
187+
float B7 = CONV[bi2+7];
172188
bi2 += 8;
173189

174-
vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
175-
vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1(A00, gvl);
176-
ai += 8;
190+
vfloat32m1_t A0 = __riscv_vle32_v_f32m1(&CONV[ai2], gvl);
191+
ai2 += 8;
177192

178193
result0 = __riscv_vfmacc_vf_f32m1(result0, B0, A0, gvl);
179194
result1 = __riscv_vfmacc_vf_f32m1(result1, B1, A0, gvl);
@@ -244,10 +259,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
244259
if ( M & 4 ) {
245260
gvl = __riscv_vsetvl_e16mf2(4);
246261

247-
BLASLONG ai=m_top*K;
248262
#ifdef BF16_WIDEN_ONE
249263
bi2 = 0;
250264
#else
265+
BLASLONG ai=m_top*K;
251266
BLASLONG bi=n_top*K;
252267
#endif
253268

@@ -262,19 +277,18 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
262277

263278
for (BLASLONG k=0; k < K; ++k) {
264279
#ifdef BF16_WIDEN_ONE
265-
float B0 = B_CONV[bi2+0];
266-
float B1 = B_CONV[bi2+1];
267-
float B2 = B_CONV[bi2+2];
268-
float B3 = B_CONV[bi2+3];
269-
float B4 = B_CONV[bi2+4];
270-
float B5 = B_CONV[bi2+5];
271-
float B6 = B_CONV[bi2+6];
272-
float B7 = B_CONV[bi2+7];
280+
float B0 = CONV[bi2+0];
281+
float B1 = CONV[bi2+1];
282+
float B2 = CONV[bi2+2];
283+
float B3 = CONV[bi2+3];
284+
float B4 = CONV[bi2+4];
285+
float B5 = CONV[bi2+5];
286+
float B6 = CONV[bi2+6];
287+
float B7 = CONV[bi2+7];
273288
bi2 += 8;
274289

275-
vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4( &AA[ai+0*gvl], gvl );
276-
vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1(__riscv_vfwcvtbf16_f_f_v_f32mf2(A00, gvl));
277-
ai += 4;
290+
vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1(__riscv_vle32_v_f32mf2(&CONV[ai2], gvl));
291+
ai2 += 4;
278292

279293
result0 = __riscv_vfmacc_vf_f32m1(result0, B0, A0, gvl);
280294
result1 = __riscv_vfmacc_vf_f32m1(result1, B1, A0, gvl);
@@ -459,26 +473,16 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
459473
m_top = 0;
460474

461475
#ifdef BF16_WIDEN_ONE
462-
BLASLONG bi2;
463-
{
464-
BLASLONG bi3 = 0;
465-
BLASLONG gvl2;
466-
bi2 = K * 4;
467-
do {
468-
gvl2 = __riscv_vsetvl_e16m4(bi2);
469-
vbfloat16m4_t A00 = __riscv_vle16_v_bf16m4(&BB[bi3 + (n_top*K)], gvl2);
470-
vfloat32m8_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m8(A00, gvl2);
471-
__riscv_vse32_v_f32m8(&B_CONV[bi3], A0, gvl2);
472-
bi3 += gvl2;
473-
} while (bi2 -= gvl2);
474-
}
476+
BLASLONG bi2 = K * 4;
477+
B_CONV(BB + (n_top*K), CONV, bi2);
478+
BLASLONG ai2 = K * 8;
475479
#endif
476480

477481
for (BLASLONG i=0; i<M/16; i+=1) {
478-
BLASLONG ai=m_top*K;
479482
#ifdef BF16_WIDEN_ONE
480483
bi2 = 0;
481484
#else
485+
BLASLONG ai=m_top*K;
482486
BLASLONG bi=n_top*K;
483487
#endif
484488

@@ -489,15 +493,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
489493

490494
for (BLASLONG k=0; k<K; k++) {
491495
#ifdef BF16_WIDEN_ONE
492-
float B0 = B_CONV[bi2+0];
493-
float B1 = B_CONV[bi2+1];
494-
float B2 = B_CONV[bi2+2];
495-
float B3 = B_CONV[bi2+3];
496+
float B0 = CONV[bi2+0];
497+
float B1 = CONV[bi2+1];
498+
float B2 = CONV[bi2+2];
499+
float B3 = CONV[bi2+3];
496500
bi2 += 4;
497501

498-
vbfloat16m1_t A00 = __riscv_vle16_v_bf16m1( &AA[ai+0*gvl], gvl );
499-
vfloat32m2_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m2(A00, gvl);
500-
ai += 16;
502+
vfloat32m2_t A0 = __riscv_vle32_v_f32m2(&CONV[ai2], gvl);
503+
ai2 += 16;
501504

502505
result0 = __riscv_vfmacc_vf_f32m2(result0, B0, A0, gvl);
503506
result1 = __riscv_vfmacc_vf_f32m2(result1, B1, A0, gvl);
@@ -543,10 +546,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
543546

544547
if ( M & 8 ) {
545548
gvl = __riscv_vsetvl_e16mf2(8);
546-
BLASLONG ai=m_top*K;
547549
#ifdef BF16_WIDEN_ONE
548550
bi2 = 0;
549551
#else
552+
BLASLONG ai=m_top*K;
550553
BLASLONG bi=n_top*K;
551554
#endif
552555

@@ -557,15 +560,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
557560

558561
for (BLASLONG k=0; k<K; k++) {
559562
#ifdef BF16_WIDEN_ONE
560-
float B0 = B_CONV[bi2+0];
561-
float B1 = B_CONV[bi2+1];
562-
float B2 = B_CONV[bi2+2];
563-
float B3 = B_CONV[bi2+3];
563+
float B0 = CONV[bi2+0];
564+
float B1 = CONV[bi2+1];
565+
float B2 = CONV[bi2+2];
566+
float B3 = CONV[bi2+3];
564567
bi2 += 4;
565568

566-
vbfloat16mf2_t A00 = __riscv_vle16_v_bf16mf2( &AA[ai+0*gvl], gvl );
567-
vfloat32m1_t A0 = __riscv_vfwcvtbf16_f_f_v_f32m1(A00, gvl);
568-
ai += 8;
569+
vfloat32m1_t A0 = __riscv_vle32_v_f32m1(&CONV[ai2], gvl);
570+
ai2 += 8;
569571

570572
result0 = __riscv_vfmacc_vf_f32m1(result0, B0, A0, gvl);
571573
result1 = __riscv_vfmacc_vf_f32m1(result1, B1, A0, gvl);
@@ -612,10 +614,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
612614
if ( M & 4 ) {
613615
gvl = __riscv_vsetvl_e16mf2(4);
614616

615-
BLASLONG ai=m_top*K;
616617
#ifdef BF16_WIDEN_ONE
617618
bi2 = 0;
618619
#else
620+
BLASLONG ai=m_top*K;
619621
BLASLONG bi=n_top*K;
620622
#endif
621623

@@ -626,15 +628,14 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
626628

627629
for (BLASLONG k=0; k < K; ++k) {
628630
#ifdef BF16_WIDEN_ONE
629-
float B0 = B_CONV[bi2+0];
630-
float B1 = B_CONV[bi2+1];
631-
float B2 = B_CONV[bi2+2];
632-
float B3 = B_CONV[bi2+3];
631+
float B0 = CONV[bi2+0];
632+
float B1 = CONV[bi2+1];
633+
float B2 = CONV[bi2+2];
634+
float B3 = CONV[bi2+3];
633635
bi2 += 4;
634636

635-
vbfloat16mf4_t A00 = __riscv_vle16_v_bf16mf4( &AA[ai+0*gvl], gvl );
636-
vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1(__riscv_vfwcvtbf16_f_f_v_f32mf2(A00, gvl));
637-
ai += 4;
637+
vfloat32m1_t A0 = __riscv_vlmul_ext_v_f32mf2_f32m1(__riscv_vle32_v_f32mf2(&CONV[ai2], gvl));
638+
ai2 += 4;
638639

639640
result0 = __riscv_vfmacc_vf_f32m1(result0, B0, A0, gvl);
640641
result1 = __riscv_vfmacc_vf_f32m1(result1, B1, A0, gvl);
@@ -1041,7 +1042,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, IFLOAT *A, IFLOAT *B,
10411042
n_top += 1;
10421043
}
10431044
#ifdef BF16_WIDEN_ONE
1044-
if (B_CONV) free(B_CONV);
1045+
if (CONV) free(CONV);
10451046
#endif
10461047
return 0;
10471048
}

0 commit comments

Comments
 (0)