Skip to content

Commit

Permalink
Add x32 avx512 gemm microkernels. These are way faster
Browse files Browse the repository at this point in the history
f32_gemm_minmax_ukernel_1x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time        1029482 ns      1029201 ns         1364 FLOPS=32.5935G/s
f32_gemm_minmax_ukernel_4x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         586536 ns       586291 ns         2407 FLOPS=57.2078G/s
f32_gemm_minmax_ukernel_5x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         541205 ns       541143 ns         2581 FLOPS=61.9994G/s
f32_gemm_minmax_ukernel_6x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         566296 ns       566175 ns         2466 FLOPS=59.2525G/s
f32_gemm_minmax_ukernel_7x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         557455 ns       557275 ns         2464 FLOPS=60.1921G/s
f32_gemm_minmax_ukernel_8x16__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         586344 ns       586240 ns         2396 FLOPS=57.2265G/s
f32_gemm_minmax_ukernel_1x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         568475 ns       568317 ns         2494 FLOPS=59.0254G/s
f32_gemm_minmax_ukernel_4x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         343299 ns       343171 ns         3967 FLOPS=97.7411G/s
f32_gemm_minmax_ukernel_5x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         326208 ns       326123 ns         4317 FLOPS=102.862G/s
f32_gemm_minmax_ukernel_6x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         324746 ns       324675 ns         4317 FLOPS=103.325G/s
f32_gemm_minmax_ukernel_7x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         319434 ns       319344 ns         4337 FLOPS=105.043G/s
f32_gemm_minmax_ukernel_8x32__avx512f_broadcast/selfiesegmentation_full/M:16384/N:32/K:32/real_time         358816 ns       358735 ns         4029 FLOPS=93.5143G/s

PiperOrigin-RevId: 692901687
  • Loading branch information
alankelly authored and xnnpack-bot committed Nov 4, 2024
1 parent 9dc7da7 commit 1fed338
Show file tree
Hide file tree
Showing 37 changed files with 4,979 additions and 24 deletions.
66 changes: 66 additions & 0 deletions bench/f32-gemm-minmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1352,6 +1352,72 @@
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x16__avx512f_broadcast)

static void f32_gemm_minmax_ukernel_1x32__avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/1, /*nr=*/32, /*kr=*/1, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_1x32__avx512f_broadcast)

static void f32_gemm_minmax_ukernel_4x32__avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_4x32__avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/4, /*nr=*/32, /*kr=*/1, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x32__avx512f_broadcast)

static void f32_gemm_minmax_ukernel_5x32__avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_5x32__avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/5, /*nr=*/32, /*kr=*/1, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_5x32__avx512f_broadcast)

static void f32_gemm_minmax_ukernel_6x32__avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_6x32__avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/6, /*nr=*/32, /*kr=*/1, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_6x32__avx512f_broadcast)

static void f32_gemm_minmax_ukernel_7x32__avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/7, /*nr=*/32, /*kr=*/1, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_7x32__avx512f_broadcast)

static void f32_gemm_minmax_ukernel_8x32__avx512f_broadcast(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
xnn_f32_gemm_minmax_ukernel_8x32__avx512f_broadcast,
xnn_init_f32_minmax_scalar_params,
xnn_pack_f32_gemm_goi_w,
/*mr=*/8, /*nr=*/32, /*kr=*/1, /*sr=*/1,
benchmark::utils::CheckAVX512F);
}

BENCHMARK_GEMM(f32_gemm_minmax_ukernel_8x32__avx512f_broadcast)
#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64)


Expand Down
26 changes: 20 additions & 6 deletions cmake/gen/avx512f_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
src/f32-dwconv/gen/f32-dwconv-5f5m5l32c16s1r-minmax-avx512f.c
src/f32-dwconv/gen/f32-dwconv-9p16c-minmax-avx512f.c
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f.c
src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-1x32-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-7x32-minmax-avx512f-broadcast.c
src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr2-p5-u64-acc2.c
src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c
src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c
Expand Down Expand Up @@ -66,7 +66,7 @@ SET(PROD_AVX512F_MICROKERNEL_SRCS
src/f32-vunary/gen/f32-vsqr-avx512f.c
src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c
src/u32-f32-vcvt/gen/u32-f32-vcvt-avx512f.c
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c)
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c)

SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
src/f32-dwconv/gen/f32-dwconv-3p16c-minmax-avx512f-acc2.c
Expand All @@ -84,20 +84,32 @@ SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f-acc2.c
src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f-acc2.c
src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f.c
src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-4x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-5x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-6x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-8x16-minmax-avx512f-broadcast.c
src/f32-gemm/gen/f32-gemm-8x32-minmax-avx512f-broadcast.c
src/f32-gemminc/gen/f32-gemminc-1x16-minmax-avx512f-broadcast.c
src/f32-gemminc/gen/f32-gemminc-4x16-minmax-avx512f-broadcast.c
src/f32-gemminc/gen/f32-gemminc-5x16-minmax-avx512f-broadcast.c
src/f32-gemminc/gen/f32-gemminc-6x16-minmax-avx512f-broadcast.c
src/f32-gemminc/gen/f32-gemminc-7x16-minmax-avx512f-broadcast.c
src/f32-gemminc/gen/f32-gemminc-8x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-4x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-4x32-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-5x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-5x32-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-6x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-6x32-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-8x16-minmax-avx512f-broadcast.c
src/f32-igemm/gen/f32-igemm-8x32-minmax-avx512f-broadcast.c
src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc2.c
src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc4.c
src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64.c
Expand Down Expand Up @@ -279,6 +291,8 @@ SET(NON_PROD_AVX512F_MICROKERNEL_SRCS
src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u32.c
src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u48.c
src/f32-vtanh/gen/f32-vtanh-avx512f-rational-9-8-div.c
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c)
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c
src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c
src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c)

