Skip to content

Commit

Permalink
fn dav1d_load_tmvs_sse4: backport x86_64 asm function from dav1d …
Browse files Browse the repository at this point in the history
…1.3.0 (#821)

Relates to #811 and #805.
  • Loading branch information
fbossen authored Mar 19, 2024
2 parents e347a40 + ffe6a4e commit ceb83bd
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 4 deletions.
17 changes: 17 additions & 0 deletions src/refmvs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ extern "C" {
col_start8: c_int,
row_start8: c_int,
);
fn dav1d_load_tmvs_sse4(
rf: *const refmvs_frame,
tile_row_idx: c_int,
col_start8: c_int,
col_end8: c_int,
row_start8: c_int,
row_end8: c_int,
);
}

#[cfg(all(feature = "asm", target_arch = "x86_64"))]
Expand Down Expand Up @@ -142,6 +150,9 @@ pub struct refmvs_block(pub refmvs_block_unaligned);

#[repr(C)]
pub(crate) struct refmvs_frame {
/// A pointer to a [`refmvs_frame`] may be passed to a [`load_tmvs_fn`] function.
/// However, the [`Self::frm_hdr`] pointer is not accessed in such a function (see [`load_tmvs_c`]).
/// Thus, it is safe to have a pointer to [`Rav1dFrameHeader`] instead of [`Dav1dFrameHeader`] here.
pub frm_hdr: *const Rav1dFrameHeader,
pub iw4: c_int,
pub ih4: c_int,
Expand Down Expand Up @@ -1645,8 +1656,14 @@ unsafe fn refmvs_dsp_init_x86(c: *mut Rav1dRefmvsDSPContext) {

(*c).save_tmvs = Some(dav1d_save_tmvs_ssse3);

if !flags.contains(CpuFlags::SSE41) {
return;
}

#[cfg(target_arch = "x86_64")]
{
(*c).load_tmvs = Some(dav1d_load_tmvs_sse4);

if !flags.contains(CpuFlags::AVX2) {
return;
}
Expand Down
224 changes: 224 additions & 0 deletions src/x86/refmvs.asm
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ SECTION_RODATA 64
%endmacro

%if ARCH_X86_64
mv_proj: dw 0, 16384, 8192, 5461, 4096, 3276, 2730, 2340
dw 2048, 1820, 1638, 1489, 1365, 1260, 1170, 1092
dw 1024, 963, 910, 862, 819, 780, 744, 712
dw 682, 655, 630, 606, 585, 564, 546, 528
splat_mv_shuf: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3
db 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7
db 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
Expand All @@ -61,6 +65,7 @@ cond_shuf512: db 3, 3, 3, 3, 7, 7, 7, 7, 7, 7, 7, 7, 3, 3, 3, 3
save_cond0: db 0x80, 0x81, 0x82, 0x83, 0x89, 0x84, 0x00, 0x00
save_cond1: db 0x84, 0x85, 0x86, 0x87, 0x88, 0x80, 0x00, 0x00
pb_128: times 16 db 128
pq_8192: dq 8192

save_tmvs_ssse3_table: SAVE_TMVS_TABLE 2, 16, ssse3
SAVE_TMVS_TABLE 4, 8, ssse3
Expand Down Expand Up @@ -329,6 +334,225 @@ cglobal splat_mv, 4, 5, 3, rr, a, bx4, bw4, bh4
RET

%if ARCH_X86_64
INIT_XMM sse4
; refmvs_frame *rf, int tile_row_idx,
; int col_start8, int col_end8, int row_start8, int row_end8
cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \
stride, rp_proj, roff, troff, \
xendi, xstarti, iw8, ih8, dst
xor r14d, r14d
cmp dword [rfq+212], 1 ; n_tile_threads
mov ih8d, [rfq+20] ; rf->ih8
mov iw8d, [rfq+16] ; rf->iw8
mov xstartd, xstartd
mov xendd, xendd
cmove tridxd, r14d
lea xstartid, [xstartq-8]
lea xendid, [xendq+8]
mov strideq, [rfq+184]
mov rp_projq, [rfq+176]
cmp ih8d, yendd
mov [rsp+0x30], strideq
cmovs yendd, ih8d
test xstartid, xstartid
cmovs xstartid, r14d
cmp iw8d, xendid
cmovs xendid, iw8d
mov troffq, strideq
shl troffq, 4
imul troffq, tridxq
mov dstd, ystartd
and dstd, 15
imul dstq, strideq
add dstq, troffq ; (16 * tridx + (ystart & 15)) * stride
lea dstq, [dstq*5]
add dstq, rp_projq
lea troffq, [troffq*5] ; 16 * tridx * stride * 5
lea r13d, [xendq*5]
lea r12, [strideq*5]
DEFINE_ARGS rf, w5, xstart, xend, ystart, yend, h, x5, \
_, troff, xendi, xstarti, stride5, _, dst
lea w5d, [xstartq*5]
add r7, troffq ; rp_proj + tile_row_offset
mov hd, yendd
mov [rsp+0x28], r7
add dstq, r13
sub w5q, r13
sub hd, ystartd
.init_xloop_start:
mov x5q, w5q
test w5b, 1
jz .init_2blk
mov dword [dstq+x5q], 0x80008000
add x5q, 5
jz .init_next_row
.init_2blk:
mov dword [dstq+x5q+0], 0x80008000
mov dword [dstq+x5q+5], 0x80008000
add x5q, 10
jl .init_2blk
.init_next_row:
add dstq, stride5q
dec hd
jg .init_xloop_start
DEFINE_ARGS rf, _, xstart, xend, ystart, yend, n7, stride, \
_, _, xendi, xstarti, stride5, _, n
mov r13d, [rfq+152] ; rf->n_mfmvs
test r13d, r13d
jz .ret
mov [rsp+0x0c], r13d
mov strideq, [rsp+0x30]
movddup m3, [pq_8192]
mov r9d, ystartd
mov [rsp+0x38], yendd
mov [rsp+0x20], xstartid
xor nd, nd
xor n7d, n7d
imul r9, strideq ; ystart * stride
mov [rsp+0x48], rfq
mov [rsp+0x18], stride5q
lea r7, [r9*5]
mov [rsp+0x24], ystartd
mov [rsp+0x00], r7
.nloop:
DEFINE_ARGS y, off, xstart, xend, ystart, rf, n7, refsign, \
ref, rp_ref, xendi, xstarti, _, _, n
mov rfq, [rsp+0x48]
mov refd, [rfq+56+nq*4] ; ref2cur
cmp refd, 0x80000000
je .next_n
mov [rsp+0x40], refd
mov offq, [rsp+0x00] ; ystart * stride * 5
movzx refd, byte [rfq+53+nq] ; rf->mfmv_ref[n]
lea refsignq, [refq-4]
mov rp_refq, [rfq+168]
movq m2, refsignq
add offq, [rp_refq+refq*8] ; r = rp_ref[ref] + row_offset
mov [rsp+0x14], nd
mov yd, ystartd
.yloop:
mov r11d, [rsp+0x24] ; ystart
mov r12d, [rsp+0x38] ; yend
mov r14d, yd
and r14d, ~7 ; y_sb_align
cmp r11d, r14d
cmovs r11d, r14d ; imax(y_sb_align, ystart)
mov [rsp+0x44], r11d ; y_proj_start
add r14d, 8
cmp r12d, r14d
cmovs r14d, r12d ; imin(y_sb_align + 8, yend)
mov [rsp+0x3c], r14d ; y_proj_end
DEFINE_ARGS y, src, xstart, xend, frac, rf, n7, mv, \
ref, x, xendi, mvx, mvy, rb, ref2ref
mov xd, [rsp+0x20] ; xstarti
.xloop:
lea rbd, [xq*5]
add rbq, srcq
movsx refd, byte [rbq+4]
test refd, refd
jz .next_x_bad_ref
mov rfq, [rsp+0x48]
lea r14d, [16+n7q+refq]
mov ref2refd, [rfq+r14*4] ; rf->mfmv_ref2ref[n][b_ref-1]
test ref2refd, ref2refd
jz .next_x_bad_ref
lea fracq, [mv_proj]
movzx fracd, word [fracq+ref2refq*2]
mov mvd, [rbq]
imul fracd, [rsp+0x40] ; ref2cur
pmovsxwq m0, [rbq]
movd m1, fracd
punpcklqdq m1, m1
pmuldq m0, m1 ; mv * frac
pshufd m1, m0, q3311
paddd m0, m3
paddd m0, m1
psrad m0, 14 ; offset = (xy + (xy >> 31) + 8192) >> 14
pabsd m1, m0
packssdw m0, m0
psrld m1, 6
packuswb m1, m1
pxor m0, m2 ; offset ^ ref_sign
psignd m1, m0 ; apply_sign(abs(offset) >> 6, offset ^ refsign)
movq mvxq, m1
lea mvyd, [mvxq+yq] ; ypos
sar mvxq, 32
DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, \
ref, x, xendi, mvx, ypos, rb, ref2ref
cmp yposd, [rsp+0x44] ; y_proj_start
jl .next_x_bad_pos_y
cmp yposd, [rsp+0x3c] ; y_proj_end
jge .next_x_bad_pos_y
and yposd, 15
add mvxq, xq ; xpos
imul yposq, [rsp+0x30] ; pos = (ypos & 15) * stride
DEFINE_ARGS y, src, xstart, xend, dst, _, n7, mv, \
ref, x, xendi, xpos, pos, rb, ref2ref
mov dstq, [rsp+0x28] ; dst = rp_proj + tile_row_offset
add posq, xposq ; pos += xpos
lea posq, [posq*5]
add dstq, posq ; dst += pos5
jmp .write_loop_entry
.write_loop:
add rbq, 5
cmp refb, byte [rbq+4]
jne .xloop
cmp mvd, [rbq]
jne .xloop
add dstq, 5
inc xposd
.write_loop_entry:
mov r12d, xd
and r12d, ~7
lea r5d, [r12-8]
cmp r5d, xstartd
cmovs r5d, xstartd ; x_proj_start
cmp xposd, r5d
jl .next_xpos
add r12d, 16
cmp xendd, r12d
cmovs r12d, xendd ; x_proj_end
cmp xposd, r12d
jge .next_xpos
mov [dstq+0], mvd
mov byte [dstq+4], ref2refb
.next_xpos:
inc xd
cmp xd, xendid
jl .write_loop
.next_y:
DEFINE_ARGS y, src, xstart, xend, ystart, _, n7, _, _, x, xendi, _, _, _, n
add srcq, [rsp+0x18] ; stride5
inc yd
cmp yd, [rsp+0x38] ; yend
jne .yloop
mov nd, [rsp+0x14]
mov ystartd, [rsp+0x24]
.next_n:
add n7d, 7
inc nd
cmp nd, [rsp+0x0c] ; n_mfmvs
jne .nloop
.ret:
RET
.next_x:
DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, ref, x, xendi, _, _, rb, _
add rbq, 5
cmp refb, byte [rbq+4]
jne .xloop
cmp mvd, [rbq]
jne .xloop
.next_x_bad_pos_y:
inc xd
cmp xd, xendid
jl .next_x
jmp .next_y
.next_x_bad_ref:
inc xd
cmp xd, xendid
jl .xloop
jmp .next_y

INIT_YMM avx2
; refmvs_temporal_block *rp, ptrdiff_t stride,
; refmvs_block **rr, uint8_t *ref_sign,
Expand Down
5 changes: 5 additions & 0 deletions src/x86/refmvs.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "src/cpu.h"
#include "src/refmvs.h"

decl_load_tmvs_fn(dav1d_load_tmvs_sse4);

decl_save_tmvs_fn(dav1d_save_tmvs_ssse3);
decl_save_tmvs_fn(dav1d_save_tmvs_avx2);
decl_save_tmvs_fn(dav1d_save_tmvs_avx512icl);
Expand All @@ -47,7 +49,10 @@ static ALWAYS_INLINE void refmvs_dsp_init_x86(Dav1dRefmvsDSPContext *const c) {

c->save_tmvs = dav1d_save_tmvs_ssse3;

if (!(flags & DAV1D_X86_CPU_FLAG_SSE41)) return;
#if ARCH_X86_64
c->load_tmvs = dav1d_load_tmvs_sse4;

if (!(flags & DAV1D_X86_CPU_FLAG_AVX2)) return;

c->save_tmvs = dav1d_save_tmvs_avx2;
Expand Down
Loading

0 comments on commit ceb83bd

Please sign in to comment.