Skip to content

Commit e74cc77

Browse files
committed
Merge pull request #8081 from NonerKao:dev-alan-contribute-spmm
PiperOrigin-RevId: 747499418
2 parents ab673dd + 7de15ab commit e74cc77

File tree

60 files changed

+15795
-109
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+15795
-109
lines changed

bench/f32-conv-hwc2chw.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "src/xnnpack/buffer.h"
1818
#include "src/xnnpack/common.h"
1919
#include "src/xnnpack/conv.h"
20+
#include "src/xnnpack/hardware-config.h"
2021
#include "src/xnnpack/microfnptr.h"
2122
#include "src/xnnpack/microparams-init.h"
2223
#include "src/xnnpack/pack.h"
@@ -155,6 +156,33 @@ static void f32_conv_hwc2chw_3x3s2p1c3x4__wasmsimd_2x2(benchmark::State& state,
155156
BENCHMARK_DCONV(f32_conv_hwc2chw_3x3s2p1c3x4__wasmsimd_2x2);
156157
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
157158

159+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
160+
static void f32_conv_hwc2chw_3x3s2p1c3x2v__rvv_1x1(benchmark::State& state,
161+
const char* net) {
162+
f32_conv_hwc2chw(state, xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x2v__rvv_1x1,
163+
xnn_init_f32_minmax_scalar_params,
164+
2 * xnn_init_hardware_config()->vlenb / sizeof(float) /* output channel tile */);
165+
}
166+
167+
static void f32_conv_hwc2chw_3x3s2p1c3x2v__rvv_2x1(benchmark::State& state,
168+
const char* net) {
169+
f32_conv_hwc2chw(state, xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x2v__rvv_2x1,
170+
xnn_init_f32_minmax_scalar_params,
171+
2 * xnn_init_hardware_config()->vlenb / sizeof(float) /* output channel tile */);
172+
}
173+
174+
static void f32_conv_hwc2chw_3x3s2p1c3x2v__rvv_2x2(benchmark::State& state,
175+
const char* net) {
176+
f32_conv_hwc2chw(state, xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x2v__rvv_2x2,
177+
xnn_init_f32_minmax_scalar_params,
178+
2 * xnn_init_hardware_config()->vlenb / sizeof(float) /* output channel tile */);
179+
}
180+
181+
BENCHMARK_DCONV(f32_conv_hwc2chw_3x3s2p1c3x2v__rvv_1x1);
182+
BENCHMARK_DCONV(f32_conv_hwc2chw_3x3s2p1c3x2v__rvv_2x1);
183+
BENCHMARK_DCONV(f32_conv_hwc2chw_3x3s2p1c3x2v__rvv_2x2);
184+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
185+
158186
static void f32_conv_hwc2chw_3x3s2p1c3x4__scalar_1x1(benchmark::State& state,
159187
const char* net) {
160188
f32_conv_hwc2chw(state, xnn_f32_conv_hwc2chw_ukernel_3x3s2p1c3x4__scalar_1x1,

bench/f32-dwconv2d-chw.cc

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,6 +2837,156 @@ BENCHMARK_DWCONV(dwconv2d_chw_5x5s2p2__wasmsimd_x86_splat_2x4_acc3)
28372837
BENCHMARK_DWCONV(dwconv2d_chw_5x5s2p2__wasmsimd_x86_splat_3x4_acc2)
28382838
#endif // XNN_ARCH_WASMSIMD || XNN_ARCH_WASMRELAXEDSIMD
28392839

2840+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
2841+
static void dwconv2d_chw_3x3p1__rvv_5x1v(benchmark::State& state,
2842+
const char* net) {
2843+
f32_dwconv2d_chw(state,
2844+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_5x1v,
2845+
xnn_init_f32_minmax_scalar_params,
2846+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2847+
/*stride=*/1);
2848+
}
2849+
static void dwconv2d_chw_3x3p1__rvv_6x1v(benchmark::State& state,
2850+
const char* net) {
2851+
f32_dwconv2d_chw(state,
2852+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_6x1v,
2853+
xnn_init_f32_minmax_scalar_params,
2854+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2855+
/*stride=*/1);
2856+
}
2857+
static void dwconv2d_chw_3x3p1__rvv_7x1v(benchmark::State& state,
2858+
const char* net) {
2859+
f32_dwconv2d_chw(state,
2860+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_7x1v,
2861+
xnn_init_f32_minmax_scalar_params,
2862+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2863+
/*stride=*/1);
2864+
}
2865+
static void dwconv2d_chw_3x3p1__rvv_8x1v(benchmark::State& state,
2866+
const char* net) {
2867+
f32_dwconv2d_chw(state,
2868+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_8x1v,
2869+
xnn_init_f32_minmax_scalar_params,
2870+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2871+
/*stride=*/1);
2872+
}
2873+
static void dwconv2d_chw_3x3p1__rvv_1x2v(benchmark::State& state,
2874+
const char* net) {
2875+
f32_dwconv2d_chw(state,
2876+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_1x2v,
2877+
xnn_init_f32_minmax_scalar_params,
2878+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2879+
/*stride=*/1);
2880+
}
2881+
static void dwconv2d_chw_3x3p1__rvv_2x2v(benchmark::State& state,
2882+
const char* net) {
2883+
f32_dwconv2d_chw(state,
2884+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_2x2v,
2885+
xnn_init_f32_minmax_scalar_params,
2886+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2887+
/*stride=*/1);
2888+
}
2889+
static void dwconv2d_chw_3x3p1__rvv_3x2v(benchmark::State& state,
2890+
const char* net) {
2891+
f32_dwconv2d_chw(state,
2892+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_3x2v,
2893+
xnn_init_f32_minmax_scalar_params,
2894+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2895+
/*stride=*/1);
2896+
}
2897+
static void dwconv2d_chw_3x3p1__rvv_4x2v(benchmark::State& state,
2898+
const char* net) {
2899+
f32_dwconv2d_chw(state,
2900+
xnn_f32_dwconv2d_chw_ukernel_3x3p1__rvv_4x2v,
2901+
xnn_init_f32_minmax_scalar_params,
2902+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2903+
/*stride=*/1);
2904+
}
2905+
2906+
static void dwconv2d_chw_3x3s2p1__rvv_5x1v(benchmark::State& state,
2907+
const char* net) {
2908+
f32_dwconv2d_chw(state,
2909+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_5x1v,
2910+
xnn_init_f32_minmax_scalar_params,
2911+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2912+
/*stride=*/2);
2913+
}
2914+
static void dwconv2d_chw_3x3s2p1__rvv_6x1v(benchmark::State& state,
2915+
const char* net) {
2916+
f32_dwconv2d_chw(state,
2917+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_6x1v,
2918+
xnn_init_f32_minmax_scalar_params,
2919+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2920+
/*stride=*/2);
2921+
}
2922+
static void dwconv2d_chw_3x3s2p1__rvv_7x1v(benchmark::State& state,
2923+
const char* net) {
2924+
f32_dwconv2d_chw(state,
2925+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_7x1v,
2926+
xnn_init_f32_minmax_scalar_params,
2927+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2928+
/*stride=*/2);
2929+
}
2930+
static void dwconv2d_chw_3x3s2p1__rvv_8x1v(benchmark::State& state,
2931+
const char* net) {
2932+
f32_dwconv2d_chw(state,
2933+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_8x1v,
2934+
xnn_init_f32_minmax_scalar_params,
2935+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2936+
/*stride=*/2);
2937+
}
2938+
static void dwconv2d_chw_3x3s2p1__rvv_1x2v(benchmark::State& state,
2939+
const char* net) {
2940+
f32_dwconv2d_chw(state,
2941+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_1x2v,
2942+
xnn_init_f32_minmax_scalar_params,
2943+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2944+
/*stride=*/2);
2945+
}
2946+
static void dwconv2d_chw_3x3s2p1__rvv_2x2v(benchmark::State& state,
2947+
const char* net) {
2948+
f32_dwconv2d_chw(state,
2949+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_2x2v,
2950+
xnn_init_f32_minmax_scalar_params,
2951+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2952+
/*stride=*/2);
2953+
}
2954+
static void dwconv2d_chw_3x3s2p1__rvv_3x2v(benchmark::State& state,
2955+
const char* net) {
2956+
f32_dwconv2d_chw(state,
2957+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_3x2v,
2958+
xnn_init_f32_minmax_scalar_params,
2959+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2960+
/*stride=*/2);
2961+
}
2962+
static void dwconv2d_chw_3x3s2p1__rvv_4x2v(benchmark::State& state,
2963+
const char* net) {
2964+
f32_dwconv2d_chw(state,
2965+
xnn_f32_dwconv2d_chw_ukernel_3x3s2p1__rvv_4x2v,
2966+
xnn_init_f32_minmax_scalar_params,
2967+
/*kernel_height=*/3, /*kernel_width=*/3, /*padding_width=*/1,
2968+
/*stride=*/2);
2969+
}
2970+
2971+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_5x1v)
2972+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_6x1v)
2973+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_7x1v)
2974+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_8x1v)
2975+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_1x2v)
2976+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_2x2v)
2977+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_3x2v)
2978+
BENCHMARK_DWCONV(dwconv2d_chw_3x3p1__rvv_4x2v)
2979+
2980+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_5x1v)
2981+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_6x1v)
2982+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_7x1v)
2983+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_8x1v)
2984+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_1x2v)
2985+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_2x2v)
2986+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_3x2v)
2987+
BENCHMARK_DWCONV(dwconv2d_chw_3x3s2p1__rvv_4x2v)
2988+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
2989+
28402990
static void dwconv2d_chw_3x3p1__scalar_1x1(benchmark::State& state,
28412991
const char* net) {
28422992
f32_dwconv2d_chw(state, xnn_f32_dwconv2d_chw_ukernel_3x3p1__scalar_1x1,

bench/f32-spmm.cc

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,150 @@
232232
#endif // XNN_ENABLE_HVX && XNN_ARCH_HEXAGON
233233

234234

235+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
236+
static void f32_spmm_minmax_ukernel_1vx1__rvv(benchmark::State& state, const char* net) {
237+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_1vx1__rvv, 1 * xnn_init_hardware_config()->vlenb / sizeof(float), 1,
238+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
239+
benchmark::utils::CheckRVV
240+
);
241+
}
242+
243+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_1vx1__rvv)
244+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
245+
246+
247+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
248+
static void f32_spmm_minmax_ukernel_1vx2__rvv(benchmark::State& state, const char* net) {
249+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_1vx2__rvv, 1 * xnn_init_hardware_config()->vlenb / sizeof(float), 2,
250+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
251+
benchmark::utils::CheckRVV
252+
);
253+
}
254+
255+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_1vx2__rvv)
256+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
257+
258+
259+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
260+
static void f32_spmm_minmax_ukernel_1vx4__rvv(benchmark::State& state, const char* net) {
261+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_1vx4__rvv, 1 * xnn_init_hardware_config()->vlenb / sizeof(float), 4,
262+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
263+
benchmark::utils::CheckRVV
264+
);
265+
}
266+
267+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_1vx4__rvv)
268+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
269+
270+
271+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
272+
static void f32_spmm_minmax_ukernel_2vx1__rvv(benchmark::State& state, const char* net) {
273+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_2vx1__rvv, 2 * xnn_init_hardware_config()->vlenb / sizeof(float), 1,
274+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
275+
benchmark::utils::CheckRVV
276+
);
277+
}
278+
279+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_2vx1__rvv)
280+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
281+
282+
283+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
284+
static void f32_spmm_minmax_ukernel_2vx2__rvv(benchmark::State& state, const char* net) {
285+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_2vx2__rvv, 2 * xnn_init_hardware_config()->vlenb / sizeof(float), 2,
286+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
287+
benchmark::utils::CheckRVV
288+
);
289+
}
290+
291+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_2vx2__rvv)
292+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
293+
294+
295+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
296+
static void f32_spmm_minmax_ukernel_2vx4__rvv(benchmark::State& state, const char* net) {
297+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_2vx4__rvv, 2 * xnn_init_hardware_config()->vlenb / sizeof(float), 4,
298+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
299+
benchmark::utils::CheckRVV
300+
);
301+
}
302+
303+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_2vx4__rvv)
304+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
305+
306+
307+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
308+
static void f32_spmm_minmax_ukernel_4vx1__rvv(benchmark::State& state, const char* net) {
309+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_4vx1__rvv, 4 * xnn_init_hardware_config()->vlenb / sizeof(float), 1,
310+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
311+
benchmark::utils::CheckRVV
312+
);
313+
}
314+
315+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_4vx1__rvv)
316+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
317+
318+
319+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
320+
static void f32_spmm_minmax_ukernel_4vx2__rvv(benchmark::State& state, const char* net) {
321+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_4vx2__rvv, 4 * xnn_init_hardware_config()->vlenb / sizeof(float), 2,
322+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
323+
benchmark::utils::CheckRVV
324+
);
325+
}
326+
327+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_4vx2__rvv)
328+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
329+
330+
331+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
332+
static void f32_spmm_minmax_ukernel_4vx4__rvv(benchmark::State& state, const char* net) {
333+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_4vx4__rvv, 4 * xnn_init_hardware_config()->vlenb / sizeof(float), 4,
334+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
335+
benchmark::utils::CheckRVV
336+
);
337+
}
338+
339+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_4vx4__rvv)
340+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
341+
342+
343+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
344+
static void f32_spmm_minmax_ukernel_8vx1__rvv(benchmark::State& state, const char* net) {
345+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_8vx1__rvv, 8 * xnn_init_hardware_config()->vlenb / sizeof(float), 1,
346+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
347+
benchmark::utils::CheckRVV
348+
);
349+
}
350+
351+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_8vx1__rvv)
352+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
353+
354+
355+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
356+
static void f32_spmm_minmax_ukernel_8vx2__rvv(benchmark::State& state, const char* net) {
357+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_8vx2__rvv, 8 * xnn_init_hardware_config()->vlenb / sizeof(float), 2,
358+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
359+
benchmark::utils::CheckRVV
360+
);
361+
}
362+
363+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_8vx2__rvv)
364+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
365+
366+
367+
#if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
368+
static void f32_spmm_minmax_ukernel_8vx4__rvv(benchmark::State& state, const char* net) {
369+
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_8vx4__rvv, 8 * xnn_init_hardware_config()->vlenb / sizeof(float), 4,
370+
/*sparsity=*/0.8f, xnn_init_f32_minmax_scalar_params,
371+
benchmark::utils::CheckRVV
372+
);
373+
}
374+
375+
BENCHMARK_SPMM(f32_spmm_minmax_ukernel_8vx4__rvv)
376+
#endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV
377+
378+
235379
#if XNN_ARCH_WASMRELAXEDSIMD
236380
static void f32_spmm_minmax_ukernel_4x1__wasmrelaxedsimd_arm(benchmark::State& state, const char* net) {
237381
f32_spmm(state, xnn_f32_spmm_minmax_ukernel_4x1__wasmrelaxedsimd_arm, 4, 1,

0 commit comments

Comments
 (0)