Skip to content

Commit 1fed338

Browse files
alankellyxnnpack-bot
authored andcommitted
Add x32 avx512 gemm microkernels. These are way faster
f32_gemm_minmax_ukernel_1x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 1029482 ns 1029201 ns 1364 FLOPS=32.5935G/s f32_gemm_minmax_ukernel_4x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 586536 ns 586291 ns 2407 FLOPS=57.2078G/s f32_gemm_minmax_ukernel_5x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 541205 ns 541143 ns 2581 FLOPS=61.9994G/s f32_gemm_minmax_ukernel_6x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 566296 ns 566175 ns 2466 FLOPS=59.2525G/s f32_gemm_minmax_ukernel_7x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 557455 ns 557275 ns 2464 FLOPS=60.1921G/s f32_gemm_minmax_ukernel_8x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 586344 ns 586240 ns 2396 FLOPS=57.2265G/s f32_gemm_minmax_ukernel_1x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 568475 ns 568317 ns 2494 FLOPS=59.0254G/s f32_gemm_minmax_ukernel_4x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 343299 ns 343171 ns 3967 FLOPS=97.7411G/s f32_gemm_minmax_ukernel_5x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 326208 ns 326123 ns 4317 FLOPS=102.862G/s f32_gemm_minmax_ukernel_6x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 324746 ns 324675 ns 4317 FLOPS=103.325G/s f32_gemm_minmax_ukernel_7x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 319434 ns 319344 ns 4337 FLOPS=105.043G/s f32_gemm_minmax_ukernel_8x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time 358816 ns 358735 ns 4029 FLOPS=93.5143G/s PiperOrigin-RevId: 692901687
1 parent 9dc7da7 commit 1fed338

37 files changed

+4979
-24
lines changed

bench/f32-gemm-minmax.cc

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,72 @@
13521352
}
13531353

