Skip to content

Commit d1c7945

Browse files
committed
Auto merge of #153379 - TKanX:refactor/149164-simplify-autodiff-rlib, r=ZuseZ4
refactor(autodiff): Simplify Autodiff Handling of `rlib` Dependencies ### Summary: Resolves the two FIXMEs left in #149033, per @bjorn3 guidance in [the discussion](#149033 (comment)). Closes #149164 r? @ZuseZ4 cc @bjorn3
2 parents b49ecc9 + 8a3d0f4 commit d1c7945

File tree

9 files changed

+101
-122
lines changed

9 files changed

+101
-122
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ mod llvm_enzyme {
2020
};
2121
use rustc_expand::base::{Annotatable, ExtCtxt};
2222
use rustc_hir::attrs::RustcAutodiff;
23-
use rustc_span::{Ident, Span, Symbol, sym};
23+
use rustc_span::{Ident, Span, Symbol, kw, sym};
2424
use thin_vec::{ThinVec, thin_vec};
2525
use tracing::{debug, trace};
2626

@@ -197,7 +197,7 @@ mod llvm_enzyme {
197197
/// }
198198
/// #[rustc_autodiff(Reverse, Duplicated, Active)]
199199
/// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
200-
/// std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
200+
/// std::intrinsics::autodiff(sin::<> as fn(..) -> .., cos_box::<>, (x, dx, dret))
201201
/// }
202202
/// ```
203203
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
@@ -326,6 +326,7 @@ mod llvm_enzyme {
326326
primal,
327327
first_ident(&meta_item_vec[0]),
328328
span,
329+
&sig,
329330
&d_sig,
330331
&generics,
331332
is_impl,
@@ -496,18 +497,62 @@ mod llvm_enzyme {
496497

497498
// Generate `autodiff` intrinsic call
498499
// ```
499-
// std::intrinsics::autodiff(source, diff, (args))
500+
// std::intrinsics::autodiff(source as fn(..) -> .., diff, (args))
500501
// ```
501502
fn call_autodiff(
502503
ecx: &ExtCtxt<'_>,
503504
primal: Ident,
504505
diff: Ident,
505506
span: Span,
507+
p_sig: &FnSig,
506508
d_sig: &FnSig,
507509
generics: &Generics,
508510
is_impl: bool,
509511
) -> rustc_ast::Stmt {
510512
let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
513+
514+
let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));
515+
let fn_ptr_params: ThinVec<ast::Param> = p_sig
516+
.decl
517+
.inputs
518+
.iter()
519+
.map(|param| {
520+
let ty = match &param.ty.kind {
521+
TyKind::ImplicitSelf => self_ty(),
522+
TyKind::Ref(lt, mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(
523+
span,
524+
TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),
525+
),
526+
TyKind::Ptr(mt) if matches!(mt.ty.kind, TyKind::ImplicitSelf) => {
527+
ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))
528+
}
529+
_ => param.ty.clone(),
530+
};
531+
ast::Param {
532+
attrs: ast::AttrVec::new(),
533+
ty,
534+
pat: Box::new(ecx.pat_wild(span)),
535+
id: ast::DUMMY_NODE_ID,
536+
span,
537+
is_placeholder: false,
538+
}
539+
})
540+
.collect();
541+
let fn_ptr_ty = ecx.ty(
542+
span,
543+
TyKind::FnPtr(Box::new(ast::FnPtrTy {
544+
safety: p_sig.header.safety,
545+
ext: p_sig.header.ext,
546+
generic_params: ThinVec::new(),
547+
decl: Box::new(ast::FnDecl {
548+
inputs: fn_ptr_params,
549+
output: p_sig.decl.output.clone(),
550+
}),
551+
decl_span: span,
552+
})),
553+
);
554+
let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));
555+
511556
let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
512557

