Skip to content

Commit

Permalink
fn unaligned_lvl_slice: Safely encapsulate "unaligned" slices of `l…
Browse files Browse the repository at this point in the history
…vl: &[[u8; 4]]`, as previous accesses were UB (#731)

See:
* #728 (comment)
* #728 (comment)

The current "unaligned" access of `lvl`, added in #726, and a proposed
change in #728, [are UB according to
`miri`](https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=80849f822a890cabd324b3b898b291b8),
as they cast `&[u8; 3]` and `&u8` to `&[u8; 4]`, which is an out of
bounds read of a borrow. This `fn unaligned_lvl_slice` helper `fn` fixes
it and passes `miri` by transmuting (through `.align_to`) to a `&[u8]`,
doing the slice, and then doing a checked cast back to `[u8; 4]`. It
also does correct bounds checking, as the previous ways were off by one
in their bounds checking, as they didn't consider the read of the
unaligned part into the next `[u8; 4]`. Doing it this way also optimizes
just as well as the previous methods.

We could also use a dependency like `zerocopy` to encapsulate the
`unsafe` `.align_to`, but as long as these kinds of "unaligned" array
reads are a one-off, doing it inline here seems better.

Update: Switched to using unstable `fn`s from `std` that make things
fully safe and correct, copied into our codebase for stability.
I already figured out how to fix the UB in reviewing #728, so I just
went ahead and made the fix independently.
  • Loading branch information
kkysen authored Feb 9, 2024
2 parents a403fa5 + 180f07b commit b796035
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 10 deletions.
1 change: 1 addition & 0 deletions lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod src {
mod fg_apply;
mod filmgrain;
mod getbits;
mod unstable_extensions;
pub(crate) mod wrap_fn_ptr;
// TODO(kkysen) Temporarily `pub(crate)` due to a `pub use` until TAIT.
pub(super) mod internal;
Expand Down
37 changes: 27 additions & 10 deletions src/lf_apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use crate::src::lf_mask::Av1Filter;
use crate::src::lr_apply::LR_RESTORE_U;
use crate::src::lr_apply::LR_RESTORE_V;
use crate::src::lr_apply::LR_RESTORE_Y;
use crate::src::unstable_extensions::as_chunks;
use crate::src::unstable_extensions::flatten;
use libc::ptrdiff_t;
use std::cmp;
use std::ffi::c_int;
Expand Down Expand Up @@ -341,6 +343,18 @@ pub(crate) unsafe fn rav1d_copy_lpf<BD: BitDepth>(
}
}

/// Slice `[u8; 4]`s from `lvl`, but "unaligned",
/// meaning the `[u8; 4]`s can straddle
/// adjacent `[u8; 4]`s in the `lvl` slice.
///
/// Note that this does not result in actual unaligned reads,
/// since `[u8; 4]` has an alignment of 1.
/// This optimizes to a single slice with a bounds check.
#[inline(always)]
fn unaligned_lvl_slice(lvl: &[[u8; 4]], y: usize) -> &[[u8; 4]] {
as_chunks(&flatten(lvl)[y..]).0
}

