Skip to content

Commit d9bc0e3

Browse files
authored
struct Rav1dTileStateContext: Extract locking higher in the call stack (#1263)
2 parents 093a17d + 2506ce8 commit d9bc0e3

File tree

1 file changed

+96
-61
lines changed

1 file changed

+96
-61
lines changed

src/decode.rs

Lines changed: 96 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,7 @@ fn decode_b(
11381138
c: &Rav1dContext,
11391139
t: &mut Rav1dTaskContext,
11401140
f: &Rav1dFrameData,
1141+
pass: &mut FrameThreadPassState,
11411142
bl: BlockLevel,
11421143
bs: BlockSize,
11431144
bp: BlockPartition,
@@ -1188,7 +1189,7 @@ fn decode_b(
11881189
&& (bh4 > ss_ver || t.b.y & 1 != 0);
11891190
let frame_type = f.frame_hdr.as_ref().unwrap().frame_type;
11901191

1191-
if t.frame_thread.pass == 2 {
1192+
let FrameThreadPassState::First(ts_c) = pass else {
11921193
match &b.ii {
11931194
Av1BlockIntraInter::Intra(intra) => {
11941195
(bd_fn.recon_b_intra)(f, t, None, bs, intra_edge_flags, b, intra);
@@ -1316,9 +1317,9 @@ fn decode_b(
13161317
}
13171318

13181319
return Ok(());
1319-
}
1320+
};
13201321

1321-
let ts_c = &mut *ts.context.try_lock().unwrap();
1322+
let ts_c = &mut **ts_c;
13221323

13231324
let cw4 = w4 + ss_hor >> ss_hor;
13241325
let ch4 = h4 + ss_ver >> ss_ver;
@@ -3433,10 +3434,16 @@ fn decode_b(
34333434
Ok(())
34343435
}
34353436

3437+
enum FrameThreadPassState<'a> {
3438+
First(&'a mut Rav1dTileStateContext),
3439+
Second,
3440+
}
3441+
34363442
fn decode_sb(
34373443
c: &Rav1dContext,
34383444
t: &mut Rav1dTaskContext,
34393445
f: &Rav1dFrameData,
3446+
pass: &mut FrameThreadPassState,
34403447
bl: BlockLevel,
34413448
edge_index: EdgeIndex,
34423449
) -> Result<(), ()> {
@@ -3457,6 +3464,7 @@ fn decode_sb(
34573464
c,
34583465
t,
34593466
f,
3467+
pass,
34603468
next_bl,
34613469
intra_edge.branch(sb128, edge_index).split[0],
34623470
);
@@ -3467,24 +3475,26 @@ fn decode_sb(
34673475
let bp;
34683476
let mut bx8 = 0;
34693477
let mut by8 = 0;
3470-
let ctx = if t.frame_thread.pass == 2 {
3471-
None
3472-
} else {
3473-
if false && bl == BlockLevel::Bl64x64 {
3474-
let ts_c = ts.context.try_lock().unwrap();
3475-
println!(
3476-
"poc={},y={},x={},bl={:?},r={}",
3477-
frame_hdr.frame_offset, t.b.y, t.b.x, bl, ts_c.msac.rng,
3478-
);
3478+
let ctx = match pass {
3479+
FrameThreadPassState::First(ts_c) => {
3480+
if false && bl == BlockLevel::Bl64x64 {
3481+
println!(
3482+
"poc={},y={},x={},bl={:?},r={}",
3483+
frame_hdr.frame_offset, t.b.y, t.b.x, bl, ts_c.msac.rng,
3484+
);
3485+
}
3486+
bx8 = (t.b.x & 31) >> 1;
3487+
by8 = (t.b.y & 31) >> 1;
3488+
Some((
3489+
get_partition_ctx(&f.a[t.a], &t.l, bl, by8, bx8),
3490+
&mut **ts_c,
3491+
))
34793492
}
3480-
bx8 = (t.b.x & 31) >> 1;
3481-
by8 = (t.b.y & 31) >> 1;
3482-
Some(get_partition_ctx(&f.a[t.a], &t.l, bl, by8, bx8))
3493+
FrameThreadPassState::Second => None,
34833494
};
34843495

34853496
if have_h_split && have_v_split {
3486-
if let Some(ctx) = ctx {
3487-
let ts_c = &mut *ts.context.try_lock().unwrap();
3497+
if let Some((ctx, ts_c)) = ctx {
34883498
let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize];
34893499
bp = BlockPartition::from_repr(rav1d_msac_decode_symbol_adapt16(
34903500
&mut ts_c.msac,
@@ -3525,37 +3535,46 @@ fn decode_sb(
35253535
match bp {
35263536
BlockPartition::None => {
35273537
let node = intra_edge.node(sb128, edge_index);
3528-
decode_b(c, t, f, bl, b[0], bp, node.o)?;
3538+
decode_b(c, t, f, pass, bl, b[0], bp, node.o)?;
35293539
}
35303540
BlockPartition::H => {
35313541
let node = intra_edge.node(sb128, edge_index);
3532-
decode_b(c, t, f, bl, b[0], bp, node.h[0])?;
3542+
decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?;
35333543
t.b.y += hsz;
3534-
decode_b(c, t, f, bl, b[0], bp, node.h[1])?;
3544+
decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?;
35353545
t.b.y -= hsz;
35363546
}
35373547
BlockPartition::V => {
35383548
let node = intra_edge.node(sb128, edge_index);
3539-
decode_b(c, t, f, bl, b[0], bp, node.v[0])?;
3549+
decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?;
35403550
t.b.x += hsz;
3541-
decode_b(c, t, f, bl, b[0], bp, node.v[1])?;
3551+
decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?;
35423552
t.b.x -= hsz;
35433553
}
35443554
BlockPartition::Split => {
35453555
match bl.decrease() {
35463556
None => {
35473557
let tip = intra_edge.tip(sb128, edge_index);
35483558
assert!(hsz == 1);
3549-
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, EdgeFlags::ALL_TR_AND_BL)?;
3559+
decode_b(
3560+
c,
3561+
t,
3562+
f,
3563+
pass,
3564+
bl,
3565+
BlockSize::Bs4x4,
3566+
bp,
3567+
EdgeFlags::ALL_TR_AND_BL,
3568+
)?;
35503569
let tl_filter = t.tl_4x4_filter;
35513570
t.b.x += 1;
3552-
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[0])?;
3571+
decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[0])?;
35533572
t.b.x -= 1;
35543573
t.b.y += 1;
3555-
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[1])?;
3574+
decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[1])?;
35563575
t.b.x += 1;
35573576
t.tl_4x4_filter = tl_filter;
3558-
decode_b(c, t, f, bl, BlockSize::Bs4x4, bp, tip.split[2])?;
3577+
decode_b(c, t, f, pass, bl, BlockSize::Bs4x4, bp, tip.split[2])?;
35593578
t.b.x -= 1;
35603579
t.b.y -= 1;
35613580
if cfg!(target_arch = "x86_64") && t.frame_thread.pass != 0 {
@@ -3571,92 +3590,91 @@ fn decode_sb(
35713590
}
35723591
Some(next_bl) => {
35733592
let branch = intra_edge.branch(sb128, edge_index);
3574-
decode_sb(c, t, f, next_bl, branch.split[0])?;
3593+
decode_sb(c, t, f, pass, next_bl, branch.split[0])?;
35753594
t.b.x += hsz;
3576-
decode_sb(c, t, f, next_bl, branch.split[1])?;
3595+
decode_sb(c, t, f, pass, next_bl, branch.split[1])?;
35773596
t.b.x -= hsz;
35783597
t.b.y += hsz;
3579-
decode_sb(c, t, f, next_bl, branch.split[2])?;
3598+
decode_sb(c, t, f, pass, next_bl, branch.split[2])?;
35803599
t.b.x += hsz;
3581-
decode_sb(c, t, f, next_bl, branch.split[3])?;
3600+
decode_sb(c, t, f, pass, next_bl, branch.split[3])?;
35823601
t.b.x -= hsz;
35833602
t.b.y -= hsz;
35843603
}
35853604
}
35863605
}
35873606
BlockPartition::TopSplit => {
35883607
let node = intra_edge.node(sb128, edge_index);
3589-
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
3608+
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
35903609
t.b.x += hsz;
3591-
decode_b(c, t, f, bl, b[0], bp, node.v[1])?;
3610+
decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?;
35923611
t.b.x -= hsz;
35933612
t.b.y += hsz;
3594-
decode_b(c, t, f, bl, b[1], bp, node.h[1])?;
3613+
decode_b(c, t, f, pass, bl, b[1], bp, node.h[1])?;
35953614
t.b.y -= hsz;
35963615
}
35973616
BlockPartition::BottomSplit => {
35983617
let node = intra_edge.node(sb128, edge_index);
3599-
decode_b(c, t, f, bl, b[0], bp, node.h[0])?;
3618+
decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?;
36003619
t.b.y += hsz;
3601-
decode_b(c, t, f, bl, b[1], bp, node.v[0])?;
3620+
decode_b(c, t, f, pass, bl, b[1], bp, node.v[0])?;
36023621
t.b.x += hsz;
3603-
decode_b(c, t, f, bl, b[1], bp, EdgeFlags::empty())?;
3622+
decode_b(c, t, f, pass, bl, b[1], bp, EdgeFlags::empty())?;
36043623
t.b.x -= hsz;
36053624
t.b.y -= hsz;
36063625
}
36073626
BlockPartition::LeftSplit => {
36083627
let node = intra_edge.node(sb128, edge_index);
3609-
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
3628+
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TR_AND_BL)?;
36103629
t.b.y += hsz;
3611-
decode_b(c, t, f, bl, b[0], bp, node.h[1])?;
3630+
decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?;
36123631
t.b.y -= hsz;
36133632
t.b.x += hsz;
3614-
decode_b(c, t, f, bl, b[1], bp, node.v[1])?;
3633+
decode_b(c, t, f, pass, bl, b[1], bp, node.v[1])?;
36153634
t.b.x -= hsz;
36163635
}
36173636
BlockPartition::RightSplit => {
36183637
let node = intra_edge.node(sb128, edge_index);
3619-
decode_b(c, t, f, bl, b[0], bp, node.v[0])?;
3638+
decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?;
36203639
t.b.x += hsz;
3621-
decode_b(c, t, f, bl, b[1], bp, node.h[0])?;
3640+
decode_b(c, t, f, pass, bl, b[1], bp, node.h[0])?;
36223641
t.b.y += hsz;
3623-
decode_b(c, t, f, bl, b[1], bp, EdgeFlags::empty())?;
3642+
decode_b(c, t, f, pass, bl, b[1], bp, EdgeFlags::empty())?;
36243643
t.b.y -= hsz;
36253644
t.b.x -= hsz;
36263645
}
36273646
BlockPartition::H4 => {
36283647
let branch = intra_edge.branch(sb128, edge_index);
36293648
let node = &branch.node;
3630-
decode_b(c, t, f, bl, b[0], bp, node.h[0])?;
3649+
decode_b(c, t, f, pass, bl, b[0], bp, node.h[0])?;
36313650
t.b.y += hsz >> 1;
3632-
decode_b(c, t, f, bl, b[0], bp, branch.h4)?;
3651+
decode_b(c, t, f, pass, bl, b[0], bp, branch.h4)?;
36333652
t.b.y += hsz >> 1;
3634-
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_LEFT_HAS_BOTTOM)?;
3653+
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_LEFT_HAS_BOTTOM)?;
36353654
t.b.y += hsz >> 1;
36363655
if t.b.y < f.bh {
3637-
decode_b(c, t, f, bl, b[0], bp, node.h[1])?;
3656+
decode_b(c, t, f, pass, bl, b[0], bp, node.h[1])?;
36383657
}
36393658
t.b.y -= hsz * 3 >> 1;
36403659
}
36413660
BlockPartition::V4 => {
36423661
let branch = intra_edge.branch(sb128, edge_index);
36433662
let node = &branch.node;
3644-
decode_b(c, t, f, bl, b[0], bp, node.v[0])?;
3663+
decode_b(c, t, f, pass, bl, b[0], bp, node.v[0])?;
36453664
t.b.x += hsz >> 1;
3646-
decode_b(c, t, f, bl, b[0], bp, branch.v4)?;
3665+
decode_b(c, t, f, pass, bl, b[0], bp, branch.v4)?;
36473666
t.b.x += hsz >> 1;
3648-
decode_b(c, t, f, bl, b[0], bp, EdgeFlags::ALL_TOP_HAS_RIGHT)?;
3667+
decode_b(c, t, f, pass, bl, b[0], bp, EdgeFlags::ALL_TOP_HAS_RIGHT)?;
36493668
t.b.x += hsz >> 1;
36503669
if t.b.x < f.bw {
3651-
decode_b(c, t, f, bl, b[0], bp, node.v[1])?;
3670+
decode_b(c, t, f, pass, bl, b[0], bp, node.v[1])?;
36523671
}
36533672
t.b.x -= hsz * 3 >> 1;
36543673
}
36553674
}
36563675
} else if have_h_split {
36573676
let is_split;
3658-
if let Some(ctx) = ctx {
3659-
let ts_c = &mut *ts.context.try_lock().unwrap();
3677+
if let Some((ctx, ts_c)) = ctx {
36603678
let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize];
36613679
is_split = rav1d_msac_decode_bool(&mut ts_c.msac, gather_top_partition_prob(pc, bl));
36623680
if debug_block_info!(f, t.b) {
@@ -3690,9 +3708,9 @@ fn decode_sb(
36903708
if is_split {
36913709
let branch = intra_edge.branch(sb128, edge_index);
36923710
bp = BlockPartition::Split;
3693-
decode_sb(c, t, f, next_bl, branch.split[0])?;
3711+
decode_sb(c, t, f, pass, next_bl, branch.split[0])?;
36943712
t.b.x += hsz;
3695-
decode_sb(c, t, f, next_bl, branch.split[1])?;
3713+
decode_sb(c, t, f, pass, next_bl, branch.split[1])?;
36963714
t.b.x -= hsz;
36973715
} else {
36983716
let node = intra_edge.node(sb128, edge_index);
@@ -3701,6 +3719,7 @@ fn decode_sb(
37013719
c,
37023720
t,
37033721
f,
3722+
pass,
37043723
bl,
37053724
dav1d_block_sizes[bl as usize][bp as usize][0],
37063725
bp,
@@ -3710,8 +3729,7 @@ fn decode_sb(
37103729
} else {
37113730
assert!(have_v_split);
37123731
let is_split;
3713-
if let Some(ctx) = ctx {
3714-
let ts_c = &mut *ts.context.try_lock().unwrap();
3732+
if let Some((ctx, ts_c)) = ctx {
37153733
let pc = &mut ts_c.cdf.m.partition[bl as usize][ctx as usize];
37163734
is_split = rav1d_msac_decode_bool(&mut ts_c.msac, gather_left_partition_prob(pc, bl));
37173735
if f.cur.p.layout == Rav1dPixelLayout::I422 && !is_split {
@@ -3748,9 +3766,9 @@ fn decode_sb(
37483766
if is_split {
37493767
let branch = intra_edge.branch(sb128, edge_index);
37503768
bp = BlockPartition::Split;
3751-
decode_sb(c, t, f, next_bl, branch.split[0])?;
3769+
decode_sb(c, t, f, pass, next_bl, branch.split[0])?;
37523770
t.b.y += hsz;
3753-
decode_sb(c, t, f, next_bl, branch.split[2])?;
3771+
decode_sb(c, t, f, pass, next_bl, branch.split[2])?;
37543772
t.b.y -= hsz;
37553773
} else {
37563774
let node = intra_edge.node(sb128, edge_index);
@@ -3759,6 +3777,7 @@ fn decode_sb(
37593777
c,
37603778
t,
37613779
f,
3780+
pass,
37623781
bl,
37633782
dav1d_block_sizes[bl as usize][bp as usize][0],
37643783
bp,
@@ -3767,7 +3786,9 @@ fn decode_sb(
37673786
}
37683787
}
37693788

3770-
if t.frame_thread.pass != 2 && (bp != BlockPartition::Split || bl == BlockLevel::Bl8x8) {
3789+
if matches!(pass, FrameThreadPassState::First(_))
3790+
&& (bp != BlockPartition::Split || bl == BlockLevel::Bl8x8)
3791+
{
37713792
CaseSet::<16, false>::many(
37723793
[(&f.a[t.a], 0), (&t.l, 1)],
37733794
[hsz as usize; 2],
@@ -4118,7 +4139,14 @@ pub(crate) fn rav1d_decode_tile_sbrow(
41184139
if c.flush.load(Ordering::Acquire) {
41194140
return Err(());
41204141
}
4121-
decode_sb(c, t, f, root_bl, EdgeIndex::root())?;
4142+
decode_sb(
4143+
c,
4144+
t,
4145+
f,
4146+
&mut FrameThreadPassState::Second,
4147+
root_bl,
4148+
EdgeIndex::root(),
4149+
)?;
41224150
if t.b.x & 16 != 0 || f.seq_hdr().sb128 != 0 {
41234151
t.a += 1;
41244152
}
@@ -4231,7 +4259,14 @@ pub(crate) fn rav1d_decode_tile_sbrow(
42314259
read_restoration_info(ts, &mut lr, p, frame_type, debug_block_info!(f, t.b));
42324260
}
42334261
}
4234-
decode_sb(c, t, f, root_bl, EdgeIndex::root())?;
4262+
decode_sb(
4263+
c,
4264+
t,
4265+
f,
4266+
&mut FrameThreadPassState::First(&mut f.ts[t.ts].context.try_lock().unwrap()),
4267+
root_bl,
4268+
EdgeIndex::root(),
4269+
)?;
42354270
if t.b.x & 16 != 0 || f.seq_hdr().sb128 != 0 {
42364271
t.a += 1;
42374272
t.lf_mask = t.lf_mask.map(|i| i + 1);

0 commit comments

Comments
 (0)