Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 54d98c8

Browse files
committed
builder: generalize the panic format_args! remover to handle runtime args.
1 parent e9cdb96 commit 54d98c8

File tree

1 file changed

+194
-54
lines changed

1 file changed

+194
-54
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 194 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ use rustc_codegen_ssa::traits::{
1616
BackendTypes, BuilderMethods, ConstMethods, IntrinsicCallMethods, LayoutTypeMethods, OverflowOp,
1717
};
1818
use rustc_codegen_ssa::MemFlags;
19+
use rustc_data_structures::fx::FxHashSet;
1920
use rustc_middle::bug;
2021
use rustc_middle::ty::Ty;
2122
use rustc_span::Span;
2223
use rustc_target::abi::call::FnAbi;
2324
use rustc_target::abi::{Abi, Align, Scalar, Size, WrappingRange};
25+
use smallvec::SmallVec;
2426
use std::convert::TryInto;
2527
use std::iter::{self, empty};
2628

@@ -2393,7 +2395,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23932395
// nor simplified in MIR (e.g. promoted to a constant) in any way,
23942396
// so we have to try and remove the `fmt::Arguments::new` call here.
23952397
// HACK(eddyb) this is basically a `try` block.
2396-
let remove_simple_format_args_if_possible = || -> Option<()> {
2398+
let remove_format_args_if_possible = || -> Option<()> {
23972399
let format_args_id = match args {
23982400
&[SpirvValue {
23992401
kind: SpirvValueKind::Def(format_args_id),
@@ -2414,6 +2416,19 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
24142416
let func_idx = builder.selected_function().unwrap();
24152417
let block_idx = builder.selected_block().unwrap();
24162418
let func = &mut builder.module_mut().functions[func_idx];
2419+
2420+
// HACK(eddyb) this is used to check that all `Op{Store,Load}`s
2421+
// that may get removed, operate on local `OpVariable`s,
2422+
// i.e. are not externally observable.
2423+
let local_var_ids: FxHashSet<_> = func.blocks[0]
2424+
.instructions
2425+
.iter()
2426+
.take_while(|inst| inst.class.opcode == Op::Variable)
2427+
.map(|inst| inst.result_id.unwrap())
2428+
.collect();
2429+
let require_local_var =
2430+
|ptr_id| Some(()).filter(|()| local_var_ids.contains(&ptr_id));
2431+
24172432
let mut non_debug_insts = func.blocks[block_idx]
24182433
.instructions
24192434
.iter()
@@ -2426,68 +2441,193 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
24262441
.contains(&CustomOp::decode_from_ext_inst(inst));
24272442
!(is_standard_debug || is_custom_debug)
24282443
});
2429-
let mut relevant_insts_next_back = |expected_op| {
2430-
non_debug_insts
2431-
.next_back()
2432-
.filter(|(_, inst)| inst.class.opcode == expected_op)
2433-
.map(|(i, inst)| {
2434-
(
2435-
i,
2436-
inst.result_id,
2437-
inst.operands.iter().map(|operand| operand.unwrap_id_ref()),
2438-
)
2439-
})
2444+
2445+
// HACK(eddyb) to aid in pattern-matching, relevant instructions
2446+
// are decoded to values of this `enum`. For instructions that
2447+
// produce results, the result ID is the first `ID` value.
2448+
#[derive(Debug)]
2449+
enum Inst<'tcx, ID> {
2450+
Bitcast(ID, ID),
2451+
AccessChain(ID, ID, SpirvConst<'tcx>),
2452+
InBoundsAccessChain(ID, ID, SpirvConst<'tcx>),
2453+
Store(ID, ID),
2454+
Load(ID, ID),
2455+
Call(ID, ID, SmallVec<[ID; 4]>),
2456+
}
2457+
2458+
let mut taken_inst_idx_range = func.blocks[block_idx].instructions.len()..;
2459+
2460+
// Take `count` instructions, advancing backwards, but returning
2461+
// instructions in their original order (and decoded to `Inst`s).
2462+
let mut try_rev_take = |count| {
2463+
let maybe_rev_insts = (0..count).map(|_| {
2464+
let (i, inst) = non_debug_insts.next_back()?;
2465+
taken_inst_idx_range = i..;
2466+
2467+
// HACK(eddyb) all instructions accepted below
2468+
// are expected to take no more than 4 operands,
2469+
// and this is easier to use than an iterator.
2470+
let id_operands = inst
2471+
.operands
2472+
.iter()
2473+
.map(|operand| operand.id_ref_any())
2474+
.collect::<Option<SmallVec<[_; 4]>>>()?;
2475+
2476+
// Decode the instruction into one of our `Inst`s.
2477+
Some(
2478+
match (inst.class.opcode, inst.result_id, &id_operands[..]) {
2479+
(Op::Bitcast, Some(r), &[x]) => Inst::Bitcast(r, x),
2480+
(Op::AccessChain, Some(r), &[p, i]) => {
2481+
Inst::AccessChain(r, p, self.builder.lookup_const_by_id(i)?)
2482+
}
2483+
(Op::InBoundsAccessChain, Some(r), &[p, i]) => {
2484+
Inst::InBoundsAccessChain(
2485+
r,
2486+
p,
2487+
self.builder.lookup_const_by_id(i)?,
2488+
)
2489+
}
2490+
(Op::Store, None, &[p, v]) => Inst::Store(p, v),
2491+
(Op::Load, Some(r), &[p]) => Inst::Load(r, p),
2492+
(Op::FunctionCall, Some(r), [f, args @ ..]) => {
2493+
Inst::Call(r, *f, args.iter().copied().collect())
2494+
}
2495+
_ => return None,
2496+
},
2497+
)
2498+
});
2499+
let mut insts = maybe_rev_insts.collect::<Option<SmallVec<[_; 4]>>>()?;
2500+
insts.reverse();
2501+
Some(insts)
24402502
};
2441-
let (_, load_src_id) = relevant_insts_next_back(Op::Load)
2442-
.map(|(_, result_id, mut operands)| {
2443-
(result_id.unwrap(), operands.next().unwrap())
2444-
})
2445-
.filter(|&(result_id, _)| result_id == format_args_id)?;
2446-
let (_, store_val_id) = relevant_insts_next_back(Op::Store)
2447-
.map(|(_, _, mut operands)| {
2448-
(operands.next().unwrap(), operands.next().unwrap())
2449-
})
2450-
.filter(|&(store_dst_id, _)| store_dst_id == load_src_id)?;
2451-
let call_fmt_args_new_idx = relevant_insts_next_back(Op::FunctionCall)
2452-
.filter(|&(_, result_id, _)| result_id == Some(store_val_id))
2453-
.map(|(i, _, mut operands)| (i, operands.next().unwrap(), operands))
2454-
.filter(|&(_, callee, _)| self.fmt_args_new_fn_ids.borrow().contains(&callee))
2455-
.and_then(|(i, _, mut call_args)| {
2456-
if call_args.len() == 4 {
2503+
2504+
let (rt_args_slice_ptr_id, rt_args_count) = match try_rev_take(3)?[..] {
2505+
[
2506+
// HACK(eddyb) comment works around `rustfmt` array bugs.
2507+
Inst::Call(call_ret_id, callee_id, ref call_args),
2508+
Inst::Store(st_dst_id, st_val_id),
2509+
Inst::Load(ld_val_id, ld_src_id),
2510+
]
2511+
if self.fmt_args_new_fn_ids.borrow().contains(&callee_id)
2512+
&& call_ret_id == st_val_id
2513+
&& st_dst_id == ld_src_id
2514+
&& ld_val_id == format_args_id =>
2515+
{
2516+
require_local_var(st_dst_id)?;
2517+
match call_args[..] {
24572518
// `<core::fmt::Arguments>::new_v1`
2458-
let mut arg = || call_args.next().unwrap();
2459-
let [_, _, _, fmt_args_len_id] = [arg(), arg(), arg(), arg()];
2460-
// Only ever remove `fmt::Arguments` with no runtime values.
2461-
Some(i).filter(|_| {
2462-
matches!(
2463-
self.builder.lookup_const_by_id(fmt_args_len_id),
2464-
Some(SpirvConst::U32(0))
2465-
)
2466-
})
2467-
} else {
2519+
[_, _, rt_args_slice_ptr_id, rt_args_len_id] => (
2520+
Some(rt_args_slice_ptr_id),
2521+
self.builder
2522+
.lookup_const_by_id(rt_args_len_id)
2523+
.and_then(|ct| match ct {
2524+
SpirvConst::U32(x) => Some(x as usize),
2525+
_ => None,
2526+
})?,
2527+
),
2528+
24682529
// `<core::fmt::Arguments>::new_const`
2469-
assert_eq!(call_args.len(), 2);
2470-
Some(i)
2530+
[_, _] => (None, 0),
2531+
2532+
_ => return None,
24712533
}
2472-
})?;
2534+
}
2535+
_ => return None,
2536+
};
24732537

2474-
// Lastly, ensure that the `Op{Store,Load}` pair operates on
2475-
// a local `OpVariable`, i.e. is not externally observable.
2476-
let store_load_local_var = func.blocks[0]
2477-
.instructions
2478-
.iter()
2479-
.take_while(|inst| inst.class.opcode == Op::Variable)
2480-
.find(|inst| inst.result_id == Some(load_src_id));
2481-
if store_load_local_var.is_some() {
2482-
// Keep all instructions up to (but not including) the call.
2483-
func.blocks[block_idx]
2484-
.instructions
2485-
.truncate(call_fmt_args_new_idx);
2538+
// HACK(eddyb) this is the worst part: if we do have runtime
2539+
// arguments (from e.g. new `assert!`s being added to `core`),
2540+
// we have to confirm their many instructions for removal.
2541+
if rt_args_count > 0 {
2542+
let rt_args_slice_ptr_id = rt_args_slice_ptr_id.unwrap();
2543+
let rt_args_array_ptr_id = match try_rev_take(1)?[..] {
2544+
[Inst::Bitcast(out_id, in_id)] if out_id == rt_args_slice_ptr_id => in_id,
2545+
_ => return None,
2546+
};
2547+
require_local_var(rt_args_array_ptr_id);
2548+
2549+
// Each runtime argument has its own variable, 6 instructions
2550+
// to initialize it, and 9 instructions to copy it to the
2551+
// appropriate slot in the array. The groups of 6 and 9
2552+
// instructions, for all runtime args, are each separate.
2553+
let copies_from_rt_arg_vars_to_rt_args_array = try_rev_take(rt_args_count * 9)?;
2554+
let copies_from_rt_arg_vars_to_rt_args_array =
2555+
copies_from_rt_arg_vars_to_rt_args_array.chunks(9);
2556+
let inits_of_rt_arg_vars = try_rev_take(rt_args_count * 6)?;
2557+
let inits_of_rt_arg_vars = inits_of_rt_arg_vars.chunks(6);
2558+
2559+
for (
2560+
rt_arg_idx,
2561+
(init_of_rt_arg_var_insts, copy_from_rt_arg_var_to_rt_args_array_insts),
2562+
) in inits_of_rt_arg_vars
2563+
.zip(copies_from_rt_arg_vars_to_rt_args_array)
2564+
.enumerate()
2565+
{
2566+
let rt_arg_var_id = match init_of_rt_arg_var_insts[..] {
2567+
[
2568+
// HACK(eddyb) comment works around `rustfmt` array bugs.
2569+
Inst::Bitcast(b, _),
2570+
Inst::Bitcast(a, _),
2571+
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
2572+
Inst::Store(a_st_dst, a_st_val),
2573+
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
2574+
Inst::Store(b_st_dst, b_st_val),
2575+
] if a_base_ptr == b_base_ptr
2576+
&& (a, b) == (a_st_val, b_st_val)
2577+
&& (a_ptr, b_ptr) == (a_st_dst, b_st_dst) =>
2578+
{
2579+
require_local_var(a_base_ptr);
2580+
a_base_ptr
2581+
}
2582+
_ => return None,
2583+
};
2584+
2585+
// HACK(eddyb) this is only split to allow variable name reuse.
2586+
let (copy_loads, copy_stores) =
2587+
copy_from_rt_arg_var_to_rt_args_array_insts.split_at(4);
2588+
let (a, b) = match copy_loads[..] {
2589+
[
2590+
// HACK(eddyb) comment works around `rustfmt` array bugs.
2591+
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
2592+
Inst::Load(a_ld_val, a_ld_src),
2593+
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
2594+
Inst::Load(b_ld_val, b_ld_src),
2595+
] if [a_base_ptr, b_base_ptr] == [rt_arg_var_id; 2]
2596+
&& (a_ptr, b_ptr) == (a_ld_src, b_ld_src) =>
2597+
{
2598+
(a_ld_val, b_ld_val)
2599+
}
2600+
_ => return None,
2601+
};
2602+
match copy_stores[..] {
2603+
[
2604+
// HACK(eddyb) comment works around `rustfmt` array bugs.
2605+
Inst::InBoundsAccessChain(array_slot_ptr, array_base_ptr, SpirvConst::U32(array_idx)),
2606+
Inst::AccessChain(a_ptr, a_base_ptr, SpirvConst::U32(0)),
2607+
Inst::Store(a_st_dst, a_st_val),
2608+
Inst::AccessChain(b_ptr, b_base_ptr, SpirvConst::U32(1)),
2609+
Inst::Store(b_st_dst, b_st_val),
2610+
] if array_base_ptr == rt_args_array_ptr_id
2611+
&& array_idx as usize == rt_arg_idx
2612+
&& [a_base_ptr, b_base_ptr] == [array_slot_ptr; 2]
2613+
&& (a, b) == (a_st_val, b_st_val)
2614+
&& (a_ptr, b_ptr) == (a_st_dst, b_st_dst) =>
2615+
{
2616+
}
2617+
_ => return None,
2618+
}
2619+
}
24862620
}
24872621

2622+
// Keep all instructions up to (but not including) the last one
2623+
// confirmed above to be the first instruction of `format_args!`.
2624+
func.blocks[block_idx]
2625+
.instructions
2626+
.truncate(taken_inst_idx_range.start);
2627+
24882628
None
24892629
};
2490-
remove_simple_format_args_if_possible();
2630+
remove_format_args_if_possible();
24912631

24922632
// HACK(eddyb) redirect any possible panic call to an abort, to avoid
24932633
// needing to materialize `&core::panic::Location` or `format_args!`.

0 commit comments

Comments
 (0)