13541354
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16__avx512f_broadcast)
1355+
1356+
static void f32_gemm_minmax_ukernel_1x32__avx512f_broadcast(benchmark::State& state, const char* net) {
1357+
GEMMBenchmark(state,
1358+
xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast,
1359+
xnn_init_f32_minmax_scalar_params,
1360+
xnn_pack_f32_gemm_goi_w,
1361+
/*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1,
1362+
benchmark::utils::CheckAVX512F);
1363+
}
1364+
1365+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32__avx512f_broadcast)
1366+
1367+
static void f32_gemm_minmax_ukernel_4x32__avx512f_broadcast(benchmark::State& state, const char* net) {
1368+
GEMMBenchmark(state,
1369+
xnn_f32_gemm_minmax_ukernel_4x32__avx512f_broadcast,
1370+
xnn_init_f32_minmax_scalar_params,
1371+
xnn_pack_f32_gemm_goi_w,
1372+
/*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1,
1373+
benchmark::utils::CheckAVX512F);
1374+
}
1375+
1376+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32__avx512f_broadcast)
1377+
1378+
static void f32_gemm_minmax_ukernel_5x32__avx512f_broadcast(benchmark::State& state, const char* net) {
1379+
GEMMBenchmark(state,
1380+
xnn_f32_gemm_minmax_ukernel_5x32__avx512f_broadcast,
1381+
xnn_init_f32_minmax_scalar_params,
1382+
xnn_pack_f32_gemm_goi_w,
1383+
/*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1,
1384+
benchmark::utils::CheckAVX512F);
1385+
}
1386+
1387+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32__avx512f_broadcast)
1388+
1389+
static void f32_gemm_minmax_ukernel_6x32__avx512f_broadcast(benchmark::State& state, const char* net) {
1390+
GEMMBenchmark(state,
1391+
xnn_f32_gemm_minmax_ukernel_6x32__avx512f_broadcast,
1392+
xnn_init_f32_minmax_scalar_params,
1393+
xnn_pack_f32_gemm_goi_w,
1394+
/*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1,
1395+
benchmark::utils::CheckAVX512F);
1396+
}
1397+
1398+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x32__avx512f_broadcast)
1399+
1400+
static void f32_gemm_minmax_ukernel_7x32__avx512f_broadcast(benchmark::State& state, const char* net) {
1401+
GEMMBenchmark(state,
1402+
xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast,
1403+
xnn_init_f32_minmax_scalar_params,
1404+
xnn_pack_f32_gemm_goi_w,
1405+
/*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1,
1406+
benchmark::utils::CheckAVX512F);
1407+
}
1408+
1409+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x32__avx512f_broadcast)
1410+
1411+
static void f32_gemm_minmax_ukernel_8x32__avx512f_broadcast(benchmark::State& state, const char* net) {
1412+
GEMMBenchmark(state,
1413+
xnn_f32_gemm_minmax_ukernel_8x32__avx512f_broadcast,
1414+
xnn_init_f32_minmax_scalar_params,
1415+
xnn_pack_f32_gemm_goi_w,
1416+
/*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1,
1417+
benchmark::utils::CheckAVX512F);
1418+
}
1419+
1420+
BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x32__avx512f_broadcast)
13551421
#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64)
13561422

13571423

cmake/gen/avx512f_microkernels.cmake

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
1515
src/f32-dwconv/gen/f32-dwconv-5f5m5l32c16s1r-minmax-avx512f.c
1616
src/f32-dwconv/gen/f32-dwconv-9p16c-minmax-avx512f.c
1717
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f.c
18-
src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c
19-
src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c
20-
src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c
21-
src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c
18+
src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c
19+
src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c
20+
src/f32-igemm/gen/f32-igemm-1x32-minmax-avx512f-broadcast.c
21+
src/f32-igemm/gen/f32-igemm-7x32-minmax-avx512f-broadcast.c
2222
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr2-p5-u64-acc2.c
2323
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c
2424
src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c
@@ -66,7 +66,7 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
6666
src/f32-vunary/gen/f32-vsqr-avx512f.c
6767
src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c
6868
src/u32-f32-vcvt/gen/u32-f32-vcvt-avx512f.c
69-
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c)
69+
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c)
7070

7171
SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
7272
src/f32-dwconv/gen/f32-dwconv-3p16c-minmax-avx512f-acc2.c
@@ -84,20 +84,32 @@ SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
8484
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f-acc2.c
8585
src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f-acc2.c
8686
src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f.c
87+
src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c
8788
src/f32-gemm/gen/f32-gemm-4x16-minmax-avx512f-broadcast.c
89+
src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c
8890
src/f32-gemm/gen/f32-gemm-5x16-minmax-avx512f-broadcast.c
91+
src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c
8992
src/f32-gemm/gen/f32-gemm-6x16-minmax-avx512f-broadcast.c
93+
src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c
94+
src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c
9095
src/f32-gemm/gen/f32-gemm-8x16-minmax-avx512f-broadcast.c
96+
src/f32-gemm/gen/f32-gemm-8x32-minmax-avx512f-broadcast.c
9197
src/f32-gemminc/gen/f32-gemminc-1x16-minmax-avx512f-broadcast.c
9298
src/f32-gemminc/gen/f32-gemminc-4x16-minmax-avx512f-broadcast.c
9399
src/f32-gemminc/gen/f32-gemminc-5x16-minmax-avx512f-broadcast.c
94100
src/f32-gemminc/gen/f32-gemminc-6x16-minmax-avx512f-broadcast.c
95101
src/f32-gemminc/gen/f32-gemminc-7x16-minmax-avx512f-broadcast.c
96102
src/f32-gemminc/gen/f32-gemminc-8x16-minmax-avx512f-broadcast.c
103+
src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c
97104
src/f32-igemm/gen/f32-igemm-4x16-minmax-avx512f-broadcast.c
105+
src/f32-igemm/gen/f32-igemm-4x32-minmax-avx512f-broadcast.c
98106
src/f32-igemm/gen/f32-igemm-5x16-minmax-avx512f-broadcast.c
107+
src/f32-igemm/gen/f32-igemm-5x32-minmax-avx512f-broadcast.c
99108
src/f32-igemm/gen/f32-igemm-6x16-minmax-avx512f-broadcast.c
109+
src/f32-igemm/gen/f32-igemm-6x32-minmax-avx512f-broadcast.c
110+
src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c
100111
src/f32-igemm/gen/f32-igemm-8x16-minmax-avx512f-broadcast.c
112+
src/f32-igemm/gen/f32-igemm-8x32-minmax-avx512f-broadcast.c
101113
src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc2.c
102114
src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc4.c
103115
src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64.c
@@ -279,6 +291,8 @@ SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
279291
src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u32.c
280292
src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u48.c
281293
src/f32-vtanh/gen/f32-vtanh-avx512f-rational-9-8-div.c
282-
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c)
294+
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c
295+
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c
296+
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c)
283297

284298
SET(ALL_AVX512F_MICROKERNEL_SRCS ${PROD_AVX512F_MICROKERNEL_SRCS} + ${NON_PROD_AVX512F_MICROKERNEL_SRCS})

gen/avx512f_microkernels.bzl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ PROD_AVX512F_MICROKERNEL_SRCS = [
1111
"src/f32-dwconv/gen/f32-dwconv-5f5m5l32c16s1r-minmax-avx512f.c",
1212
"src/f32-dwconv/gen/f32-dwconv-9p16c-minmax-avx512f.c",
1313
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f.c",
14-
"src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c",
15-
"src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c",
16-
"src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c",
17-
"src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c",
14+
"src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c",
15+
"src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c",
16+
"src/f32-igemm/gen/f32-igemm-1x32-minmax-avx512f-broadcast.c",
17+
"src/f32-igemm/gen/f32-igemm-7x32-minmax-avx512f-broadcast.c",
1818
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr2-p5-u64-acc2.c",
1919
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c",
2020
"src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c",
@@ -62,7 +62,7 @@ PROD_AVX512F_MICROKERNEL_SRCS = [
6262
"src/f32-vunary/gen/f32-vsqr-avx512f.c",
6363
"src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c",
6464
"src/u32-f32-vcvt/gen/u32-f32-vcvt-avx512f.c",
65-
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c",
65+
"src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c",
6666
]
6767

6868
NON_PROD_AVX512F_MICROKERNEL_SRCS = [
@@ -81,20 +81,32 @@ NON_PROD_AVX512F_MICROKERNEL_SRCS = [
8181
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f-acc2.c",
8282
"src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f-acc2.c",
8383
"src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f.c",
84+
"src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c",
8485
"src/f32-gemm/gen/f32-gemm-4x16-minmax-avx512f-broadcast.c",
86+
"src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c",
8587
"src/f32-gemm/gen/f32-gemm-5x16-minmax-avx512f-broadcast.c",
88+
"src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c",
8689
"src/f32-gemm/gen/f32-gemm-6x16-minmax-avx512f-broadcast.c",
90+
"src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c",
91+
"src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c",
8792
"src/f32-gemm/gen/f32-gemm-8x16-minmax-avx512f-broadcast.c",
93+
"src/f32-gemm/gen/f32-gemm-8x32-minmax-avx512f-broadcast.c",
8894
"src/f32-gemminc/gen/f32-gemminc-1x16-minmax-avx512f-broadcast.c",
8995
"src/f32-gemminc/gen/f32-gemminc-4x16-minmax-avx512f-broadcast.c",
9096
"src/f32-gemminc/gen/f32-gemminc-5x16-minmax-avx512f-broadcast.c",
9197
"src/f32-gemminc/gen/f32-gemminc-6x16-minmax-avx512f-broadcast.c",
9298
"src/f32-gemminc/gen/f32-gemminc-7x16-minmax-avx512f-broadcast.c",
9399
"src/f32-gemminc/gen/f32-gemminc-8x16-minmax-avx512f-broadcast.c",
100+
"src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c",
94101
"src/f32-igemm/gen/f32-igemm-4x16-minmax-avx512f-broadcast.c",
102+
"src/f32-igemm/gen/f32-igemm-4x32-minmax-avx512f-broadcast.c",
95103
"src/f32-igemm/gen/f32-igemm-5x16-minmax-avx512f-broadcast.c",
104+
"src/f32-igemm/gen/f32-igemm-5x32-minmax-avx512f-broadcast.c",
96105
"src/f32-igemm/gen/f32-igemm-6x16-minmax-avx512f-broadcast.c",
106+
"src/f32-igemm/gen/f32-igemm-6x32-minmax-avx512f-broadcast.c",
107+
"src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c",
97108
"src/f32-igemm/gen/f32-igemm-8x16-minmax-avx512f-broadcast.c",
109+
"src/f32-igemm/gen/f32-igemm-8x32-minmax-avx512f-broadcast.c",
98110
"src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc2.c",
99111
"src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc4.c",
100112
"src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64.c",
@@ -276,7 +288,9 @@ NON_PROD_AVX512F_MICROKERNEL_SRCS = [
276288
"src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u32.c",
277289
"src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u48.c",
278290
"src/f32-vtanh/gen/f32-vtanh-avx512f-rational-9-8-div.c",
291+
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c",
279292
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c",
293+
"src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c",
280294
]
281295

282296
ALL_AVX512F_MICROKERNEL_SRCS = PROD_AVX512F_MICROKERNEL_SRCS + NON_PROD_AVX512F_MICROKERNEL_SRCS

scripts/generate-f32-gemm.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,13 @@ tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=7 -D NR=16 -D INC=1 -D DATA
611611
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=8 -D NR=16 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-8x16-minmax-avx512f-broadcast.c &
612612
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=8 -D NR=16 -D INC=1 -D DATATYPE=F32 -o src/f32-gemminc/gen/f32-gemminc-8x16-minmax-avx512f-broadcast.c &
613613

614+
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=1 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c &
615+
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=4 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c &
616+
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=5 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c &
617+
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=6 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c &
618+
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=7 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c &
619+
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=8 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-8x32-minmax-avx512f-broadcast.c &
620+
614621
################################ RISC-V Vector ################################
615622
tools/xngen src/f32-gemm/MRxNRv-rvv.c.in -D MR=7 -D NR=m4 -D ACTIVATION=LINEAR -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-7x4v-rvv.c &
616623
tools/xngen src/f32-gemm/MRxNRv-rvv.c.in -D MR=7 -D NR=m4 -D ACTIVATION=RELU -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-7x4v-relu-rvv.c &

scripts/generate-f32-igemm.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,13 @@ tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=6 -D NR=16 -o src/f32-igem
361361
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=7 -D NR=16 -o src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c &
362362
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=8 -D NR=16 -o src/f32-igemm/gen/f32-igemm-8x16-minmax-avx512f-broadcast.c &
363363

364+
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=1 -D NR=32 -o src/f32-igemm/gen/f32-igemm-1x32-minmax-avx512f-broadcast.c &
365+
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=4 -D NR=32 -o src/f32-igemm/gen/f32-igemm-4x32-minmax-avx512f-broadcast.c &
366+
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=5 -D NR=32 -o src/f32-igemm/gen/f32-igemm-5x32-minmax-avx512f-broadcast.c &
367+
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=6 -D NR=32 -o src/f32-igemm/gen/f32-igemm-6x32-minmax-avx512f-broadcast.c &
368+
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=7 -D NR=32 -o src/f32-igemm/gen/f32-igemm-7x32-minmax-avx512f-broadcast.c &
369+
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=8 -D NR=32 -o src/f32-igemm/gen/f32-igemm-8x32-minmax-avx512f-broadcast.c &
370+
364371
################################ RISC-V Vector ################################
365372
tools/xngen src/f32-igemm/MRxNRv-rvv.c.in -D MR=1 -D NR=m4 -D ACTIVATION=LINEAR -o src/f32-igemm/gen/f32-igemm-1x4v-rvv.c &
366373
tools/xngen src/f32-igemm/MRxNRv-rvv.c.in -D MR=7 -D NR=m4 -D ACTIVATION=LINEAR -o src/f32-igemm/gen/f32-igemm-7x4v-rvv.c &

scripts/generate-x32-packw.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=1 -o src/x32-packw/g
8989
tools/xngen src/x32-packw/avx512.c.in -D NR=16 -D PREFETCH=0 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c &
9090
tools/xngen src/x32-packw/avx512.c.in -D NR=16 -D PREFETCH=1 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c &
9191

92+
tools/xngen src/x32-packw/avx512.c.in -D NR=32 -D PREFETCH=0 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c &
93+
tools/xngen src/x32-packw/avx512.c.in -D NR=32 -D PREFETCH=1 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c &
94+
9295
################################## Wasm SIMD ##################################
9396
### NR multiple of 4
9497
tools/xngen src/x32-packw/wasmsimd.c.in -D NR=8 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x8-gemm-goi-wasmsimd-u4.c &

src/configs/gemm-config.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -677,15 +677,15 @@ static void init_f32_gemm_config(void) {
677677
assert(hardware_config != NULL);
678678
#if XNN_ENABLE_AVX512F
679679
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) {
680-
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast);
681-
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast);
682-
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast);
683-
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast);
680+
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast);
681+
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast);
682+
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x32__avx512f_broadcast);
683+
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x32__avx512f_broadcast);
684684
f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
685685
f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w;
686-
f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx512f_u4_prfm;
686+
f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x32__avx512f_u4_prfm;
687687
f32_gemm_config.mr = 7;
688-
f32_gemm_config.nr = 16;
688+
f32_gemm_config.nr = 32;
689689
} else
690690
#endif
691691
if (hardware_config->use_x86_fma3) {
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Auto-generated file. Do not edit!
2+
// Template: src/f32-gemm/avx512-broadcast.c.in
3+
// Generator: tools/xngen
4+
//
5+
// Copyright 2019 Google LLC
6+
//
7+
// This source code is licensed under the BSD-style license found in the
8+
// LICENSE file in the root directory of this source tree.
9+
10+
#include <assert.h>
11+
12+
#include <immintrin.h>
13+
14+
#include "xnnpack/gemm.h"
15+
#include "xnnpack/intrinsics-polyfill.h"
16+
17+
18+
void xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast(
19+
size_t mr,
20+
size_t nc,
21+
size_t kc,
22+
const float* restrict a,
23+
size_t a_stride,
24+
const float* restrict w,
25+
float* restrict c,
26+
size_t cm_stride,
27+
size_t cn_stride,
28+
const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
29+
{
30+
assert(mr != 0);
31+
assert(mr <= 1);
32+
assert(nc != 0);
33+
assert(kc != 0);
34+
assert(kc % sizeof(float) == 0);
35+
assert(a != NULL);
36+
assert(w != NULL);
37+
assert(c != NULL);
38+
39+
const float* a0 = a;
40+
float* c0 = c;
41+
do {
42+
__m512 vacc0x0 = _mm512_load_ps(w);
43+
__m512 vacc0x1 = _mm512_load_ps(w + 16);
44+
w += 32;
45+
46+
size_t k = kc;
47+
do {
48+
const __m512 vb0 = _mm512_load_ps(w);
49+
const __m512 vb1 = _mm512_loadu_ps(w + 16);
50+
w += 32;
51+
52+
const __m512 va0 = _mm512_set1_ps(*a0);
53+
vacc0x0 = _mm512_fmadd_ps(va0, vb0, vacc0x0);
54+
vacc0x1 = _mm512_fmadd_ps(va0, vb1, vacc0x1);
55+
56+
a0 += 1;
57+
58+
k -= sizeof(float);
59+
} while (k != 0);
60+
61+
const __m512 vmin = _mm512_set1_ps(params->scalar.min);
62+
vacc0x0 = _mm512_max_ps(vmin, vacc0x0);
63+
vacc0x1 = _mm512_max_ps(vmin, vacc0x1);
64+
65+
const __m512 vmax = _mm512_set1_ps(params->scalar.max);
66+
vacc0x0 = _mm512_min_ps(vmax, vacc0x0);
67+
vacc0x1 = _mm512_min_ps(vmax, vacc0x1);
68+
69+
if XNN_LIKELY(nc >= 32) {
70+
_mm512_storeu_ps(c0, vacc0x0);
71+
_mm512_storeu_ps(c0 + 16, vacc0x1);
72+
c0 = (float*) ((uintptr_t) c0 + cn_stride);
73+
74+
a0 = (const float*) ((uintptr_t) a0 - kc);
75+
76+
nc -= 32;
77+
} else {
78+
if (nc & 16) {
79+
_mm512_storeu_ps(c0, vacc0x0);
80+
81+
vacc0x0 = vacc0x1;
82+
83+
c0 += 16;
84+
}
85+
if (nc & 15) {
86+
// Prepare mask for valid 32-bit elements (depends on nc).
87+
const __mmask16 vmask = _cvtu32_mask16((uint32_t) (UINT32_C(1) << (nc & 15)) - UINT32_C(1));
88+
_mm512_mask_storeu_ps(c0, vmask, vacc0x0);
89+
}
90+
nc = 0;
91+
}
92+
} while (nc != 0);
93+
}

0 commit comments

Comments
 (0)