#[inline]
unsafe fn filter_plane_cols_y<BD: BitDepth>(
f: *const Rav1dFrameContext,
Expand Down Expand Up @@ -377,7 +391,7 @@ unsafe fn filter_plane_cols_y<BD: BitDepth>(
dst.offset((x * 4) as isize).cast(),
ls,
hmask.as_mut_ptr(),
lvl[x as usize][0..].as_ptr() as *const [u8; 4],
&lvl[x as usize],
b4_stride,
&(*f).lf.lim_lut.0,
endy4 - starty4,
Expand Down Expand Up @@ -416,7 +430,7 @@ unsafe fn filter_plane_rows_y<BD: BitDepth>(
dst.cast(),
ls,
vmask.as_ptr(),
lvl[0][1..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[0..], 1).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
w,
Expand Down Expand Up @@ -462,7 +476,7 @@ unsafe fn filter_plane_cols_uv<BD: BitDepth>(
u.offset((x * 4) as isize).cast(),
ls,
hmask.as_mut_ptr(),
lvl[x as usize][2..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[x as usize..], 2).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
endy4 - starty4,
Expand All @@ -472,7 +486,7 @@ unsafe fn filter_plane_cols_uv<BD: BitDepth>(
v.offset((x * 4) as isize).cast(),
ls,
hmask.as_mut_ptr(),
lvl[x as usize][3..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[x as usize..], 3).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
endy4 - starty4,
Expand Down Expand Up @@ -512,7 +526,7 @@ unsafe fn filter_plane_rows_uv<BD: BitDepth>(
u.offset(off_l as isize).cast(),
ls,
vmask.as_ptr(),
lvl[0][2..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[0..], 2).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
w,
Expand All @@ -522,7 +536,7 @@ unsafe fn filter_plane_rows_uv<BD: BitDepth>(
v.offset(off_l as isize).cast(),
ls,
vmask.as_ptr(),
lvl[0][3..].as_ptr() as *const [u8; 4],
unaligned_lvl_slice(&lvl[0..], 3).as_ptr(),
b4_stride,
&(*f).lf.lim_lut.0,
w,
Expand Down Expand Up @@ -687,10 +701,10 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
}
}
let mut ptr: *mut BD::Pixel;
let level_ptr = &(*f).lf.level[((*f).b4_stride * sby as isize * sbsz as isize) as usize..];
let mut level_ptr = &(*f).lf.level[((*f).b4_stride * sby as isize * sbsz as isize) as usize..];
ptr = p[0];
have_left = 0 as c_int;
for (x, level_ptr) in (0..(*f).sb128w).zip(level_ptr.chunks(32)) {
for x in 0..(*f).sb128w {
filter_plane_cols_y::<BD>(
f,
have_left,
Expand All @@ -705,15 +719,17 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
);
have_left = 1 as c_int;
ptr = ptr.offset(128);
level_ptr = &level_ptr[32..];
}
if frame_hdr.loopfilter.level_u == 0 && frame_hdr.loopfilter.level_v == 0 {
return;
}
let mut uv_off: ptrdiff_t;
let level_ptr = &(*f).lf.level[((*f).b4_stride * (sby * sbsz >> ss_ver) as isize) as usize..];
let mut level_ptr =
&(*f).lf.level[((*f).b4_stride * (sby * sbsz >> ss_ver) as isize) as usize..];
have_left = 0 as c_int;
uv_off = 0;
for (x, level_ptr) in (0..(*f).sb128w).zip(level_ptr.chunks(32 >> ss_hor)) {
for x in 0..(*f).sb128w {
filter_plane_cols_uv::<BD>(
f,
have_left,
Expand All @@ -730,6 +746,7 @@ pub(crate) unsafe fn rav1d_loopfilter_sbrow_cols<BD: BitDepth>(
);
have_left = 1 as c_int;
uv_off += 128 >> ss_hor;
level_ptr = &level_ptr[32 >> ss_hor..];
}
}

Expand Down
55 changes: 55 additions & 0 deletions src/unstable_extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//! Unstable `fn`s copied directly from `std`, with the following differences:
//! * They are free `fn`s now, not methods.
//! * `self` is replaced by `this`.
//! * Things only accessible by `std` are replaced with stable counterparts, such as:
//! * `exact_div` => `/`
//! * `.unchecked_mul` => `*`
//! * `const` `.expect` => `match` and `panic!`

use std::mem;
use std::slice::from_raw_parts;

/// From `1.75.0`.
pub const fn flatten<const N: usize, T>(this: &[[T; N]]) -> &[T] {
let len = if mem::size_of::<T>() == 0 {
match this.len().checked_mul(N) {
None => panic!("slice len overflow"),
Some(it) => it,
}
} else {
// SAFETY: `this.len() * N` cannot overflow because `self` is
// already in the address space.
/* unsafe */
this.len() * N
};
// SAFETY: `[T]` is layout-identical to `[T; N]`
unsafe { from_raw_parts(this.as_ptr().cast(), len) }
}

/// From `1.75.0`.
#[inline]
#[must_use]
pub const unsafe fn as_chunks_unchecked<const N: usize, T>(this: &[T]) -> &[[T; N]] {
// SAFETY: Caller must guarantee that `N` is nonzero and exactly divides the slice length
let new_len = /* unsafe */ {
assert!(N != 0 && this.len() % N == 0);
this.len() / N
};
// SAFETY: We cast a slice of `new_len * N` elements into
// a slice of `new_len` many `N` elements chunks.
unsafe { from_raw_parts(this.as_ptr().cast(), new_len) }
}

/// From `1.75.0`.
#[inline]
#[track_caller]
#[must_use]
pub const fn as_chunks<const N: usize, T>(this: &[T]) -> (&[[T; N]], &[T]) {
assert!(N != 0, "chunk size must be non-zero");
let len = this.len() / N;
let (multiple_of_n, remainder) = this.split_at(len * N);
// SAFETY: We already panicked for zero, and ensured by construction
// that the length of the subslice is a multiple of N.
let array_slice = unsafe { as_chunks_unchecked(multiple_of_n) };
(array_slice, remainder)
}

0 comments on commit b796035

Please sign in to comment.