Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fn unaligned_lvl_slice: Safely encapsulate "unaligned" slices of lvl: &[[u8; 4]], as previous accesses were UB #731

Merged
merged 3 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod src {
mod fg_apply;
mod filmgrain;
mod getbits;
mod unstable_extensions;
pub(crate) mod wrap_fn_ptr;
// TODO(kkysen) Temporarily `pub(crate)` due to a `pub use` until TAIT.
pub(super) mod internal;
Expand Down
37 changes: 27 additions & 10 deletions src/lf_apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::src::lf_mask::Av1Filter;
use crate::src::lr_apply::LR_RESTORE_U;
use crate::src::lr_apply::LR_RESTORE_V;
use crate::src::lr_apply::LR_RESTORE_Y;
use crate::src::unstable_extensions::as_chunks;
use crate::src::unstable_extensions::flatten;
use libc::ptrdiff_t;
use std::cmp;
use std::ffi::c_int;
Expand Down Expand Up @@ -341,6 +343,18 @@ pub(crate) unsafe fn rav1d_copy_lpf<BD: BitDepth>(
}
}

/// Slice `[u8; 4]`s from `lvl`, but "unaligned",
/// meaning the `[u8; 4]`s can straddle
/// adjacent `[u8; 4]`s in the `lvl` slice.
///
/// Note that this does not result in actual unaligned reads,
/// since `[u8; 4]` has an alignment of 1.
/// This optimizes to a single slice with a bounds check.
#[inline(always)]
fn unaligned_lvl_slice(lvl: &[[u8; 4]], y: usize) -> &[[u8; 4]] {
as_chunks(&flatten(lvl)[y..]).0
}

