Skip to content

Commit

Permalink
Re-implement case set as a macro
Browse files Browse the repository at this point in the history
  • Loading branch information
rinon committed Jul 10, 2024
1 parent 412cd4c commit 0b8324b
Show file tree
Hide file tree
Showing 6 changed files with 598 additions and 367 deletions.
4 changes: 2 additions & 2 deletions lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ pub mod src {
mod cdf;
mod const_fn;
pub mod cpu;
mod ctx;
pub mod ctx;
mod cursor;
mod data;
mod decode;
mod dequant_tables;
pub(crate) mod disjoint_mut;
pub mod disjoint_mut;
pub(crate) mod enum_map;
mod env;
pub(crate) mod error;
Expand Down
355 changes: 235 additions & 120 deletions src/ctx.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! The [`CaseSet`] API below is a safe and simplified version of the `case_set*` macros in `ctx.h`.
//! The [`case_set!`] macro is a safe and simplified version of the `case_set*`
//! macros in `ctx.h`.
//!
//! The `case_set*` macros themselves replaced `memset`s in order to further optimize them
//! (in e3b5d4d044506f9e0e95e79b3de42fd94386cc61,
Expand All @@ -13,135 +14,249 @@
//! as unaligned writes are UB, and so we'd need to check at runtime if they're aligned
//! (a runtime-determined `off`set is used, so we can't reasonably ensure this at compile-time).
//!
//! To more thoroughly check this, I ran the same benchmarks done in
//! e3b5d4d044506f9e0e95e79b3de42fd94386cc61, which introduced the `case_set*` macros:
//! We also want to avoid multiple switches when setting a group of buffers as
//! the C implementation did, which was implemented in
//! https://github.com/memorysafety/rav1d/pull/1293.
//!
//! ```sh
//! cargo build --release && hyperfine './target/release/dav1d -i ./tests/large/chimera_8b_1080p.ivf -l 1000 -o /dev/null'
//! ```
//! # Benchmarks
//!
//! for 3 implementations:
//! 1. the original `case_set*` macros translated directly to `unsafe` Rust `fn`s
//! 2. the safe [`CaseSet`] implementation below using [`small_memset`] with its small powers of 2 optimization
//! 3. a safe [`CaseSet`] implementation using [`slice::fill`]/`memset` only
//!
//! The [`small_memset`] version was ~1.27% faster than the `case_set*` one,
//! and ~3.26% faster than the `memset` one.
//! The `case_set*` macros were also faster than `memset` in C by a similar margin,
//! meaning the `memset` option is the slowest in both C and Rust,
//! and since it was replaced with `case_set*` in C, we shouldn't use it in Rust.
//! Thus, the [`small_memset`] implementation seems optimal, as it:
//! * is the fastest of the Rust implementations
//! * is completely safe
//! * employs the same small powers of 2 optimization the `case_set*` implementation did
//! * is far simpler than the `case_set*` implementation, consisting of a `match` and array writes
//! Comparing this implementation to the previous implementation of `CaseSet` we
//! see an 8.2-10.5% speedup for a single buffer, a 5.9-7.0% speedup for
//! multiple buffers, and a minor improvement to multiple [`DisjointMut`]
//! buffers (which happened to be well-optimized in the previous
//! implementation).
//!
//! [`BlockContext`]: crate::src::env::BlockContext
use crate::src::disjoint_mut::AsMutPtr;
use crate::src::disjoint_mut::DisjointMut;
use std::iter::zip;
//! [`DisjointMut`]: crate::src::disjoint_mut::DisjointMut

/// Perform a `memset` optimized for lengths that are small powers of 2.
/// Fill small ranges of buffers with a value.
///
/// This is effectively a specialized version [`slice::fill`] for small
/// power-of-two sized ranges of buffers.
///
/// `$UP_TO` is the maximum length that will be optimized, with powers of two up
/// to 64 supported. If the buffer length is not a power of two or greater than
/// `$UP_TO`, this macro will do nothing. See [`case_set_with_default!`] to fill
/// buffers with non-comforming lengths if needed.
///
/// # Examples
///
/// ```
/// # use rav1d::case_set;
/// let mut buf = [0u8; 32];
/// let len = 16;
/// for offset in [0, 16] {
/// case_set!(up_to = 32, len, offset, {
/// set!(&mut buf, 1u8);
/// });
/// }
/// ```
///
/// In the simplest case, `$len` is the length of the buffer range to fill
/// starting from `$offset`. The `$body` block is executed with `len` and
/// `offset` identifiers set to the given length and offset values. Within the
/// body a `set!` macro is available and must be called to set each buffer range
/// to a value. `set!` takes a buffer and a value and sets the range
/// `buf[offset..][..len]` to the value.
/// ```
/// # macro_rules! set {
/// # ($buf:expr, $val:expr) => {};
/// # }
/// set!(buf, value);
/// ```
///
/// ## Naming parameters
///
/// The identifier for either or both of `len` and `offset` can be overridden by
/// specifying `identifer=value` for those parameters:
/// ```
/// # use rav1d::case_set;
/// let mut buf = [0u8; 32];
/// let outer_len = 16;
/// for outer_offset in [0, 16] {
/// case_set!(
/// up_to = 32,
/// len=outer_len,
/// offset=outer_offset,
/// {
/// set!(&mut buf, (offset+len) as u8);
/// }
/// );
/// }
/// ```
///
/// ## `DisjointMut` buffers
///
/// [`DisjointMut`] buffers can be used in basically the same way as normal
/// buffers but using the `set_disjoint!` macro instead of `set!`.
/// ```
/// # use rav1d::case_set;
/// # use rav1d::src::disjoint_mut::DisjointMut;
/// let mut buf = DisjointMut::new([0u8; 32]);
/// let len = 16;
/// for offset in [0, 16] {
/// case_set!(up_to = 32, len, offset, {
/// set_disjoint!(&mut buf, 1u8);
/// });
/// }
/// ```
///
/// ## Multiple buffer ranges
///
/// Multiple buffers with different lengths and offsets can be filled with the
/// same body statements. In the following example, two buffers with different
/// sizes are initialized by quarters.
/// ```
/// # use rav1d::case_set;
/// let mut buf1 = [0u8; 32];
/// let mut buf2 = [0u8; 64];
/// for offset in [0, 8, 16, 24] {
/// case_set!(
/// up_to = 16,
/// buf = [&mut buf1[..], &mut buf2[..]],
/// len = [8, 16],
/// offset = [offset, offset*2],
/// {
/// set!(buf, len as u8 >> 3);
/// }
/// );
/// }
/// ```
///
/// For power of 2 lengths `<= UP_TO`,
/// the `memset` is done as an array write of that exactly (compile-time known) length.
/// If the length is not a power of 2 or `> UP_TO`,
/// then the `memset` is done by [`slice::fill`] (a `memset` call) if `WITH_DEFAULT` is `true`,
/// or else skipped if `WITH_DEFAULT` is `false`.
/// A more realistic example of filling multiple buffers with the same value is
/// initializing different struct fields at the same time (from
/// `src/decode.rs`):
/// ```ignore
/// case_set!(
/// up_to = 32,
/// ctx = [(&t.l, 1), (&f.a[t.a], 0)],
/// len = [bh4, bw4],
/// offset = [by4, bx4],
/// {
/// let (dir, dir_index) = ctx;
/// set_disjoint!(dir.seg_pred, seg_pred.into());
/// set_disjoint!(dir.skip_mode, b.skip_mode);
/// set_disjoint!(dir.intra, 0);
/// set_disjoint!(dir.skip, b.skip);
/// set_disjoint!(dir.pal_sz, 0);
/// }
/// );
/// ```
///
/// This optimizes for the common cases where `buf.len()` is a small power of 2,
/// where the array write is optimized as few and large stores as possible.
#[inline]
pub fn small_memset<T: Clone + Copy, const UP_TO: usize, const WITH_DEFAULT: bool>(
buf: &mut [T],
val: T,
) {
fn as_array<T: Clone + Copy, const N: usize>(buf: &mut [T]) -> &mut [T; N] {
buf.try_into().unwrap()
}
match buf.len() {
01 if UP_TO >= 01 => *as_array(buf) = [val; 01],
02 if UP_TO >= 02 => *as_array(buf) = [val; 02],
04 if UP_TO >= 04 => *as_array(buf) = [val; 04],
08 if UP_TO >= 08 => *as_array(buf) = [val; 08],
16 if UP_TO >= 16 => *as_array(buf) = [val; 16],
32 if UP_TO >= 32 => *as_array(buf) = [val; 32],
64 if UP_TO >= 64 => *as_array(buf) = [val; 64],
_ => {
if WITH_DEFAULT {
buf.fill(val)
/// [`DisjointMut`]: crate::src::disjoint_mut::DisjointMut
macro_rules! case_set {
(up_to=$UP_TO:literal, $(@DEFAULT=$WITH_DEFAULT:literal,)? $ctx:ident=[$($ctx_expr:expr),* $(,)?], $len:ident=[$($len_expr:expr),* $(,)?], $offset:ident=[$($offset_expr:expr),* $(,)?], $body:block) => {
let ctxs = [$($ctx_expr,)*];
let lens = [$($len_expr,)*];
let offsets = [$($offset_expr,)*];
assert_eq!(ctxs.len(), lens.len());
assert_eq!(ctxs.len(), offsets.len());
for (i, ctx) in ctxs.into_iter().enumerate() {
case_set!(up_to=$UP_TO, $(@DEFAULT=$WITH_DEFAULT,)? $ctx=ctx, $len=lens[i], $offset=offsets[i], $body);
}
};
(up_to=$UP_TO:literal, $(@DEFAULT=$WITH_DEFAULT:literal,)? $len:ident, $offset:ident, $body:block) => {
case_set!(up_to=$UP_TO, $(@DEFAULT=$WITH_DEFAULT,)? _ctx=(), $len=$len, $offset=$offset, $body);
};
(up_to=$UP_TO:literal, $(@DEFAULT=$WITH_DEFAULT:literal,)? $len:ident=$len_expr:expr, $offset:ident=$offset_expr:expr, $body:block) => {
case_set!(up_to=$UP_TO, $(@DEFAULT=$WITH_DEFAULT,)? _ctx=(), $len=$len_expr, $offset=$offset_expr, $body);
};
(up_to=$UP_TO:literal, $(@DEFAULT=$WITH_DEFAULT:literal,)? $ctx:ident=$ctx_expr:expr, $len:ident=$len_expr:expr, $offset:ident=$offset_expr:expr, $body:block) => {
#[allow(unused_mut)]
let mut $ctx = $ctx_expr;
let $len = $len_expr;
let $offset = $offset_expr;
{
#[allow(unused_macros)]
macro_rules! set {
($buf:expr, $val:expr) => {{
assert!($offset <= $buf.len() && $offset + $len <= $buf.len());
}};
}
#[allow(unused_imports)]
use set as set_disjoint;
#[allow(unused)]
$body
}
}
}

pub struct CaseSetter<const UP_TO: usize, const WITH_DEFAULT: bool> {
offset: usize,
len: usize,
}

impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSetter<UP_TO, WITH_DEFAULT> {
#[inline]
pub fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T) {
small_memset::<T, UP_TO, WITH_DEFAULT>(&mut buf[self.offset..][..self.len], val);
}

/// # Safety
///
/// Caller must ensure that no elements of the written range are concurrently
/// borrowed (immutably or mutably) at all during the call to `set_disjoint`.
#[inline]
pub fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
where
T: AsMutPtr<Target = V>,
V: Clone + Copy,
{
let mut buf = buf.index_mut(self.offset..self.offset + self.len);
small_memset::<V, UP_TO, WITH_DEFAULT>(&mut *buf, val);
}
macro_rules! exec_block {
($N:literal, $block:block) => {
{
#[allow(unused_macros)]
macro_rules! set {
($buf:expr, $val:expr) => {
// SAFETY: The offset and length are checked by the
// assert outside of the match.
let buf_range = unsafe {
$buf.get_unchecked_mut($offset..$offset+$N)
};
*<&mut [_; $N]>::try_from(buf_range).unwrap() = [$val; $N];
};
}
#[allow(unused_macros)]
macro_rules! set_disjoint {
($buf:expr, $val:expr) => {{
// SAFETY: The offset and length are checked by the
// assert outside of the match.
let mut buf_range = unsafe {
$buf.index_mut_unchecked(($offset.., ..$N))
};
*<&mut [_; $N]>::try_from(&mut *buf_range).unwrap() = [$val; $N];
}};
}
$block
}
};
}
match $len {
01 if $UP_TO >= 01 => exec_block!(01, $body),
02 if $UP_TO >= 02 => exec_block!(02, $body),
04 if $UP_TO >= 04 => exec_block!(04, $body),
08 if $UP_TO >= 08 => exec_block!(08, $body),
16 if $UP_TO >= 16 => exec_block!(16, $body),
32 if $UP_TO >= 32 => exec_block!(32, $body),
64 if $UP_TO >= 64 => exec_block!(64, $body),
_ => {
if $($WITH_DEFAULT ||)? false {
#[allow(unused_macros)]
macro_rules! set {
($buf:expr, $val:expr) => {{
// SAFETY: The offset and length are checked by the
// assert outside of the match.
let buf_range = unsafe {
$buf.get_unchecked_mut($offset..$offset+$len)
};
buf_range.fill($val);
}};
}
#[allow(unused_macros)]
macro_rules! set_disjoint {
($buf:expr, $val:expr) => {{
// SAFETY: The offset and length are checked by the
// assert outside of the match.
let mut buf_range = unsafe {
$buf.index_mut_unchecked(($offset.., ..$len))
};
buf_range.fill($val);
}};
}
$body
}
}
}
};
}
pub(crate) use case_set;

/// The entrypoint to the [`CaseSet`] API.
/// Fill small ranges of buffers with a value.
///
/// `UP_TO` and `WITH_DEFAULT` are made const generic parameters rather than have multiple `case_set*` `fn`s,
/// and these are put in a separate `struct` so that these 2 generic parameters
/// can be manually specified while the ones on the methods are inferred.
pub struct CaseSet<const UP_TO: usize, const WITH_DEFAULT: bool>;

impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSet<UP_TO, WITH_DEFAULT> {
/// Perform one case set.
///
/// This API is generic over the element type (`T`) rather than hardcoding `u8`,
/// as sometimes other types are used, though only `i8` is used currently.
///
/// The `len` and `offset` are supplied here and
/// applied to each `buf` passed to [`CaseSetter::set`] in `set_ctx`.
#[inline]
pub fn one<T, F>(ctx: T, len: usize, offset: usize, mut set_ctx: F)
where
F: FnMut(&CaseSetter<UP_TO, WITH_DEFAULT>, T),
{
set_ctx(&CaseSetter { offset, len }, ctx);
}

/// Perform many case sets in one call.
///
/// This allows specifying the `set_ctx` closure inline easily,
/// and also allows you to group the same args together.
///
/// The `lens`, `offsets`, and `dirs` are zipped and passed to [`CaseSet::one`],
/// where `dirs` can be an array of any type and whose elements are passed back to the `set_ctx` closure.
#[inline]
pub fn many<T, F, const N: usize>(
dirs: [T; N],
lens: [usize; N],
offsets: [usize; N],
mut set_ctx: F,
) where
F: FnMut(&CaseSetter<UP_TO, WITH_DEFAULT>, T),
{
for (dir, (len, offset)) in zip(dirs, zip(lens, offsets)) {
Self::one(dir, len, offset, &mut set_ctx);
}
}
/// `$UP_TO` is the maximum length that will be optimized, with powers of two up
/// to 64 supported. If the buffer length is not a power of two or greater than
/// `$UP_TO`, this macro will still fill the buffer with a slower fallback.
///
/// See [`case_set!`] for examples and more documentation.
macro_rules! case_set_with_default {
(up_to=$UP_TO:literal, $($tt:tt)*) => {
$crate::src::ctx::case_set!(up_to=$UP_TO, @DEFAULT=true, $($tt)*);
};
}
pub(crate) use case_set_with_default;
Loading

0 comments on commit 0b8324b

Please sign in to comment.