Skip to content

Commit

Permalink
Fix naming of mask
Browse files Browse the repository at this point in the history
  • Loading branch information
fanchenkong1 committed Oct 15, 2024
1 parent 695903a commit 087bdfc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
22 changes: 10 additions & 12 deletions src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t all_one = wasm_i32x4_splat(-1);
const v128_t mask = wasm_u64x2_shr(all_one, (8 - k) * sizeof(int8_t) * 8);
const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

v128_t v01 = wasm_v128_load64_splat(w0);
v01 = wasm_i64x2_shuffle(v01, wasm_v128_load64_splat(w1), 0, 3);
v01 = wasm_v128_and(v01, mask);
v01 = wasm_v128_and(v01, vmask);
v128_t v23 = wasm_v128_load64_splat(w2);
v23 = wasm_i64x2_shuffle(v23, wasm_v128_load64_splat(w3), 0, 3);
v23 = wasm_v128_and(v23, mask);
v23 = wasm_v128_and(v23, vmask);
v128_t v45 = wasm_v128_load64_splat(w4);
v45 = wasm_i64x2_shuffle(v45, wasm_v128_load64_splat(w5), 0, 3);
v45 = wasm_v128_and(v45, mask);
v45 = wasm_v128_and(v45, vmask);
v128_t v67 = wasm_v128_load64_splat(w6);
v67 = wasm_i64x2_shuffle(v67, wasm_v128_load64_splat(w7), 0, 3);
v67 = wasm_v128_and(v67, mask);
v67 = wasm_v128_and(v67, vmask);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand Down Expand Up @@ -308,21 +307,20 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t all_one = wasm_i32x4_splat(-1);
const v128_t mask = wasm_u64x2_shr(all_one, (8 - k) * sizeof(int8_t) * 8);
const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

v128_t v01 = wasm_v128_load64_splat(w0);
v01 = wasm_i64x2_shuffle(v01, wasm_v128_load64_splat(w1), 0, 3);
v01 = wasm_v128_and(v01, mask);
v01 = wasm_v128_and(v01, vmask);
v128_t v23 = wasm_v128_load64_splat(w2);
v23 = wasm_i64x2_shuffle(v23, wasm_v128_load64_splat(w3), 0, 3);
v23 = wasm_v128_and(v23, mask);
v23 = wasm_v128_and(v23, vmask);
v128_t v45 = wasm_v128_load64_splat(w4);
v45 = wasm_i64x2_shuffle(v45, wasm_v128_load64_splat(w5), 0, 3);
v45 = wasm_v128_and(v45, mask);
v45 = wasm_v128_and(v45, vmask);
v128_t v67 = wasm_v128_load64_splat(w6);
v67 = wasm_i64x2_shuffle(v67, wasm_v128_load64_splat(w7), 0, 3);
v67 = wasm_v128_and(v67, mask);
v67 = wasm_v128_and(v67, vmask);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand Down
22 changes: 10 additions & 12 deletions src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t all_one = wasm_i32x4_splat(-1);
const v128_t mask = wasm_u64x2_shr(all_one, (8 - k) * sizeof(int8_t) * 8);
const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

v128_t v01 = wasm_v128_load64_splat(w0);
v01 = wasm_i64x2_shuffle(v01, wasm_v128_load64_splat(w1), 0, 3);
v01 = wasm_v128_and(v01, mask);
v01 = wasm_v128_and(v01, vmask);
v128_t v23 = wasm_v128_load64_splat(w2);
v23 = wasm_i64x2_shuffle(v23, wasm_v128_load64_splat(w3), 0, 3);
v23 = wasm_v128_and(v23, mask);
v23 = wasm_v128_and(v23, vmask);
v128_t v45 = wasm_v128_load64_splat(w4);
v45 = wasm_i64x2_shuffle(v45, wasm_v128_load64_splat(w5), 0, 3);
v45 = wasm_v128_and(v45, mask);
v45 = wasm_v128_and(v45, vmask);
v128_t v67 = wasm_v128_load64_splat(w6);
v67 = wasm_i64x2_shuffle(v67, wasm_v128_load64_splat(w7), 0, 3);
v67 = wasm_v128_and(v67, mask);
v67 = wasm_v128_and(v67, vmask);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand Down Expand Up @@ -308,21 +307,20 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t all_one = wasm_i32x4_splat(-1);
const v128_t mask = wasm_u64x2_shr(all_one, (8 - k) * sizeof(int8_t) * 8);
const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

v128_t v01 = wasm_v128_load64_splat(w0);
v01 = wasm_i64x2_shuffle(v01, wasm_v128_load64_splat(w1), 0, 3);
v01 = wasm_v128_and(v01, mask);
v01 = wasm_v128_and(v01, vmask);
v128_t v23 = wasm_v128_load64_splat(w2);
v23 = wasm_i64x2_shuffle(v23, wasm_v128_load64_splat(w3), 0, 3);
v23 = wasm_v128_and(v23, mask);
v23 = wasm_v128_and(v23, vmask);
v128_t v45 = wasm_v128_load64_splat(w4);
v45 = wasm_i64x2_shuffle(v45, wasm_v128_load64_splat(w5), 0, 3);
v45 = wasm_v128_and(v45, mask);
v45 = wasm_v128_and(v45, vmask);
v128_t v67 = wasm_v128_load64_splat(w6);
v67 = wasm_i64x2_shuffle(v67, wasm_v128_load64_splat(w7), 0, 3);
v67 = wasm_v128_and(v67, mask);
v67 = wasm_v128_and(v67, vmask);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand Down
10 changes: 4 additions & 6 deletions src/x8-packw/kr-wasmdot.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,12 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K
if (k != 0) {
assert(k >= 1 && k <= ${KR-1});

const v128_t all_one = wasm_i32x4_splat(-1);
const v128_t mask = wasm_u64x2_shr(all_one, (${KR} - k) * sizeof(${WTYPE}) * 8);
const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8);

$for N in range(0, NR, 2):
v128_t v${ABC[N:N+2]} = wasm_v128_load64_splat(w${N});
v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${ABC[N:N+2]}, wasm_v128_load64_splat(w${N+1}), 0, 3);
v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, mask);
v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask);

$for N in range(0, NR, 2):
vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});
Expand Down Expand Up @@ -209,13 +208,12 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K
if (k != 0) {
assert(k >= 1 && k <= ${KR-1});

const v128_t all_one = wasm_i32x4_splat(-1);
const v128_t mask = wasm_u64x2_shr(all_one, (${KR} - k) * sizeof(${WTYPE}) * 8);
const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8);

$for N in range(0, NR, 2):
v128_t v${ABC[N:N+2]} = wasm_v128_load64_splat(w${N});
v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${ABC[N:N+2]}, wasm_v128_load64_splat(w${N+1}), 0, 3);
v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, mask);
v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask);

$for N in range(0, NR, 2):
vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});
Expand Down

0 comments on commit 087bdfc

Please sign in to comment.