#[inline]
unsafe fn filter_plane_cols_y<BD: BitDepth>(
f: *const Rav1dFrameContext,
Expand Down Expand Up @@ -377,7 +391,7 @@ unsafe fn filter_plane_cols_y<BD: BitDepth>(
dst.offset((x * 4) as isize).cast(),
ls,
hmask.as_mut_ptr(),
lvl[x as usize][0..].as_ptr() as *const [u8; 4],
&lvl[x as usize],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If enforcing a minimum slice size, this would be something like lvl[x as usize..][..((endy4 - starty4 - 1) as isize * b4_stride + 1) as usize].as_ptr(). However, index -1 of such as slice can also be read, so more work on slice definitions will be needed. But that's probably outside the scope of this PR

Copy link
Collaborator Author

@kkysen kkysen Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to structure things as, at this point in the code, we don't care about checking the length, as it is up to the inner fn to correctly bounds check when reading elements. That's of course not true for the asm fns, though it will be eventually true (once we make them safe) for the C/Rust fallbacks. If we do check the length eagerly here, the only real safety that gives us in documentation, that is, if we check that the asm fn claims it will access X elements and here we check for X elements, we check that the actual length matches the docs, but it doesn't enforce anything on the asm. Not sure if that made sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense. If, down the road, we want the C/Rust fallbacks to check bounds, then we'll probably have to change fn interfaces to avoid casting to raw pointers and recreating slices in the fallback fns? One potential wrinkle here, is that an asm fn may access memory locations that are different from the corresponding fallback fn (because asm may read a 128-bit word where the fallback reads a couple 16-bit words; I can't immediately think of a case where that would also apply a write). I don't know whether we need to consider that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't change the existing fn ptr args without modifying the asm, which we don't want to do. But what we can do is is pass extra args that asm ignores (which we already do for passing bitdepth_max: c_int for 8bpc fns). So we can split a slice into the ptr and len, pass the len as an extra arg, and then recreate the slice in the fallback C/Rust fn. Where it will be technically unsafe, but quite safe in practice.

As for asm fns accessing more memory locations than the fallbacks, we could eagerly checks the bounds, but it's still ultimately up to the asm not to access any more than that. It's always going to be inherently very unsafe.

b4_stride,
&(*f).lf.lim_lut.0,
endy4 - starty4,
Expand Down Expand Up @@ -416,7 +430,7 @@ unsafe fn filter_plane_rows_y<BD: BitDepth>(
dst.cast(),
ls,
vmask.as_ptr(),
lvl[0][1..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[0..], 1).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
w,
Expand Down Expand Up @@ -462,7 +476,7 @@ unsafe fn filter_plane_cols_uv<BD: BitDepth>(
u.offset((x * 4) as isize).cast(),
ls,
hmask.as_mut_ptr(),
lvl[x as usize][2..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[x as usize..], 2).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
endy4 - starty4,
Expand All @@ -472,7 +486,7 @@ unsafe fn filter_plane_cols_uv<BD: BitDepth>(
v.offset((x * 4) as isize).cast(),
ls,
hmask.as_mut_ptr(),
lvl[x as usize][3..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[x as usize..], 3).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
endy4 - starty4,
Expand Down Expand Up @@ -512,7 +526,7 @@ unsafe fn filter_plane_rows_uv<BD: BitDepth>(
u.offset(off_l as isize).cast(),
ls,
vmask.as_ptr(),
lvl[0][2..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[0..], 2).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
w,
Expand All @@ -522,7 +536,7 @@ unsafe fn filter_plane_rows_uv<BD: BitDepth>(
v.offset(off_l as isize).cast(),
ls,
vmask.as_ptr(),
lvl[0][3..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[0..], 3).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
w,
Expand Down Expand Up @@ -687,10 +701,10 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
}
}
let mut ptr: *mut BD::Pixel;
let level_ptr = &(*f).lf.level[((*f).b4_stride * sby as isize * sbsz as isize) as usize..];
let mut level_ptr = &(*f).lf.level[((*f).b4_stride * sby as isize * sbsz as isize) as usize..];
ptr = p[0];
have_left = 0 as c_int;
for (x, level_ptr) in (0..(*f).sb128w).zip(level_ptr.chunks(32)) {
for x in 0..(*f).sb128w {
folkertdev marked this conversation as resolved.
Show resolved Hide resolved
filter_plane_cols_y::<BD>(
f,
have_left,
Expand All @@ -705,15 +719,17 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
);
have_left = 1 as c_int;
ptr = ptr.offset(128);
level_ptr = &level_ptr[32..];
}
if frame_hdr.loopfilter.level_u == 0 && frame_hdr.loopfilter.level_v == 0 {
return;
}
let mut uv_off: ptrdiff_t;
let level_ptr = &(*f).lf.level[((*f).b4_stride * (sby * sbsz >> ss_ver) as isize) as usize..];
let mut level_ptr =
&(*f).lf.level[((*f).b4_stride * (sby * sbsz >> ss_ver) as isize) as usize..];
have_left = 0 as c_int;
uv_off = 0;
for (x, level_ptr) in (0..(*f).sb128w).zip(level_ptr.chunks(32 >> ss_hor)) {
for x in 0..(*f).sb128w {
filter_plane_cols_uv::<BD>(
f,
have_left,
Expand All @@ -730,6 +746,7 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
);
have_left = 1 as c_int;
uv_off += 128 >> ss_hor;
level_ptr = &level_ptr[32 >> ss_hor..];
}
}

Expand Down
55 changes: 55 additions & 0 deletions src/unstable_extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! Unstable `fn`s copied directly from `std`, with the following differences:
//! * They are free `fn`s now, not methods.
//! * `self` is replaced by `this`.
//! * Things only accessible by `std` are replaced with stable counterparts, such as:
//! * `exact_div` => `/`
//! * `.unchecked_mul` => `*`
//! * `const` `.expect` => `match` and `panic!`

use std::mem;
use std::slice::from_raw_parts;

/// From `1.75.0`.
pub const fn flatten<const N: usize, T>(this: &[[T; N]]) -> &[T] {
let len = if mem::size_of::<T>() == 0 {
match this.len().checked_mul(N) {
None => panic!("slice len overflow"),
Some(it) => it,
}
} else {
// SAFETY: `this.len() * N` cannot overflow because `self` is
// already in the address space.
/* unsafe */
this.len() * N
};
// SAFETY: `[T]` is layout-identical to `[T; N]`
unsafe { from_raw_parts(this.as_ptr().cast(), len) }
}

/// From `1.75.0`.
#[inline]
#[must_use]
pub const unsafe fn as_chunks_unchecked<const N: usize, T>(this: &[T]) -> &[[T; N]] {
// SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
let new_len = /* unsafe */ {
assert!(N != 0 && this.len() % N == 0);
this.len() / N
};
// SAFETY: We cast a slice of `new_len * N` elements into
// a slice of `new_len` many `N` elements chunks.
unsafe { from_raw_parts(this.as_ptr().cast(), new_len) }
}

/// From `1.75.0`.
#[inline]
#[track_caller]
#[must_use]
pub const fn as_chunks<const N: usize, T>(this: &[T]) -> (&[[T; N]], &[T]) {
assert!(N != 0, "chunk size must be non-zero");
let len = this.len() / N;
let (multiple_of_n, remainder) = this.split_at(len * N);
// SAFETY: We already panicked for zero, and ensured by construction
// that the length of the subslice is a multiple of N.
let array_slice = unsafe { as_chunks_unchecked(multiple_of_n) };
(array_slice, remainder)
}
Loading