Skip to content

Commit

Permalink
fn loopfilter_sb: Make safe (#1269)
Browse files Browse the repository at this point in the history
* Fixes #848.
* Fixes #855.
* Fixes #841.

This does add back some bounds checks since the only and final slicing
is done in `fn loopfilter_sb::Fn::call`, which is inside inner loops.
@fbossen, how does this look for perf?

Only 2 `unsafe` ops left (after #1239 is merged, too)! Though still 122
`unsafe` ops in `neon` `fn`s.

## Addendum

So it turns out the `- b4strideb` in `fn loop_filter_sb128_rust` makes
the index negative, so the initial slice needs to be larger. This is
difficult to safely do, though, because all of the other uses of
`f.lf.level` do not go through `DisjointMut`'s safe APIs since there
were overlapping `&mut`s since indices 0 and 1 are written
simultaneously to 2 and 3, and thus were done `unsafe`ly. This loses the
disjointedness checking, though, which is a problem for making other
changes.

Thus, this PR leaves the `- b4strideb` as an `unsafe`
`lvl.as_ptr().sub(b4strideb)` for now. In a follow-up PR, I'm going to
change the `[u8; 4]` `lvl` elements to just `u8`s. This will allow the
disjoint writes to `[0..2]` and `[2..4]` to be done safely with
`DisjointMut`s APIs and will remove the need for `fn
unaligned_lvl_slice`.

Fixed in #1273 now.
  • Loading branch information
kkysen authored Jul 1, 2024
2 parents d9bc0e3 + 30b4456 commit 7e0a9d8
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 68 deletions.
4 changes: 1 addition & 3 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4797,9 +4797,7 @@ fn rav1d_decode_frame_main(c: &Rav1dContext, f: &mut Rav1dFrameData) -> Rav1dRes
}

// loopfilter + cdef + restoration
//
// SAFETY: Function call with all safe args, will be marked safe.
unsafe { (f.bd_fn().filter_sbrow)(c, f, &mut t, sby) };
(f.bd_fn().filter_sbrow)(c, f, &mut t, sby);
}
}

Expand Down
69 changes: 40 additions & 29 deletions src/lf_apply.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![deny(unsafe_op_in_unsafe_fn)]

use crate::include::common::bitdepth::BitDepth;
use crate::include::dav1d::headers::Rav1dFrameHeader;
use crate::include::dav1d::headers::Rav1dPixelLayout;
Expand Down Expand Up @@ -365,10 +367,10 @@ pub(crate) fn rav1d_copy_lpf<BD: BitDepth>(
}

#[inline]
unsafe fn filter_plane_cols_y<BD: BitDepth>(
fn filter_plane_cols_y<BD: BitDepth>(
f: &Rav1dFrameData,
have_left: bool,
lvl: &[[u8; 4]],
lvl: WithOffset<&[[u8; 4]]>,
mask: &[[[RelaxedAtomic<u16>; 2]; 3]; 32],
y_dst: Rav1dPictureDataComponentOffset,
w: usize,
Expand Down Expand Up @@ -396,16 +398,16 @@ unsafe fn filter_plane_cols_y<BD: BitDepth>(
} else {
mask.each_ref().map(|[_, b]| b.get() as u32)
};
let lvl = &lvl[x..];
let lvl = lvl + x;
lf_sb.y.h.call::<BD>(f, y_dst(x), &hmask, lvl, 0, len);
}
}

