Skip to content

Commit ceb83bd

Browse files
authored
fn dav1d_load_tmvs_sse4: backport x86_64 asm function from dav1d 1.3.0 (#821)
Relates to #811 and #805.
2 parents e347a40 + ffe6a4e commit ceb83bd

File tree

4 files changed

+445
-4
lines changed

4 files changed

+445
-4
lines changed

src/refmvs.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ extern "C" {
4343
col_start8: c_int,
4444
row_start8: c_int,
4545
);
46+
fn dav1d_load_tmvs_sse4(
47+
rf: *const refmvs_frame,
48+
tile_row_idx: c_int,
49+
col_start8: c_int,
50+
col_end8: c_int,
51+
row_start8: c_int,
52+
row_end8: c_int,
53+
);
4654
}
4755

4856
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
@@ -142,6 +150,9 @@ pub struct refmvs_block(pub refmvs_block_unaligned);
142150

143151
#[repr(C)]
144152
pub(crate) struct refmvs_frame {
153+
/// A pointer to a [`refmvs_frame`] may be passed to a [`load_tmvs_fn`] function.
154+
/// However, the [`Self::frm_hdr`] pointer is not accessed in such a function (see [`load_tmvs_c`]).
155+
/// Thus, it is safe to have a pointer to [`Rav1dFrameHeader`] instead of [`Dav1dFrameHeader`] here.
145156
pub frm_hdr: *const Rav1dFrameHeader,
146157
pub iw4: c_int,
147158
pub ih4: c_int,
@@ -1645,8 +1656,14 @@ unsafe fn refmvs_dsp_init_x86(c: *mut Rav1dRefmvsDSPContext) {
16451656

16461657
(*c).save_tmvs = Some(dav1d_save_tmvs_ssse3);
16471658

1659+
if !flags.contains(CpuFlags::SSE41) {
1660+
return;
1661+
}
1662+
16481663
#[cfg(target_arch = "x86_64")]
16491664
{
1665+
(*c).load_tmvs = Some(dav1d_load_tmvs_sse4);
1666+
16501667
if !flags.contains(CpuFlags::AVX2) {
16511668
return;
16521669
}

src/x86/refmvs.asm

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ SECTION_RODATA 64
4747
%endmacro
4848

4949
%if ARCH_X86_64
50+
mv_proj: dw 0, 16384, 8192, 5461, 4096, 3276, 2730, 2340
51+
dw 2048, 1820, 1638, 1489, 1365, 1260, 1170, 1092
52+
dw 1024, 963, 910, 862, 819, 780, 744, 712
53+
dw 682, 655, 630, 606, 585, 564, 546, 528
5054
splat_mv_shuf: db 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3
5155
db 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7
5256
db 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
@@ -61,6 +65,7 @@ cond_shuf512: db 3, 3, 3, 3, 7, 7, 7, 7, 7, 7, 7, 7, 3, 3, 3, 3
6165
save_cond0: db 0x80, 0x81, 0x82, 0x83, 0x89, 0x84, 0x00, 0x00
6266
save_cond1: db 0x84, 0x85, 0x86, 0x87, 0x88, 0x80, 0x00, 0x00
6367
pb_128: times 16 db 128
68+
pq_8192: dq 8192
6469

6570
save_tmvs_ssse3_table: SAVE_TMVS_TABLE 2, 16, ssse3
6671
SAVE_TMVS_TABLE 4, 8, ssse3
@@ -329,6 +334,225 @@ cglobal splat_mv, 4, 5, 3, rr, a, bx4, bw4, bh4
329334
RET
330335

331336
%if ARCH_X86_64
337+
INIT_XMM sse4
338+
; refmvs_frame *rf, int tile_row_idx,
339+
; int col_start8, int col_end8, int row_start8, int row_end8
340+
cglobal load_tmvs, 6, 15, 4, -0x50, rf, tridx, xstart, xend, ystart, yend, \
341+
stride, rp_proj, roff, troff, \
342+
xendi, xstarti, iw8, ih8, dst
343+
xor r14d, r14d
344+
cmp dword [rfq+212], 1 ; n_tile_threads
345+
mov ih8d, [rfq+20] ; rf->ih8
346+
mov iw8d, [rfq+16] ; rf->iw8
347+
mov xstartd, xstartd
348+
mov xendd, xendd
349+
cmove tridxd, r14d
350+
lea xstartid, [xstartq-8]
351+
lea xendid, [xendq+8]
352+
mov strideq, [rfq+184]
353+
mov rp_projq, [rfq+176]
354+
cmp ih8d, yendd
355+
mov [rsp+0x30], strideq
356+
cmovs yendd, ih8d
357+
test xstartid, xstartid
358+
cmovs xstartid, r14d
359+
cmp iw8d, xendid
360+
cmovs xendid, iw8d
361+
mov troffq, strideq
362+
shl troffq, 4
363+
imul troffq, tridxq
364+
mov dstd, ystartd
365+
and dstd, 15
366+
imul dstq, strideq
367+
add dstq, troffq ; (16 * tridx + (ystart & 15)) * stride
368+
lea dstq, [dstq*5]
369+
add dstq, rp_projq
370+
lea troffq, [troffq*5] ; 16 * tridx * stride * 5
371+
lea r13d, [xendq*5]
372+
lea r12, [strideq*5]
373+
DEFINE_ARGS rf, w5, xstart, xend, ystart, yend, h, x5, \
374+
_, troff, xendi, xstarti, stride5, _, dst
375+
lea w5d, [xstartq*5]
376+
add r7, troffq ; rp_proj + tile_row_offset
377+
mov hd, yendd
378+
mov [rsp+0x28], r7
379+
add dstq, r13
380+
sub w5q, r13
381+
sub hd, ystartd
382+
.init_xloop_start:
383+
mov x5q, w5q
384+
test w5b, 1
385+
jz .init_2blk
386+
mov dword [dstq+x5q], 0x80008000
387+
add x5q, 5
388+
jz .init_next_row
389+
.init_2blk:
390+
mov dword [dstq+x5q+0], 0x80008000
391+
mov dword [dstq+x5q+5], 0x80008000
392+
add x5q, 10
393+
jl .init_2blk
394+
.init_next_row:
395+
add dstq, stride5q
396+
dec hd
397+
jg .init_xloop_start
398+
DEFINE_ARGS rf, _, xstart, xend, ystart, yend, n7, stride, \
399+
_, _, xendi, xstarti, stride5, _, n
400+
mov r13d, [rfq+152] ; rf->n_mfmvs
401+
test r13d, r13d
402+
jz .ret
403+
mov [rsp+0x0c], r13d
404+
mov strideq, [rsp+0x30]
405+
movddup m3, [pq_8192]
406+
mov r9d, ystartd
407+
mov [rsp+0x38], yendd
408+
mov [rsp+0x20], xstartid
409+
xor nd, nd
410+
xor n7d, n7d
411+
imul r9, strideq ; ystart * stride
412+
mov [rsp+0x48], rfq
413+
mov [rsp+0x18], stride5q
414+
lea r7, [r9*5]
415+
mov [rsp+0x24], ystartd
416+
mov [rsp+0x00], r7
417+
.nloop:
418+
DEFINE_ARGS y, off, xstart, xend, ystart, rf, n7, refsign, \
419+
ref, rp_ref, xendi, xstarti, _, _, n
420+
mov rfq, [rsp+0x48]
421+
mov refd, [rfq+56+nq*4] ; ref2cur
422+
cmp refd, 0x80000000
423+
je .next_n
424+
mov [rsp+0x40], refd
425+
mov offq, [rsp+0x00] ; ystart * stride * 5
426+
movzx refd, byte [rfq+53+nq] ; rf->mfmv_ref[n]
427+
lea refsignq, [refq-4]
428+
mov rp_refq, [rfq+168]
429+
movq m2, refsignq
430+
add offq, [rp_refq+refq*8] ; r = rp_ref[ref] + row_offset
431+
mov [rsp+0x14], nd
432+
mov yd, ystartd
433+
.yloop:
434+
mov r11d, [rsp+0x24] ; ystart
435+
mov r12d, [rsp+0x38] ; yend
436+
mov r14d, yd
437+
and r14d, ~7 ; y_sb_align
438+
cmp r11d, r14d
439+
cmovs r11d, r14d ; imax(y_sb_align, ystart)
440+
mov [rsp+0x44], r11d ; y_proj_start
441+
add r14d, 8
442+
cmp r12d, r14d
443+
cmovs r14d, r12d ; imin(y_sb_align + 8, yend)
444+
mov [rsp+0x3c], r14d ; y_proj_end
445+
DEFINE_ARGS y, src, xstart, xend, frac, rf, n7, mv, \
446+
ref, x, xendi, mvx, mvy, rb, ref2ref
447+
mov xd, [rsp+0x20] ; xstarti
448+
.xloop:
449+
lea rbd, [xq*5]
450+
add rbq, srcq
451+
movsx refd, byte [rbq+4]
452+
test refd, refd
453+
jz .next_x_bad_ref
454+
mov rfq, [rsp+0x48]
455+
lea r14d, [16+n7q+refq]
456+
mov ref2refd, [rfq+r14*4] ; rf->mfmv_ref2ref[n][b_ref-1]
457+
test ref2refd, ref2refd
458+
jz .next_x_bad_ref
459+
lea fracq, [mv_proj]
460+
movzx fracd, word [fracq+ref2refq*2]
461+
mov mvd, [rbq]
462+
imul fracd, [rsp+0x40] ; ref2cur
463+
pmovsxwq m0, [rbq]
464+
movd m1, fracd
465+
punpcklqdq m1, m1
466+
pmuldq m0, m1 ; mv * frac
467+
pshufd m1, m0, q3311
468+
paddd m0, m3
469+
paddd m0, m1
470+
psrad m0, 14 ; offset = (xy + (xy >> 31) + 8192) >> 14
471+
pabsd m1, m0
472+
packssdw m0, m0
473+
psrld m1, 6
474+
packuswb m1, m1
475+
pxor m0, m2 ; offset ^ ref_sign
476+
psignd m1, m0 ; apply_sign(abs(offset) >> 6, offset ^ refsign)
477+
movq mvxq, m1
478+
lea mvyd, [mvxq+yq] ; ypos
479+
sar mvxq, 32
480+
DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, \
481+
ref, x, xendi, mvx, ypos, rb, ref2ref
482+
cmp yposd, [rsp+0x44] ; y_proj_start
483+
jl .next_x_bad_pos_y
484+
cmp yposd, [rsp+0x3c] ; y_proj_end
485+
jge .next_x_bad_pos_y
486+
and yposd, 15
487+
add mvxq, xq ; xpos
488+
imul yposq, [rsp+0x30] ; pos = (ypos & 15) * stride
489+
DEFINE_ARGS y, src, xstart, xend, dst, _, n7, mv, \
490+
ref, x, xendi, xpos, pos, rb, ref2ref
491+
mov dstq, [rsp+0x28] ; dst = rp_proj + tile_row_offset
492+
add posq, xposq ; pos += xpos
493+
lea posq, [posq*5]
494+
add dstq, posq ; dst += pos5
495+
jmp .write_loop_entry
496+
.write_loop:
497+
add rbq, 5
498+
cmp refb, byte [rbq+4]
499+
jne .xloop
500+
cmp mvd, [rbq]
501+
jne .xloop
502+
add dstq, 5
503+
inc xposd
504+
.write_loop_entry:
505+
mov r12d, xd
506+
and r12d, ~7
507+
lea r5d, [r12-8]
508+
cmp r5d, xstartd
509+
cmovs r5d, xstartd ; x_proj_start
510+
cmp xposd, r5d
511+
jl .next_xpos
512+
add r12d, 16
513+
cmp xendd, r12d
514+
cmovs r12d, xendd ; x_proj_end
515+
cmp xposd, r12d
516+
jge .next_xpos
517+
mov [dstq+0], mvd
518+
mov byte [dstq+4], ref2refb
519+
.next_xpos:
520+
inc xd
521+
cmp xd, xendid
522+
jl .write_loop
523+
.next_y:
524+
DEFINE_ARGS y, src, xstart, xend, ystart, _, n7, _, _, x, xendi, _, _, _, n
525+
add srcq, [rsp+0x18] ; stride5
526+
inc yd
527+
cmp yd, [rsp+0x38] ; yend
528+
jne .yloop
529+
mov nd, [rsp+0x14]
530+
mov ystartd, [rsp+0x24]
531+
.next_n:
532+
add n7d, 7
533+
inc nd
534+
cmp nd, [rsp+0x0c] ; n_mfmvs
535+
jne .nloop
536+
.ret:
537+
RET
538+
.next_x:
539+
DEFINE_ARGS y, src, xstart, xend, _, _, n7, mv, ref, x, xendi, _, _, rb, _
540+
add rbq, 5
541+
cmp refb, byte [rbq+4]
542+
jne .xloop
543+
cmp mvd, [rbq]
544+
jne .xloop
545+
.next_x_bad_pos_y:
546+
inc xd
547+
cmp xd, xendid
548+
jl .next_x
549+
jmp .next_y
550+
.next_x_bad_ref:
551+
inc xd
552+
cmp xd, xendid
553+
jl .xloop
554+
jmp .next_y
555+
332556
INIT_YMM avx2
333557
; refmvs_temporal_block *rp, ptrdiff_t stride,
334558
; refmvs_block **rr, uint8_t *ref_sign,

src/x86/refmvs.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "src/cpu.h"
2929
#include "src/refmvs.h"
3030

31+
decl_load_tmvs_fn(dav1d_load_tmvs_sse4);
32+
3133
decl_save_tmvs_fn(dav1d_save_tmvs_ssse3);
3234
decl_save_tmvs_fn(dav1d_save_tmvs_avx2);
3335
decl_save_tmvs_fn(dav1d_save_tmvs_avx512icl);
@@ -47,7 +49,10 @@ static ALWAYS_INLINE void refmvs_dsp_init_x86(Dav1dRefmvsDSPContext *const c) {
4749

4850
c->save_tmvs = dav1d_save_tmvs_ssse3;
4951

52+
if (!(flags & DAV1D_X86_CPU_FLAG_SSE41)) return;
5053
#if ARCH_X86_64
54+
c->load_tmvs = dav1d_load_tmvs_sse4;
55+
5156
if (!(flags & DAV1D_X86_CPU_FLAG_AVX2)) return;
5257

5358
c->save_tmvs = dav1d_save_tmvs_avx2;

0 commit comments

Comments
 (0)