From 2506ce8511fa7e91e25a6124db35483149206cbc Mon Sep 17 00:00:00 2001 From: Stephen Crane Date: Wed, 26 Jun 2024 18:51:29 -0700 Subject: [PATCH] Extract `Rav1dTileStateContext` locking `decode_b` is hot but only sometimes needs to lock `Rav1dTileStateContext`. If `frame_thread.pass` is not 2, we can take the lock higher in the call chain and avoid repeatedly locking and unlocking the context. --- src/decode.rs | 157 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 96 insertions(+), 61 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 8cdd2c2d6..e0860f63e 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1138,6 +1138,7 @@ fn decode_b( c: &Rav1dContext, t: &mut Rav1dTaskContext, f: &Rav1dFrameData, + pass: &mut FrameThreadPassState, bl: BlockLevel, bs: BlockSize, bp: BlockPartition, @@ -1188,7 +1189,7 @@ fn decode_b( && (bh4 > ss_ver || t.b.y & 1 != 0); let frame_type = f.frame_hdr.as_ref().unwrap().frame_type; - if t.frame_thread.pass == 2 { + let FrameThreadPassState::First(ts_c) = pass else { match &b.ii { Av1BlockIntraInter::Intra(intra) => { (bd_fn.recon_b_intra)(f, t, None, bs, intra_edge_flags, b, intra); @@ -1316,9 +1317,9 @@ fn decode_b( } return Ok(()); - } + }; - let ts_c = &mut *ts.context.try_lock().unwrap(); + let ts_c = &mut **ts_c; let cw4 = w4 + ss_hor >> ss_hor; let ch4 = h4 + ss_ver >> ss_ver; @@ -3426,10 +3427,16 @@ fn decode_b( Ok(()) } +enum FrameThreadPassState<'a> { + First(&'a mut Rav1dTileStateContext), + Second, +} + fn decode_sb( c: &Rav1dContext, t: &mut Rav1dTaskContext, f: &Rav1dFrameData, + pass: &mut FrameThreadPassState, bl: BlockLevel, edge_index: EdgeIndex, ) -> Result<(), ()> { @@ -3450,6 +3457,7 @@ fn decode_sb( c, t, f, + pass, next_bl, intra_edge.branch(sb128, edge_index).split[0], ); @@ -3460,24 +3468,26 @@ fn decode_sb( let bp; let mut bx8 = 0; let mut by8 = 0; - let ctx = if t.frame_thread.pass == 2 { - None - } else { - if false && bl == BlockLevel::Bl64x64 { - let ts_c = ts.context.try_lock().unwrap(); - println!( - "poc={},y={},x={},bl={:?},r={}", - frame_hdr.frame_offset, t.b.y, t.b.x, bl, ts_c.msac.rng, - ); + let ctx = match pass { + FrameThreadPassState::First(ts_c) => { + if false && bl == BlockLevel::Bl64x64 { + println!( + "poc={},y={},x={},bl={:?},r={}", + frame_hdr.frame_offset, t.b.y, t.b.x, bl, ts_c.msac.rng, + ); + } + bx8 = (t.b.x & 31) >> 1; + by8 = (t.b.y & 31) >> 1; + Some(( + get_partition_ctx(&f.a[t.a], &t.l, bl, by8, bx8), + &mut **ts_c, + )) } - bx8 = (t.b.x & 31) >> 1; - by8 = (t.b.y & 31) >> 1; - Some(get_partition_ctx(&f.a[t.a], &t.l, bl, by8, bx8)) + FrameThreadPassState::Second => None, }; if have_h_split && have_v_split { - if let Some(ctx) = ctx { - let ts_c = &mut *ts.context.try_lock().unwrap(); + if let Some((ctx, ts_c)) = ctx { let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize]; bp = BlockPartition::from_repr(rav1d_msac_decode_symbol_adapt16( &mut ts_c.msac, @@ -3518,20 +3528,20 @@ fn decode_sb( match bp { BlockPartition::None => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, node.o)?; + decode_b(c, t, f, pass, bl, b[0], bp, node.o)?; } BlockPartition::H => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, node.h[0])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?; t.b.y += hsz; - decode_b(c, t, f, bl, b[0], bp, node.h[1])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?; t.b.y -= hsz; } BlockPartition::V => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, node.v[0])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?; t.b.x += hsz; - decode_b(c, t, f, bl, b[0], bp, node.v[1])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?; t.b.x -= hsz; } BlockPartition::Split => { @@ -3539,16 +3549,25 @@ fn decode_sb( None => { let tip = intra_edge.tip(sb128, edge_index); assert!(hsz == 1); - decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, EdgeFlags::ALL_TR_AND_BL)?; + decode_b( + c, + t, + f, + pass, + bl, + BlockSize::Bs4x4, + bp, + EdgeFlags::ALL_TR_AND_BL, + )?; let tl_filter = t.tl_4x4_filter; t.b.x += 1; - decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[0])?; + decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[0])?; t.b.x -= 1; t.b.y += 1; - decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[1])?; + decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[1])?; t.b.x += 1; t.tl_4x4_filter = tl_filter; - decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[2])?; + decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[2])?; t.b.x -= 1; t.b.y -= 1; if cfg!(target_arch = "x86_64") && t.frame_thread.pass != 0 { @@ -3564,14 +3583,14 @@ fn decode_sb( } Some(next_bl) => { let branch = intra_edge.branch(sb128, edge_index); - decode_sb(c, t, f, next_bl, branch.split[0])?; + decode_sb(c, t, f, pass, next_bl, branch.split[0])?; t.b.x += hsz; - decode_sb(c, t, f, next_bl, branch.split[1])?; + decode_sb(c, t, f, pass, next_bl, branch.split[1])?; t.b.x -= hsz; t.b.y += hsz; - decode_sb(c, t, f, next_bl, branch.split[2])?; + decode_sb(c, t, f, pass, next_bl, branch.split[2])?; t.b.x += hsz; - decode_sb(c, t, f, next_bl, branch.split[3])?; + decode_sb(c, t, f, pass, next_bl, branch.split[3])?; t.b.x -= hsz; t.b.y -= hsz; } @@ -3579,77 +3598,76 @@ fn decode_sb( } BlockPartition::TopSplit => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?; + decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?; t.b.x += hsz; - decode_b(c, t, f, bl, b[0], bp, node.v[1])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?; t.b.x -= hsz; t.b.y += hsz; - decode_b(c, t, f, bl, b[1], bp, node.h[1])?; + decode_b(c, t, f, pass, bl, b[1], bp, node.h[1])?; t.b.y -= hsz; } BlockPartition::BottomSplit => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, node.h[0])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?; t.b.y += hsz; - decode_b(c, t, f, bl, b[1], bp, node.v[0])?; + decode_b(c, t, f, pass, bl, b[1], bp, node.v[0])?; t.b.x += hsz; - decode_b(c, t, f, bl, b[1], bp, EdgeFlags::empty())?; + decode_b(c, t, f, pass, bl, b[1], bp, EdgeFlags::empty())?; t.b.x -= hsz; t.b.y -= hsz; } BlockPartition::LeftSplit => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?; + decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?; t.b.y += hsz; - decode_b(c, t, f, bl, b[0], bp, node.h[1])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?; t.b.y -= hsz; t.b.x += hsz; - decode_b(c, t, f, bl, b[1], bp, node.v[1])?; + decode_b(c, t, f, pass, bl, b[1], bp, node.v[1])?; t.b.x -= hsz; } BlockPartition::RightSplit => { let node = intra_edge.node(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, node.v[0])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?; t.b.x += hsz; - decode_b(c, t, f, bl, b[1], bp, node.h[0])?; + decode_b(c, t, f, pass, bl, b[1], bp, node.h[0])?; t.b.y += hsz; - decode_b(c, t, f, bl, b[1], bp, EdgeFlags::empty())?; + decode_b(c, t, f, pass, bl, b[1], bp, EdgeFlags::empty())?; t.b.y -= hsz; t.b.x -= hsz; } BlockPartition::H4 => { let branch = intra_edge.branch(sb128, edge_index); let node = &branch.node; - decode_b(c, t, f, bl, b[0], bp, node.h[0])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?; t.b.y += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, branch.h4)?; + decode_b(c, t, f, pass, bl, b[0], bp, branch.h4)?; t.b.y += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_LEFT_HAS_BOTTOM)?; + decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_LEFT_HAS_BOTTOM)?; t.b.y += hsz >> 1; if t.b.y < f.bh { - decode_b(c, t, f, bl, b[0], bp, node.h[1])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?; } t.b.y -= hsz * 3 >> 1; } BlockPartition::V4 => { let branch = intra_edge.branch(sb128, edge_index); let node = &branch.node; - decode_b(c, t, f, bl, b[0], bp, node.v[0])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?; t.b.x += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, branch.v4)?; + decode_b(c, t, f, pass, bl, b[0], bp, branch.v4)?; t.b.x += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TOP_HAS_RIGHT)?; + decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TOP_HAS_RIGHT)?; t.b.x += hsz >> 1; if t.b.x < f.bw { - decode_b(c, t, f, bl, b[0], bp, node.v[1])?; + decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?; } t.b.x -= hsz * 3 >> 1; } } } else if have_h_split { let is_split; - if let Some(ctx) = ctx { - let ts_c = &mut *ts.context.try_lock().unwrap(); + if let Some((ctx, ts_c)) = ctx { let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize]; is_split = rav1d_msac_decode_bool(&mut ts_c.msac, gather_top_partition_prob(pc, bl)); if debug_block_info!(f, t.b) { @@ -3683,9 +3701,9 @@ fn decode_sb( if is_split { let branch = intra_edge.branch(sb128, edge_index); bp = BlockPartition::Split; - decode_sb(c, t, f, next_bl, branch.split[0])?; + decode_sb(c, t, f, pass, next_bl, branch.split[0])?; t.b.x += hsz; - decode_sb(c, t, f, next_bl, branch.split[1])?; + decode_sb(c, t, f, pass, next_bl, branch.split[1])?; t.b.x -= hsz; } else { let node = intra_edge.node(sb128, edge_index); @@ -3694,6 +3712,7 @@ fn decode_sb( c, t, f, + pass, bl, dav1d_block_sizes[bl as usize][bp as usize][0], bp, @@ -3703,8 +3722,7 @@ fn decode_sb( } else { assert!(have_v_split); let is_split; - if let Some(ctx) = ctx { - let ts_c = &mut *ts.context.try_lock().unwrap(); + if let Some((ctx, ts_c)) = ctx { let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize]; is_split = rav1d_msac_decode_bool(&mut ts_c.msac, gather_left_partition_prob(pc, bl)); if f.cur.p.layout == Rav1dPixelLayout::I422 && !is_split { @@ -3741,9 +3759,9 @@ fn decode_sb( if is_split { let branch = intra_edge.branch(sb128, edge_index); bp = BlockPartition::Split; - decode_sb(c, t, f, next_bl, branch.split[0])?; + decode_sb(c, t, f, pass, next_bl, branch.split[0])?; t.b.y += hsz; - decode_sb(c, t, f, next_bl, branch.split[2])?; + decode_sb(c, t, f, pass, next_bl, branch.split[2])?; t.b.y -= hsz; } else { let node = intra_edge.node(sb128, edge_index); @@ -3752,6 +3770,7 @@ fn decode_sb( c, t, f, + pass, bl, dav1d_block_sizes[bl as usize][bp as usize][0], bp, @@ -3760,7 +3779,9 @@ fn decode_sb( } } - if t.frame_thread.pass != 2 && (bp != BlockPartition::Split || bl == BlockLevel::Bl8x8) { + if matches!(pass, FrameThreadPassState::First(_)) + && (bp != BlockPartition::Split || bl == BlockLevel::Bl8x8) + { CaseSet::<16, false>::many( [(&f.a[t.a], 0), (&t.l, 1)], [hsz as usize; 2], @@ -4111,7 +4132,14 @@ pub(crate) fn rav1d_decode_tile_sbrow( if c.flush.load(Ordering::Acquire) { return Err(()); } - decode_sb(c, t, f, root_bl, EdgeIndex::root())?; + decode_sb( + c, + t, + f, + &mut FrameThreadPassState::Second, + root_bl, + EdgeIndex::root(), + )?; if t.b.x & 16 != 0 || f.seq_hdr().sb128 != 0 { t.a += 1; } @@ -4224,7 +4252,14 @@ pub(crate) fn rav1d_decode_tile_sbrow( read_restoration_info(ts, &mut lr, p, frame_type, debug_block_info!(f, t.b)); } } - decode_sb(c, t, f, root_bl, EdgeIndex::root())?; + decode_sb( + c, + t, + f, + &mut FrameThreadPassState::First(&mut f.ts[t.ts].context.try_lock().unwrap()), + root_bl, + EdgeIndex::root(), + )?; if t.b.x & 16 != 0 || f.seq_hdr().sb128 != 0 { t.a += 1; t.lf_mask = t.lf_mask.map(|i| i + 1);