SET(ALL_AVX512F_MICROKERNEL_SRCS ${PROD_AVX512F_MICROKERNEL_SRCS} + ${NON_PROD_AVX512F_MICROKERNEL_SRCS})
24 changes: 19 additions & 5 deletions gen/avx512f_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ PROD_AVX512F_MICROKERNEL_SRCS = [
"src/f32-dwconv/gen/f32-dwconv-5f5m5l32c16s1r-minmax-avx512f.c",
"src/f32-dwconv/gen/f32-dwconv-9p16c-minmax-avx512f.c",
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f.c",
"src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-1x32-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-7x32-minmax-avx512f-broadcast.c",
"src/f32-raddstoreexpminusmax/gen/f32-raddstoreexpminusmax-avx512f-rr2-p5-u64-acc2.c",
"src/f32-rdsum/gen/f32-rdsum-7p7x-minmax-avx512f-c64.c",
"src/f32-rminmax/gen/f32-rmax-avx512f-u64-acc4.c",
Expand Down Expand Up @@ -62,7 +62,7 @@ PROD_AVX512F_MICROKERNEL_SRCS = [
"src/f32-vunary/gen/f32-vsqr-avx512f.c",
"src/s32-f32-vcvt/gen/s32-f32-vcvt-avx512f.c",
"src/u32-f32-vcvt/gen/u32-f32-vcvt-avx512f.c",
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c",
"src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c",
]

NON_PROD_AVX512F_MICROKERNEL_SRCS = [
Expand All @@ -81,20 +81,32 @@ NON_PROD_AVX512F_MICROKERNEL_SRCS = [
"src/f32-dwconv/gen/f32-dwconv-25p16c-minmax-avx512f-acc2.c",
"src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f-acc2.c",
"src/f32-dwconv/gen/f32-dwconv-25p32c-minmax-avx512f.c",
"src/f32-gemm/gen/f32-gemm-1x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-4x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-5x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-6x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-7x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-8x16-minmax-avx512f-broadcast.c",
"src/f32-gemm/gen/f32-gemm-8x32-minmax-avx512f-broadcast.c",
"src/f32-gemminc/gen/f32-gemminc-1x16-minmax-avx512f-broadcast.c",
"src/f32-gemminc/gen/f32-gemminc-4x16-minmax-avx512f-broadcast.c",
"src/f32-gemminc/gen/f32-gemminc-5x16-minmax-avx512f-broadcast.c",
"src/f32-gemminc/gen/f32-gemminc-6x16-minmax-avx512f-broadcast.c",
"src/f32-gemminc/gen/f32-gemminc-7x16-minmax-avx512f-broadcast.c",
"src/f32-gemminc/gen/f32-gemminc-8x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-1x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-4x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-4x32-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-5x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-5x32-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-6x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-6x32-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-8x16-minmax-avx512f-broadcast.c",
"src/f32-igemm/gen/f32-igemm-8x32-minmax-avx512f-broadcast.c",
"src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc2.c",
"src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64-acc4.c",
"src/f32-raddexpminusmax/gen/f32-raddexpminusmax-avx512f-p5-scalef-u64.c",
Expand Down Expand Up @@ -276,7 +288,9 @@ NON_PROD_AVX512F_MICROKERNEL_SRCS = [
"src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u32.c",
"src/f32-vsqrt/gen/f32-vsqrt-avx512f-rsqrt-u48.c",
"src/f32-vtanh/gen/f32-vtanh-avx512f-rational-9-8-div.c",
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c",
"src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c",
"src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c",
]

ALL_AVX512F_MICROKERNEL_SRCS = PROD_AVX512F_MICROKERNEL_SRCS + NON_PROD_AVX512F_MICROKERNEL_SRCS
7 changes: 7 additions & 0 deletions scripts/generate-f32-gemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,13 @@ tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=7 -D NR=16 -D INC=1 -D DATA
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=8 -D NR=16 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-8x16-minmax-avx512f-broadcast.c &
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=8 -D NR=16 -D INC=1 -D DATATYPE=F32 -o src/f32-gemminc/gen/f32-gemminc-8x16-minmax-avx512f-broadcast.c &

tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=1 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=4 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-4x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=5 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-5x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=6 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-6x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=7 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-7x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-gemm/avx512-broadcast.c.in -D MR=8 -D NR=32 -D INC=0 -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-8x32-minmax-avx512f-broadcast.c &

################################ RISC-V Vector ################################
tools/xngen src/f32-gemm/MRxNRv-rvv.c.in -D MR=7 -D NR=m4 -D ACTIVATION=LINEAR -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-7x4v-rvv.c &
tools/xngen src/f32-gemm/MRxNRv-rvv.c.in -D MR=7 -D NR=m4 -D ACTIVATION=RELU -D DATATYPE=F32 -o src/f32-gemm/gen/f32-gemm-7x4v-relu-rvv.c &
Expand Down
7 changes: 7 additions & 0 deletions scripts/generate-f32-igemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=6 -D NR=16 -o src/f32-igem
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=7 -D NR=16 -o src/f32-igemm/gen/f32-igemm-7x16-minmax-avx512f-broadcast.c &
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=8 -D NR=16 -o src/f32-igemm/gen/f32-igemm-8x16-minmax-avx512f-broadcast.c &

tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=1 -D NR=32 -o src/f32-igemm/gen/f32-igemm-1x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=4 -D NR=32 -o src/f32-igemm/gen/f32-igemm-4x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=5 -D NR=32 -o src/f32-igemm/gen/f32-igemm-5x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=6 -D NR=32 -o src/f32-igemm/gen/f32-igemm-6x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=7 -D NR=32 -o src/f32-igemm/gen/f32-igemm-7x32-minmax-avx512f-broadcast.c &
tools/xngen src/f32-igemm/avx512-broadcast.c.in -D MR=8 -D NR=32 -o src/f32-igemm/gen/f32-igemm-8x32-minmax-avx512f-broadcast.c &

################################ RISC-V Vector ################################
tools/xngen src/f32-igemm/MRxNRv-rvv.c.in -D MR=1 -D NR=m4 -D ACTIVATION=LINEAR -o src/f32-igemm/gen/f32-igemm-1x4v-rvv.c &
tools/xngen src/f32-igemm/MRxNRv-rvv.c.in -D MR=7 -D NR=m4 -D ACTIVATION=LINEAR -o src/f32-igemm/gen/f32-igemm-7x4v-rvv.c &
Expand Down
3 changes: 3 additions & 0 deletions scripts/generate-x32-packw.sh
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=1 -o src/x32-packw/g
tools/xngen src/x32-packw/avx512.c.in -D NR=16 -D PREFETCH=0 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4.c &
tools/xngen src/x32-packw/avx512.c.in -D NR=16 -D PREFETCH=1 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x16-gemm-goi-avx512f-u4-prfm.c &

tools/xngen src/x32-packw/avx512.c.in -D NR=32 -D PREFETCH=0 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4.c &
tools/xngen src/x32-packw/avx512.c.in -D NR=32 -D PREFETCH=1 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x32-gemm-goi-avx512f-u4-prfm.c &

################################## Wasm SIMD ##################################
### NR multiple of 4
tools/xngen src/x32-packw/wasmsimd.c.in -D NR=8 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x8-gemm-goi-wasmsimd-u4.c &
Expand Down
12 changes: 6 additions & 6 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,15 @@ static void init_f32_gemm_config(void) {
assert(hardware_config != NULL);
#if XNN_ENABLE_AVX512F
if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512f) {
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x16__avx512f_broadcast);
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x16__avx512f_broadcast);
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x16__avx512f_broadcast);
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x16__avx512f_broadcast);
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast);
f32_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_f32_gemm_minmax_ukernel_7x32__avx512f_broadcast);
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_1x32__avx512f_broadcast);
f32_gemm_config.minmax.igemm[XNN_MR_TO_INDEX(7)] = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_fn) xnn_f32_igemm_minmax_ukernel_7x32__avx512f_broadcast);
f32_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params;
f32_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_f32_gemm_gio_w;
f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x16__avx512f_u4_prfm;
f32_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_x32_packw_gemm_goi_ukernel_x32__avx512f_u4_prfm;
f32_gemm_config.mr = 7;
f32_gemm_config.nr = 16;
f32_gemm_config.nr = 32;
} else
#endif
if (hardware_config->use_x86_fma3) {
Expand Down
93 changes: 93 additions & 0 deletions src/f32-gemm/gen/f32-gemm-1x32-minmax-avx512f-broadcast.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Auto-generated file. Do not edit!
// Template: src/f32-gemm/avx512-broadcast.c.in
// Generator: tools/xngen
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include <assert.h>

#include <immintrin.h>

#include "xnnpack/gemm.h"
#include "xnnpack/intrinsics-polyfill.h"


void xnn_f32_gemm_minmax_ukernel_1x32__avx512f_broadcast(
size_t mr,
size_t nc,
size_t kc,
const float* restrict a,
size_t a_stride,
const float* restrict w,
float* restrict c,
size_t cm_stride,
size_t cn_stride,
const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
{
assert(mr != 0);
assert(mr <= 1);
assert(nc != 0);
assert(kc != 0);
assert(kc % sizeof(float) == 0);
assert(a != NULL);
assert(w != NULL);
assert(c != NULL);

const float* a0 = a;
float* c0 = c;
do {
__m512 vacc0x0 = _mm512_load_ps(w);
__m512 vacc0x1 = _mm512_load_ps(w + 16);
w += 32;

size_t k = kc;
do {
const __m512 vb0 = _mm512_load_ps(w);
const __m512 vb1 = _mm512_loadu_ps(w + 16);
w += 32;

const __m512 va0 = _mm512_set1_ps(*a0);
vacc0x0 = _mm512_fmadd_ps(va0, vb0, vacc0x0);
vacc0x1 = _mm512_fmadd_ps(va0, vb1, vacc0x1);

a0 += 1;

k -= sizeof(float);
} while (k != 0);

const __m512 vmin = _mm512_set1_ps(params->scalar.min);
vacc0x0 = _mm512_max_ps(vmin, vacc0x0);
vacc0x1 = _mm512_max_ps(vmin, vacc0x1);

const __m512 vmax = _mm512_set1_ps(params->scalar.max);
vacc0x0 = _mm512_min_ps(vmax, vacc0x0);
vacc0x1 = _mm512_min_ps(vmax, vacc0x1);

if XNN_LIKELY(nc >= 32) {
_mm512_storeu_ps(c0, vacc0x0);
_mm512_storeu_ps(c0 + 16, vacc0x1);
c0 = (float*) ((uintptr_t) c0 + cn_stride);

a0 = (const float*) ((uintptr_t) a0 - kc);

nc -= 32;
} else {
if (nc & 16) {
_mm512_storeu_ps(c0, vacc0x0);

vacc0x0 = vacc0x1;

c0 += 16;
}
if (nc & 15) {
// Prepare mask for valid 32-bit elements (depends on nc).
const __mmask16 vmask = _cvtu32_mask16((uint32_t) (UINT32_C(1) << (nc & 15)) - UINT32_C(1));
_mm512_mask_storeu_ps(c0, vmask, vacc0x0);
}
nc = 0;
}
} while (nc != 0);
}
Loading

0 comments on commit 1fed338

Please sign in to comment.