#[inline]
unsafe fn filter_plane_rows_y<BD: BitDepth>(
fn filter_plane_rows_y<BD: BitDepth>(
f: &Rav1dFrameData,
have_top: bool,
lvl: &[[u8; 4]],
lvl: WithOffset<&[[u8; 4]]>,
b4_stride: usize,
mask: &[[[RelaxedAtomic<u16>; 2]; 3]; 32],
y_dst: Rav1dPictureDataComponentOffset,
Expand All @@ -420,7 +422,6 @@ unsafe fn filter_plane_rows_y<BD: BitDepth>(
let len = endy4 - starty4;
let y_dst = |i| y_dst + (i as isize * 4 * y_dst.pixel_stride::<BD>());
y_dst(len - 1).as_ptr::<BD>(); // Bounds check
let lvl = &lvl[..1 + (len - 1) * b4_stride];
for i in 0..len {
let y = i + starty4;
if !have_top && y == 0 {
Expand All @@ -429,16 +430,16 @@ unsafe fn filter_plane_rows_y<BD: BitDepth>(
let vmask = mask[y % mask.len()] // To elide the bounds check.
.each_ref()
.map(|[a, b]| a.get() as u32 | ((b.get() as u32) << 16));
let lvl = &lvl[i * b4_stride..];
let lvl = lvl + i * b4_stride;
lf_sb.y.v.call::<BD>(f, y_dst(i), &vmask, lvl, 1, w);
}
}

#[inline]
unsafe fn filter_plane_cols_uv<BD: BitDepth>(
fn filter_plane_cols_uv<BD: BitDepth>(
f: &Rav1dFrameData,
have_left: bool,
lvl: &[[u8; 4]],
lvl: WithOffset<&[[u8; 4]]>,
mask: &[[[RelaxedAtomic<u16>; 2]; 2]; 32],
u_dst: Rav1dPictureDataComponentOffset,
v_dst: Rav1dPictureDataComponentOffset,
Expand All @@ -455,7 +456,6 @@ unsafe fn filter_plane_cols_uv<BD: BitDepth>(
u_dst(w).as_ptr::<BD>(); // Bounds check
v_dst(w).as_ptr::<BD>(); // Bounds check
let mask = &mask[..w];
let lvl = &lvl[..w];
for x in 0..w {
if !have_left && x == 0 {
continue;
Expand All @@ -472,17 +472,17 @@ unsafe fn filter_plane_cols_uv<BD: BitDepth>(
mask.each_ref().map(|[_, b]| b.get() as u32)
};
let hmask = [hmask[0], hmask[1], 0];
let lvl = &lvl[x..];
let lvl = lvl + x;
lf_sb.uv.h.call::<BD>(f, u_dst(x), &hmask, lvl, 2, len);
lf_sb.uv.h.call::<BD>(f, v_dst(x), &hmask, lvl, 3, len);
}
}

#[inline]
unsafe fn filter_plane_rows_uv<BD: BitDepth>(
fn filter_plane_rows_uv<BD: BitDepth>(
f: &Rav1dFrameData,
have_top: bool,
lvl: &[[u8; 4]],
lvl: WithOffset<&[[u8; 4]]>,
b4_stride: usize,
mask: &[[[RelaxedAtomic<u16>; 2]; 2]; 32],
u_dst: Rav1dPictureDataComponentOffset,
Expand All @@ -501,7 +501,6 @@ unsafe fn filter_plane_rows_uv<BD: BitDepth>(
let v_dst = |i| v_dst + (i as isize * 4 * v_dst.pixel_stride::<BD>());
u_dst(len - 1).as_ptr::<BD>(); // Bounds check
v_dst(len - 1).as_ptr::<BD>(); // Bounds check
let lvl = &lvl[..1 + (len - 1) * b4_stride];
for i in 0..len {
let y = i + starty4;
if !have_top && y == 0 {
Expand All @@ -511,13 +510,13 @@ unsafe fn filter_plane_rows_uv<BD: BitDepth>(
.each_ref()
.map(|[a, b]| a.get() as u32 | ((b.get() as u32) << (16 >> ss_hor)));
let vmask = [vmask[0], vmask[1], 0];
let lvl = &lvl[i * b4_stride..];
let lvl = lvl + i * b4_stride;
lf_sb.uv.v.call::<BD>(f, u_dst(i), &vmask, lvl, 2, w);
lf_sb.uv.v.call::<BD>(f, v_dst(i), &vmask, lvl, 3, w);
}
}

pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
pub(crate) fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
f: &Rav1dFrameData,
[py, pu, pv]: [Rav1dPictureDataComponentOffset; 3],
lflvl_offset: usize,
Expand Down Expand Up @@ -625,38 +624,45 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
}
}
let lflvl = &f.lf.mask[lflvl_offset..];
let mut level_ptr = &*f
let lvl = &*f
.lf
.level
.index((f.b4_stride * sby as isize * sbsz as isize) as usize..);
let lvl = WithOffset {
data: lvl,
offset: 0,
};
have_left = false;
for x in 0..f.sb128w as usize {
filter_plane_cols_y::<BD>(
f,
have_left,
level_ptr,
lvl + x * 32,
&lflvl[x].filter_y[0],
py + x * 128,
cmp::min(32, f.w4 - x as c_int * 32) as usize,
starty4 as usize,
endy4 as usize,
);
have_left = true;
level_ptr = &level_ptr[32..];
}
if frame_hdr.loopfilter.level_u == 0 && frame_hdr.loopfilter.level_v == 0 {
return;
}
let mut level_ptr = &*f
let lvl = &*f
.lf
.level
.index((f.b4_stride * (sby * sbsz >> ss_ver) as isize) as usize..);
let lvl = WithOffset {
data: lvl,
offset: 0,
};
have_left = false;
for x in 0..f.sb128w as usize {
filter_plane_cols_uv::<BD>(
f,
have_left,
level_ptr,
lvl + x * (32 >> ss_hor),
&lflvl[x].filter_uv[0],
pu + x * (128 >> ss_hor),
pv + x * (128 >> ss_hor),
Expand All @@ -666,11 +672,10 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
ss_ver,
);
have_left = true;
level_ptr = &level_ptr[32 >> ss_hor..];
}
}

pub(crate) unsafe fn rav1d_loopfilter_sbrow_rows<BD: BitDepth>(
pub(crate) fn rav1d_loopfilter_sbrow_rows<BD: BitDepth>(
f: &Rav1dFrameData,
p: [Rav1dPictureDataComponentOffset; 3],
lflvl_offset: usize,
Expand All @@ -689,40 +694,47 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_rows<BD: BitDepth>(
let endy4: c_uint = (starty4 + cmp::min(f.h4 - sby * sbsz, sbsz)) as c_uint;
let uv_endy4: c_uint = endy4.wrapping_add(ss_ver as c_uint) >> ss_ver;

let mut level_ptr = &*f
let lvl = &*f
.lf
.level
.index((f.b4_stride * sby as isize * sbsz as isize) as usize..);
let lvl = WithOffset {
data: lvl,
offset: 0,
};
for x in 0..f.sb128w as usize {
filter_plane_rows_y::<BD>(
f,
have_top,
level_ptr,
lvl + x * 32,
f.b4_stride as usize,
&lflvl[x].filter_y[1],
p[0] + 128 * x,
cmp::min(32, f.w4 - x as c_int * 32) as usize,
starty4 as usize,
endy4 as usize,
);
level_ptr = &level_ptr[32..];
}

let frame_hdr = &***f.frame_hdr.as_ref().unwrap();
if frame_hdr.loopfilter.level_u == 0 && frame_hdr.loopfilter.level_v == 0 {
return;
}

let mut level_ptr = &*f
let lvl = &*f
.lf
.level
.index((f.b4_stride * (sby * sbsz >> ss_ver) as isize) as usize..);
let lvl = WithOffset {
data: lvl,
offset: 0,
};
let [_, pu, pv] = p;
for x in 0..f.sb128w as usize {
filter_plane_rows_uv::<BD>(
f,
have_top,
level_ptr,
lvl + x * (32 >> ss_hor),
f.b4_stride as usize,
&lflvl[x].filter_uv[1],
pu + (x * 128 >> ss_hor),
Expand All @@ -732,6 +744,5 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_rows<BD: BitDepth>(
uv_endy4 as usize,
ss_hor,
);
level_ptr = &level_ptr[32 >> ss_hor..];
}
}
42 changes: 29 additions & 13 deletions src/loopfilter.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![deny(unsafe_op_in_unsafe_fn)]

use crate::include::common::bitdepth::AsPrimitive;
use crate::include::common::bitdepth::BitDepth;
use crate::include::common::bitdepth::DynPixel;
Expand All @@ -11,6 +13,7 @@ use crate::src::lf_mask::Av1FilterLUT;
use crate::src::strided::Strided as _;
use crate::src::unstable_extensions::as_chunks;
use crate::src::unstable_extensions::flatten;
use crate::src::with_offset::WithOffset;
use crate::src::wrap_fn_ptr::wrap_fn_ptr;
use libc::ptrdiff_t;
use std::cmp;
Expand All @@ -27,12 +30,13 @@ wrap_fn_ptr!(pub unsafe extern "C" fn loopfilter_sb(
dst_ptr: *mut DynPixel,
stride: ptrdiff_t,
mask: &[u32; 3],
lvl: *const [u8; 4],
lvl_ptr: *const [u8; 4],
b4_stride: ptrdiff_t,
lut: &Align16<Av1FilterLUT>,
w: c_int,
bitdepth_max: c_int,
_dst: *const FFISafe<Rav1dPictureDataComponentOffset>,
_lvl: *const FFISafe<WithOffset<&[[u8; 4]]>>,
) -> ());

/// Slice `[u8; 4]`s from `lvl`, but "unaligned",
Expand All @@ -48,24 +52,31 @@ fn unaligned_lvl_slice(lvl: &[[u8; 4]], y: usize) -> &[[u8; 4]] {
}

impl loopfilter_sb::Fn {
pub unsafe fn call<BD: BitDepth>(
pub fn call<BD: BitDepth>(
&self,
f: &Rav1dFrameData,
dst: Rav1dPictureDataComponentOffset,
mask: &[u32; 3],
lvl: &[[u8; 4]],
mut lvl: WithOffset<&[[u8; 4]]>,
lvl_y_offset: usize,
w: usize,
) {
let dst_ptr = dst.as_mut_ptr::<BD>().cast();
let stride = dst.stride();
let lvl = unaligned_lvl_slice(lvl, lvl_y_offset).as_ptr();
lvl.data = unaligned_lvl_slice(lvl.data, lvl_y_offset);
let lvl_ptr = lvl.data[lvl.offset..].as_ptr();
let b4_stride = f.b4_stride;
let lut = &f.lf.lim_lut;
let w = w as c_int;
let bd = f.bitdepth_max;
let dst = FFISafe::new(&dst);
self.get()(dst_ptr, stride, mask, lvl, b4_stride, lut, w, bd, dst)
let lvl = FFISafe::new(&lvl);
// SAFETY: Fallback `fn loop_filter_sb128_rust` is safe; asm is supposed to do the same.
unsafe {
self.get()(
dst_ptr, stride, mask, lvl_ptr, b4_stride, lut, w, bd, dst, lvl,
)
}
}

const fn default<BD: BitDepth, const HV: usize, const YUV: usize>() -> Self {
Expand Down Expand Up @@ -288,10 +299,10 @@ enum YUV {
UV,
}

unsafe fn loop_filter_sb128_rust<BD: BitDepth, const HV: usize, const YUV: usize>(
fn loop_filter_sb128_rust<BD: BitDepth, const HV: usize, const YUV: usize>(
mut dst: Rav1dPictureDataComponentOffset,
vmask: &[u32; 3],
mut l: *const [u8; 4],
mut lvl: WithOffset<&[[u8; 4]]>,
b4_stride: usize,
lut: &Align16<Av1FilterLUT>,
_wh: c_int,
Expand Down Expand Up @@ -320,10 +331,11 @@ unsafe fn loop_filter_sb128_rust<BD: BitDepth, const HV: usize, const YUV: usize
if vm & xy == 0 {
break 'block;
}
let L = if (*l.offset(0))[0] != 0 {
(*l.offset(0))[0]
let L = if lvl.data[lvl.offset][0] != 0 {
lvl.data[lvl.offset][0]
} else {
(*l.sub(b4_strideb))[0]
// SAFETY: TODO will make this safe
unsafe { (*lvl.data[lvl.offset..].as_ptr().sub(b4_strideb))[0] }
};
if L == 0 {
break 'block;
Expand All @@ -349,29 +361,33 @@ unsafe fn loop_filter_sb128_rust<BD: BitDepth, const HV: usize, const YUV: usize
}
xy <<= 1;
dst += 4 * stridea;
l = l.add(b4_stridea);
lvl += b4_stridea;
}
}

/// # Safety
///
/// Must be called by [`loopfilter_sb::Fn::call`].
#[deny(unsafe_op_in_unsafe_fn)]
unsafe extern "C" fn loop_filter_sb128_c_erased<BD: BitDepth, const HV: usize, const YUV: usize>(
_dst_ptr: *mut DynPixel,
_stride: ptrdiff_t,
vmask: &[u32; 3],
l: *const [u8; 4],
_lvl_ptr: *const [u8; 4],
b4_stride: isize,
lut: &Align16<Av1FilterLUT>,
wh: c_int,
bitdepth_max: c_int,
dst: *const FFISafe<Rav1dPictureDataComponentOffset>,
lvl: *const FFISafe<WithOffset<&[[u8; 4]]>>,
) {
// SAFETY: Was passed as `FFISafe::new(_)` in `loopfilter_sb::Fn::call`.
let dst = *unsafe { FFISafe::get(dst) };
// SAFETY: Was passed as `FFISafe::new(_)` in `loopfilter_sb::Fn::call`.
let lvl = *unsafe { FFISafe::get(lvl) };
let b4_stride = b4_stride as usize;
let bd = BD::from_c(bitdepth_max);
loop_filter_sb128_rust::<BD, { HV }, { YUV }>(dst, vmask, l, b4_stride, lut, wh, bd)
loop_filter_sb128_rust::<BD, { HV }, { YUV }>(dst, vmask, lvl, b4_stride, lut, wh, bd)
}

impl Rav1dLoopFilterDSPContext {
Expand Down
Loading

0 comments on commit 7e0a9d8

Please sign in to comment.