513558
let tuple_expr = ecx.expr_tuple(
@@ -529,7 +574,7 @@ mod llvm_enzyme {
529574
let call_expr = ecx.expr_call(
530575
span,
531576
ecx.expr_path(enzyme_path),
532-
vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
577+
vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),
533578
);
534579

535580
ecx.stmt_expr(call_expr)

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_codegen_ssa::common::TypeKind;
66
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
77
use rustc_data_structures::thin_vec::ThinVec;
88
use rustc_hir::attrs::RustcAutodiff;
9-
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
9+
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
1010
use rustc_middle::{bug, ty};
1111
use rustc_target::callconv::PassMode;
1212
use tracing::debug;
@@ -18,25 +18,23 @@ use crate::llvm::{self, TRUE, Type, Value};
1818

1919
pub(crate) fn adjust_activity_to_abi<'tcx>(
2020
tcx: TyCtxt<'tcx>,
21-
instance: Instance<'tcx>,
21+
fn_ptr_ty: Ty<'tcx>,
2222
typing_env: TypingEnv<'tcx>,
2323
da: &mut ThinVec<DiffActivity>,
2424
) {
25-
let fn_ty = instance.ty(tcx, typing_env);
26-
27-
if !matches!(fn_ty.kind(), ty::FnDef(..)) {
28-
bug!("expected fn def for autodiff, got {:?}", fn_ty);
25+
if !matches!(fn_ptr_ty.kind(), ty::FnPtr(..)) {
26+
bug!("expected fn ptr for autodiff, got {:?}", fn_ptr_ty);
2927
}
3028

3129
// We don't actually pass the types back into the type system.
3230
// All we do is decide how to handle the arguments.
33-
let sig = fn_ty.fn_sig(tcx).skip_binder();
31+
let fn_sig = fn_ptr_ty.fn_sig(tcx);
32+
let sig = fn_sig.skip_binder();
3433

3534
// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
36-
let Ok(fn_abi) =
37-
tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
35+
let Ok(fn_abi) = tcx.fn_abi_of_fn_ptr(typing_env.as_query_input((fn_sig, ty::List::empty())))
3836
else {
39-
bug!("failed to get fn_abi of instance with empty varargs");
37+
bug!("failed to get fn_abi of fn_ptr with empty varargs");
4038
};
4139

4240
let mut new_activities = vec![];

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,29 +1322,8 @@ fn codegen_autodiff<'ll, 'tcx>(
13221322
let ret_ty = sig.output();
13231323
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
13241324

1325-
// Get source, diff, and attrs
1326-
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
1327-
ty::FnDef(def_id, source_params) => (def_id, source_params),
1328-
_ => bug!("invalid autodiff intrinsic args"),
1329-
};
1330-
1331-
let fn_source = match Instance::try_resolve(tcx, bx.cx.typing_env(), *source_id, source_args) {
1332-
Ok(Some(instance)) => instance,
1333-
Ok(None) => bug!(
1334-
"could not resolve ({:?}, {:?}) to a specific autodiff instance",
1335-
source_id,
1336-
source_args
1337-
),
1338-
Err(_) => {
1339-
// An error has already been emitted
1340-
return;
1341-
}
1342-
};
1343-
1344-
let source_symbol = symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
1345-
let Some(fn_to_diff) = bx.cx.get_function(&source_symbol) else {
1346-
bug!("could not find source function")
1347-
};
1325+
let source_fn_ptr_ty = fn_args.into_type_list(tcx)[0];
1326+
let fn_to_diff = args[0].immediate();
13481327

13491328
let (diff_id, diff_args) = match fn_args.into_type_list(tcx)[1].kind() {
13501329
ty::FnDef(def_id, diff_args) => (def_id, diff_args),
@@ -1375,13 +1354,12 @@ fn codegen_autodiff<'ll, 'tcx>(
13751354

13761355
adjust_activity_to_abi(
13771356
tcx,
1378-
fn_source,
1357+
source_fn_ptr_ty,
13791358
TypingEnv::fully_monomorphized(),
13801359
&mut diff_attrs.input_activity,
13811360
);
13821361

1383-
let fnc_tree =
1384-
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
1362+
let fnc_tree = rustc_middle::ty::fnc_typetrees(tcx, source_fn_ptr_ty);
13851363

13861364
// Build body
13871365
generate_enzyme_call(

compiler/rustc_mir_transform/src/cross_crate_inline.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ fn cross_crate_inlinable(tcx: TyCtxt<'_>, def_id: LocalDefId) -> bool {
3535
return true;
3636
}
3737

38-
// FIXME(autodiff): replace this as per discussion in https://github.com/rust-lang/rust/pull/149033#discussion_r2535465880
39-
if find_attr!(tcx, def_id, RustcAutodiff(..)) {
40-
return true;
41-
}
42-
4338
if find_attr!(tcx, def_id, RustcIntrinsic) {
4439
// Intrinsic fallback bodies are always cross-crate inlineable.
4540
// To ensure that the MIR inliner doesn't cluelessly try to inline fallback

compiler/rustc_monomorphize/src/collector.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@
205205
//! this is not implemented however: a mono item will be produced
206206
//! regardless of whether it is actually needed or not.
207207
208-
mod autodiff;
209-
210208
use std::cell::OnceCell;
211209
use std::ops::ControlFlow;
212210

@@ -240,7 +238,6 @@ use rustc_span::source_map::{Spanned, dummy_spanned, respan};
240238
use rustc_span::{DUMMY_SP, Span};
241239
use tracing::{debug, instrument, trace};
242240

243-
use crate::collector::autodiff::collect_autodiff_fn;
244241
use crate::errors::{
245242
self, EncounteredErrorWhileInstantiating, EncounteredErrorWhileInstantiatingGlobalAsm,
246243
NoOptimizedMir, RecursionLimit,
@@ -990,8 +987,6 @@ fn visit_instance_use<'tcx>(
990987
return;
991988
}
992989
if let Some(intrinsic) = tcx.intrinsic(instance.def_id()) {
993-
collect_autodiff_fn(tcx, instance, intrinsic, output);
994-
995990
if let Some(_requirement) = ValidityRequirement::from_intrinsic(intrinsic.name) {
996991
// The intrinsics assert_inhabited, assert_zero_valid, and assert_mem_uninitialized_valid will
997992
// be lowered in codegen to nothing or a call to panic_nounwind. So if we encounter any

compiler/rustc_monomorphize/src/collector/autodiff.rs

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)