Skip to content

Commit

Permalink
Extract Rav1dTileStateContext locking
Browse files Browse the repository at this point in the history
`decode_b` is hot but only sometimes needs to lock
`Rav1dTileStateContext`. If `frame_thread.pass` is
not 2, we can take the lock higher in the call
chain and avoid repeatedly locking and unlocking
the context.
  • Loading branch information
rinon committed Jul 1, 2024
1 parent a33d74a commit 2506ce8
Showing 1 changed file with 96 additions and 61 deletions.
157 changes: 96 additions & 61 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,7 @@ fn decode_b(
c: &Rav1dContext,
t: &mut Rav1dTaskContext,
f: &Rav1dFrameData,
pass: &mut FrameThreadPassState,
bl: BlockLevel,
bs: BlockSize,
bp: BlockPartition,
Expand Down Expand Up @@ -1188,7 +1189,7 @@ fn decode_b(
&& (bh4 > ss_ver || t.b.y & 1 != 0);
let frame_type = f.frame_hdr.as_ref().unwrap().frame_type;

if t.frame_thread.pass == 2 {
let FrameThreadPassState::First(ts_c) = pass else {
match &b.ii {
Av1BlockIntraInter::Intra(intra) => {
(bd_fn.recon_b_intra)(f, t, None, bs, intra_edge_flags, b, intra);
Expand Down Expand Up @@ -1316,9 +1317,9 @@ fn decode_b(
}

return Ok(());
}
};

let ts_c = &mut *ts.context.try_lock().unwrap();
let ts_c = &mut **ts_c;

let cw4 = w4 + ss_hor >> ss_hor;
let ch4 = h4 + ss_ver >> ss_ver;
Expand Down Expand Up @@ -3426,10 +3427,16 @@ fn decode_b(
Ok(())
}

enum FrameThreadPassState<'a> {
First(&'a mut Rav1dTileStateContext),
Second,
}

fn decode_sb(
c: &Rav1dContext,
t: &mut Rav1dTaskContext,
f: &Rav1dFrameData,
pass: &mut FrameThreadPassState,
bl: BlockLevel,
edge_index: EdgeIndex,
) -> Result<(), ()> {
Expand All @@ -3450,6 +3457,7 @@ fn decode_sb(
c,
t,
f,
pass,
next_bl,
intra_edge.branch(sb128, edge_index).split[0],
);
Expand All @@ -3460,24 +3468,26 @@ fn decode_sb(
let bp;
let mut bx8 = 0;
let mut by8 = 0;
let ctx = if t.frame_thread.pass == 2 {
None
} else {
if false && bl == BlockLevel::Bl64x64 {
let ts_c = ts.context.try_lock().unwrap();
println!(
"poc={},y={},x={},bl={:?},r={}",
frame_hdr.frame_offset, t.b.y, t.b.x, bl, ts_c.msac.rng,
);
let ctx = match pass {
FrameThreadPassState::First(ts_c) => {
if false && bl == BlockLevel::Bl64x64 {
println!(
"poc={},y={},x={},bl={:?},r={}",
frame_hdr.frame_offset, t.b.y, t.b.x, bl, ts_c.msac.rng,
);
}
bx8 = (t.b.x & 31) >> 1;
by8 = (t.b.y & 31) >> 1;
Some((
get_partition_ctx(&f.a[t.a], &t.l, bl, by8, bx8),
&mut **ts_c,
))
}
bx8 = (t.b.x & 31) >> 1;
by8 = (t.b.y & 31) >> 1;
Some(get_partition_ctx(&f.a[t.a], &t.l, bl, by8, bx8))
FrameThreadPassState::Second => None,
};

if have_h_split && have_v_split {
if let Some(ctx) = ctx {
let ts_c = &mut *ts.context.try_lock().unwrap();
if let Some((ctx, ts_c)) = ctx {
let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize];
bp = BlockPartition::from_repr(rav1d_msac_decode_symbol_adapt16(
&mut ts_c.msac,
Expand Down Expand Up @@ -3518,37 +3528,46 @@ fn decode_sb(
match bp {
BlockPartition::None => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, node.o)?;
decode_b(c, t, f, pass, bl, b[0], bp, node.o)?;
}
BlockPartition::H => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, node.h[0])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?;
t.b.y += hsz;
decode_b(c, t, f, bl, b[0], bp, node.h[1])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?;
t.b.y -= hsz;
}
BlockPartition::V => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, node.v[0])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?;
t.b.x += hsz;
decode_b(c, t, f, bl, b[0], bp, node.v[1])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?;
t.b.x -= hsz;
}
BlockPartition::Split => {
match bl.decrease() {
None => {
let tip = intra_edge.tip(sb128, edge_index);
assert!(hsz == 1);
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, EdgeFlags::ALL_TR_AND_BL)?;
decode_b(
c,
t,
f,
pass,
bl,
BlockSize::Bs4x4,
bp,
EdgeFlags::ALL_TR_AND_BL,
)?;
let tl_filter = t.tl_4x4_filter;
t.b.x += 1;
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[0])?;
decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[0])?;
t.b.x -= 1;
t.b.y += 1;
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[1])?;
decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[1])?;
t.b.x += 1;
t.tl_4x4_filter = tl_filter;
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[2])?;
decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[2])?;
t.b.x -= 1;
t.b.y -= 1;
if cfg!(target_arch = "x86_64") && t.frame_thread.pass != 0 {
Expand All @@ -3564,92 +3583,91 @@ fn decode_sb(
}
Some(next_bl) => {
let branch = intra_edge.branch(sb128, edge_index);
decode_sb(c, t, f, next_bl, branch.split[0])?;
decode_sb(c, t, f, pass, next_bl, branch.split[0])?;
t.b.x += hsz;
decode_sb(c, t, f, next_bl, branch.split[1])?;
decode_sb(c, t, f, pass, next_bl, branch.split[1])?;
t.b.x -= hsz;
t.b.y += hsz;
decode_sb(c, t, f, next_bl, branch.split[2])?;
decode_sb(c, t, f, pass, next_bl, branch.split[2])?;
t.b.x += hsz;
decode_sb(c, t, f, next_bl, branch.split[3])?;
decode_sb(c, t, f, pass, next_bl, branch.split[3])?;
t.b.x -= hsz;
t.b.y -= hsz;
}
}
}
BlockPartition::TopSplit => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
t.b.x += hsz;
decode_b(c, t, f, bl, b[0], bp, node.v[1])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?;
t.b.x -= hsz;
t.b.y += hsz;
decode_b(c, t, f, bl, b[1], bp, node.h[1])?;
decode_b(c, t, f, pass, bl, b[1], bp, node.h[1])?;
t.b.y -= hsz;
}
BlockPartition::BottomSplit => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, node.h[0])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?;
t.b.y += hsz;
decode_b(c, t, f, bl, b[1], bp, node.v[0])?;
decode_b(c, t, f, pass, bl, b[1], bp, node.v[0])?;
t.b.x += hsz;
decode_b(c, t, f, bl, b[1], bp, EdgeFlags::empty())?;
decode_b(c, t, f, pass, bl, b[1], bp, EdgeFlags::empty())?;
t.b.x -= hsz;
t.b.y -= hsz;
}
BlockPartition::LeftSplit => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
t.b.y += hsz;
decode_b(c, t, f, bl, b[0], bp, node.h[1])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?;
t.b.y -= hsz;
t.b.x += hsz;
decode_b(c, t, f, bl, b[1], bp, node.v[1])?;
decode_b(c, t, f, pass, bl, b[1], bp, node.v[1])?;
t.b.x -= hsz;
}
BlockPartition::RightSplit => {
let node = intra_edge.node(sb128, edge_index);
decode_b(c, t, f, bl, b[0], bp, node.v[0])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?;
t.b.x += hsz;
decode_b(c, t, f, bl, b[1], bp, node.h[0])?;
decode_b(c, t, f, pass, bl, b[1], bp, node.h[0])?;
t.b.y += hsz;
decode_b(c, t, f, bl, b[1], bp, EdgeFlags::empty())?;
decode_b(c, t, f, pass, bl, b[1], bp, EdgeFlags::empty())?;
t.b.y -= hsz;
t.b.x -= hsz;
}
BlockPartition::H4 => {
let branch = intra_edge.branch(sb128, edge_index);
let node = &branch.node;
decode_b(c, t, f, bl, b[0], bp, node.h[0])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?;
t.b.y += hsz >> 1;
decode_b(c, t, f, bl, b[0], bp, branch.h4)?;
decode_b(c, t, f, pass, bl, b[0], bp, branch.h4)?;
t.b.y += hsz >> 1;
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_LEFT_HAS_BOTTOM)?;
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_LEFT_HAS_BOTTOM)?;
t.b.y += hsz >> 1;
if t.b.y < f.bh {
decode_b(c, t, f, bl, b[0], bp, node.h[1])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?;
}
t.b.y -= hsz * 3 >> 1;
}
BlockPartition::V4 => {
let branch = intra_edge.branch(sb128, edge_index);
let node = &branch.node;
decode_b(c, t, f, bl, b[0], bp, node.v[0])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?;
t.b.x += hsz >> 1;
decode_b(c, t, f, bl, b[0], bp, branch.v4)?;
decode_b(c, t, f, pass, bl, b[0], bp, branch.v4)?;
t.b.x += hsz >> 1;
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TOP_HAS_RIGHT)?;
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TOP_HAS_RIGHT)?;
t.b.x += hsz >> 1;
if t.b.x < f.bw {
decode_b(c, t, f, bl, b[0], bp, node.v[1])?;
decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?;
}
t.b.x -= hsz * 3 >> 1;
}
}
} else if have_h_split {
let is_split;
if let Some(ctx) = ctx {
let ts_c = &mut *ts.context.try_lock().unwrap();
if let Some((ctx, ts_c)) = ctx {
let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize];
is_split = rav1d_msac_decode_bool(&mut ts_c.msac, gather_top_partition_prob(pc, bl));
if debug_block_info!(f, t.b) {
Expand Down Expand Up @@ -3683,9 +3701,9 @@ fn decode_sb(
if is_split {
let branch = intra_edge.branch(sb128, edge_index);
bp = BlockPartition::Split;
decode_sb(c, t, f, next_bl, branch.split[0])?;
decode_sb(c, t, f, pass, next_bl, branch.split[0])?;
t.b.x += hsz;
decode_sb(c, t, f, next_bl, branch.split[1])?;
decode_sb(c, t, f, pass, next_bl, branch.split[1])?;
t.b.x -= hsz;
} else {
let node = intra_edge.node(sb128, edge_index);
Expand All @@ -3694,6 +3712,7 @@ fn decode_sb(
c,
t,
f,
pass,
bl,
dav1d_block_sizes[bl as usize][bp as usize][0],
bp,
Expand All @@ -3703,8 +3722,7 @@ fn decode_sb(
} else {
assert!(have_v_split);
let is_split;
if let Some(ctx) = ctx {
let ts_c = &mut *ts.context.try_lock().unwrap();
if let Some((ctx, ts_c)) = ctx {
let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize];
is_split = rav1d_msac_decode_bool(&mut ts_c.msac, gather_left_partition_prob(pc, bl));
if f.cur.p.layout == Rav1dPixelLayout::I422 && !is_split {
Expand Down Expand Up @@ -3741,9 +3759,9 @@ fn decode_sb(
if is_split {
let branch = intra_edge.branch(sb128, edge_index);
bp = BlockPartition::Split;
decode_sb(c, t, f, next_bl, branch.split[0])?;
decode_sb(c, t, f, pass, next_bl, branch.split[0])?;
t.b.y += hsz;
decode_sb(c, t, f, next_bl, branch.split[2])?;
decode_sb(c, t, f, pass, next_bl, branch.split[2])?;
t.b.y -= hsz;
} else {
let node = intra_edge.node(sb128, edge_index);
Expand All @@ -3752,6 +3770,7 @@ fn decode_sb(
c,
t,
f,
pass,
bl,
dav1d_block_sizes[bl as usize][bp as usize][0],
bp,
Expand All @@ -3760,7 +3779,9 @@ fn decode_sb(
}
}

if t.frame_thread.pass != 2 && (bp != BlockPartition::Split || bl == BlockLevel::Bl8x8) {
if matches!(pass, FrameThreadPassState::First(_))
&& (bp != BlockPartition::Split || bl == BlockLevel::Bl8x8)
{
CaseSet::<16, false>::many(
[(&f.a[t.a], 0), (&t.l, 1)],
[hsz as usize; 2],
Expand Down Expand Up @@ -4111,7 +4132,14 @@ pub(crate) fn rav1d_decode_tile_sbrow(
if c.flush.load(Ordering::Acquire) {
return Err(());
}
decode_sb(c, t, f, root_bl, EdgeIndex::root())?;
decode_sb(
c,
t,
f,
&mut FrameThreadPassState::Second,
root_bl,
EdgeIndex::root(),
)?;
if t.b.x & 16 != 0 || f.seq_hdr().sb128 != 0 {
t.a += 1;
}
Expand Down Expand Up @@ -4224,7 +4252,14 @@ pub(crate) fn rav1d_decode_tile_sbrow(
read_restoration_info(ts, &mut lr, p, frame_type, debug_block_info!(f, t.b));
}
}
decode_sb(c, t, f, root_bl, EdgeIndex::root())?;
decode_sb(
c,
t,
f,
&mut FrameThreadPassState::First(&mut f.ts[t.ts].context.try_lock().unwrap()),
root_bl,
EdgeIndex::root(),
)?;
if t.b.x & 16 != 0 || f.seq_hdr().sb128 != 0 {
t.a += 1;
t.lf_mask = t.lf_mask.map(|i| i + 1);
Expand Down

0 comments on commit 2506ce8

Please sign in to comment.