Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unpolished, experimental support for AFIDT (async fn in dyn trait) #133122

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_feature/src/unstable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ declare_features! (
(unstable, associated_type_defaults, "1.2.0", Some(29661)),
/// Allows `async || body` closures.
(unstable, async_closure, "1.37.0", Some(62290)),
/// Allows async functions to be called from `dyn Trait`.
(incomplete, async_fn_in_dyn_trait, "CURRENT_RUSTC_VERSION", Some(133119)),
/// Allows `#[track_caller]` on async functions.
(unstable, async_fn_track_caller, "1.73.0", Some(110011)),
/// Allows `for await` loops.
Expand Down
37 changes: 20 additions & 17 deletions compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,23 +675,26 @@ impl<'tcx> Instance<'tcx> {
//
// 1) The underlying method expects a caller location parameter
// in the ABI
if resolved.def.requires_caller_location(tcx)
// 2) The caller location parameter comes from having `#[track_caller]`
// on the implementation, and *not* on the trait method.
&& !tcx.should_inherit_track_caller(def)
// If the method implementation comes from the trait definition itself
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
// then we don't need to generate a shim. This check is needed because
// `should_inherit_track_caller` returns `false` if our method
// implementation comes from the trait block, and not an impl block
&& !matches!(
tcx.opt_associated_item(def),
Some(ty::AssocItem {
container: ty::AssocItemContainer::Trait,
..
})
)
{
let needs_track_caller_shim = resolved.def.requires_caller_location(tcx)
// 2) The caller location parameter comes from having `#[track_caller]`
// on the implementation, and *not* on the trait method.
&& !tcx.should_inherit_track_caller(def)
// If the method implementation comes from the trait definition itself
// (e.g. `trait Foo { #[track_caller] my_fn() { /* impl */ } }`),
// then we don't need to generate a shim. This check is needed because
// `should_inherit_track_caller` returns `false` if our method
// implementation comes from the trait block, and not an impl block
&& !matches!(
tcx.opt_associated_item(def),
Some(ty::AssocItem {
container: ty::AssocItemContainer::Trait,
..
})
);
// We also need to generate a shim if this is an AFIT.
let needs_rpitit_shim =
tcx.return_position_impl_trait_in_trait_shim_data(def).is_some();
if needs_track_caller_shim || needs_rpitit_shim {
if tcx.is_closure_like(def) {
debug!(
" => vtable fn pointer created for closure with #[track_caller]: {:?} for method {:?} {:?}",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ mod opaque_types;
mod parameterized;
mod predicate;
mod region;
mod return_position_impl_trait_in_trait;
mod rvalue_scopes;
mod structural_impls;
#[allow(hidden_glob_reexports)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use rustc_hir::def_id::DefId;

use crate::ty::{self, ExistentialPredicateStableCmpExt, TyCtxt};

impl<'tcx> TyCtxt<'tcx> {
/// Given a `def_id` of a trait or impl method, compute whether that method needs to
/// have an RPITIT shim applied to it for it to be object safe. If so, return the
/// `def_id` of the RPITIT, and also the args of trait method that returns the RPITIT.
///
/// NOTE that these args are not, in general, the same as than the RPITIT's args. They
/// are a subset of those args, since they do not include the late-bound lifetimes of
/// the RPITIT. Depending on the context, these will need to be dealt with in different
/// ways -- in codegen, it's okay to fill them with ReErased.
pub fn return_position_impl_trait_in_trait_shim_data(
self,
def_id: DefId,
) -> Option<(DefId, ty::EarlyBinder<'tcx, ty::GenericArgsRef<'tcx>>)> {
let assoc_item = self.opt_associated_item(def_id)?;

let (trait_item_def_id, opt_impl_def_id) = match assoc_item.container {
ty::AssocItemContainer::Impl => {
(assoc_item.trait_item_def_id?, Some(self.parent(def_id)))
}
ty::AssocItemContainer::Trait => (def_id, None),
};

let sig = self.fn_sig(trait_item_def_id);

let ty::Alias(ty::Projection, alias_ty) = *sig.skip_binder().skip_binder().output().kind()
else {
return None;
};

if !self.is_impl_trait_in_trait(alias_ty.def_id) {
return None;
}

let args = if let Some(impl_def_id) = opt_impl_def_id {
// Rebase the args from the RPITIT onto the impl trait ref, so we can later
// substitute them with the method args of the *impl* method, since that's
// the instance we're building a vtable shim for.
ty::GenericArgs::identity_for_item(self, trait_item_def_id).rebase_onto(
self,
self.parent(trait_item_def_id),
self.impl_trait_ref(impl_def_id)
.expect("expected impl trait ref from parent of impl item")
.instantiate_identity()
.args,
)
} else {
// This is when we have a default trait implementation.
ty::GenericArgs::identity_for_item(self, trait_item_def_id)
};

Some((alias_ty.def_id, ty::EarlyBinder::bind(args)))
}

/// Given a `DefId` of an RPITIT and its args, return the existential predicates
/// that corresponds to the RPITIT's bounds with the self type erased.
pub fn item_bounds_to_existential_predicates(
self,
def_id: DefId,
args: ty::GenericArgsRef<'tcx>,
) -> &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> {
let mut bounds: Vec<_> = self
.item_super_predicates(def_id)
.iter_instantiated(self, args)
.filter_map(|clause| {
clause
.kind()
.map_bound(|clause| match clause {
ty::ClauseKind::Trait(trait_pred) => Some(ty::ExistentialPredicate::Trait(
ty::ExistentialTraitRef::erase_self_ty(self, trait_pred.trait_ref),
)),
ty::ClauseKind::Projection(projection_pred) => {
Some(ty::ExistentialPredicate::Projection(
ty::ExistentialProjection::erase_self_ty(self, projection_pred),
))
}
_ => None,
})
.transpose()
})
.collect();
bounds.sort_by(|a, b| a.skip_binder().stable_cmp(self, &b.skip_binder()));
self.mk_poly_existential_predicates(&bounds)
}
}
60 changes: 55 additions & 5 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::adjustment::PointerCoercion;
use rustc_middle::ty::{
self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
};
Expand Down Expand Up @@ -710,6 +711,13 @@ fn build_call_shim<'tcx>(
};

let def_id = instance.def_id();

let rpitit_shim = if let ty::InstanceKind::ReifyShim(..) = instance {
tcx.return_position_impl_trait_in_trait_shim_data(def_id)
} else {
None
};

let sig = tcx.fn_sig(def_id);
let sig = sig.map_bound(|sig| tcx.instantiate_bound_regions_with_erased(sig));

Expand Down Expand Up @@ -747,8 +755,8 @@ fn build_call_shim<'tcx>(
sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
}

// FIXME(eddyb) avoid having this snippet both here and in
// `Instance::fn_sig` (introduce `InstanceKind::fn_sig`?).
// FIXME: Avoid having to adjust the signature both here and in
// `fn_sig_for_fn_abi`.
if let ty::InstanceKind::VTableShim(..) = instance {
// Modify fn(self, ...) to fn(self: *mut Self, ...)
let mut inputs_and_output = sig.inputs_and_output.to_vec();
Expand All @@ -765,9 +773,34 @@ fn build_call_shim<'tcx>(
let mut local_decls = local_decls_for_sig(&sig, span);
let source_info = SourceInfo::outermost(span);

let mut destination = Place::return_place();
if let Some((rpitit_def_id, fn_args)) = rpitit_shim {
let rpitit_args =
fn_args.instantiate_identity().extend_to(tcx, rpitit_def_id, |param, _| {
match param.kind {
ty::GenericParamDefKind::Lifetime => tcx.lifetimes.re_erased.into(),
ty::GenericParamDefKind::Type { .. }
| ty::GenericParamDefKind::Const { .. } => {
unreachable!("rpitit should have no addition ty/ct")
}
}
});
let dyn_star_ty = Ty::new_dynamic(
tcx,
tcx.item_bounds_to_existential_predicates(rpitit_def_id, rpitit_args),
tcx.lifetimes.re_erased,
ty::DynStar,
);
destination = local_decls.push(local_decls[RETURN_PLACE].clone()).into();
local_decls[RETURN_PLACE].ty = dyn_star_ty;
let mut inputs_and_output = sig.inputs_and_output.to_vec();
*inputs_and_output.last_mut().unwrap() = dyn_star_ty;
sig.inputs_and_output = tcx.mk_type_list(&inputs_and_output);
}

let rcvr_place = || {
assert!(rcvr_adjustment.is_some());
Place::from(Local::new(1 + 0))
Place::from(Local::new(1))
};
let mut statements = vec![];

Expand Down Expand Up @@ -854,7 +887,7 @@ fn build_call_shim<'tcx>(
TerminatorKind::Call {
func: callee,
args,
destination: Place::return_place(),
destination,
target: Some(BasicBlock::new(1)),
unwind: if let Some(Adjustment::RefMut) = rcvr_adjustment {
UnwindAction::Cleanup(BasicBlock::new(3))
Expand Down Expand Up @@ -882,7 +915,24 @@ fn build_call_shim<'tcx>(
);
}
// BB #1/#2 - return
block(&mut blocks, vec![], TerminatorKind::Return, false);
// NOTE: If this is an RPITIT in dyn, we also want to coerce
// the return type of the function into a `dyn*`.
let stmts = if rpitit_shim.is_some() {
vec![Statement {
source_info,
kind: StatementKind::Assign(Box::new((
Place::return_place(),
Rvalue::Cast(
CastKind::PointerCoercion(PointerCoercion::DynStar, CoercionSource::Implicit),
Operand::Move(destination),
sig.output(),
),
))),
}]
} else {
vec![]
};
block(&mut blocks, stmts, TerminatorKind::Return, false);
if let Some(Adjustment::RefMut) = rcvr_adjustment {
// BB #3 - drop if closure panics
block(
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_monomorphize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ fn custom_coerce_unsize_info<'tcx>(
..
})) => Ok(tcx.coerce_unsized_info(impl_def_id)?.custom_kind.unwrap()),
impl_source => {
bug!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
bug!(
"invalid `CoerceUnsized` from {source_ty} to {target_ty}: impl_source: {:?}",
impl_source
);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ symbols! {
async_drop_slice,
async_drop_surface_drop_in_place,
async_fn,
async_fn_in_dyn_trait,
async_fn_in_trait,
async_fn_kind_helper,
async_fn_kind_upvars,
Expand Down
56 changes: 45 additions & 11 deletions compiler/rustc_trait_selection/src/traits/dyn_compatibility.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rustc_abi::BackendRepr;
use rustc_errors::FatalError;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_middle::bug;
use rustc_middle::query::Providers;
use rustc_middle::ty::{
self, EarlyBinder, ExistentialPredicateStableCmpExt as _, GenericArgs, Ty, TyCtxt,
Expand Down Expand Up @@ -905,23 +906,56 @@ fn contains_illegal_impl_trait_in_trait<'tcx>(
fn_def_id: DefId,
ty: ty::Binder<'tcx, Ty<'tcx>>,
) -> Option<MethodViolationCode> {
// This would be caught below, but rendering the error as a separate
// `async-specific` message is better.
let ty = tcx.liberate_late_bound_regions(fn_def_id, ty);

if tcx.asyncness(fn_def_id).is_async() {
return Some(MethodViolationCode::AsyncFn);
if tcx.features().async_fn_in_dyn_trait() {
let ty::Alias(ty::Projection, proj) = *ty.kind() else {
bug!("expected async fn in trait to return an RPITIT");
};
assert!(tcx.is_impl_trait_in_trait(proj.def_id));

// FIXME(async_fn_in_dyn_trait): We should check that this bound is legal too,
// and stop relying on `async fn` in the definition.
for bound in tcx.item_bounds(proj.def_id).instantiate(tcx, proj.args) {
if let Some(violation) = bound
.visit_with(&mut IllegalRpititVisitor { tcx, allowed: Some(proj) })
.break_value()
{
return Some(violation);
}
}

None
} else {
// Rendering the error as a separate `async-specific` message is better.
Some(MethodViolationCode::AsyncFn)
}
} else {
ty.visit_with(&mut IllegalRpititVisitor { tcx, allowed: None }).break_value()
}
}

struct IllegalRpititVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
allowed: Option<ty::AliasTy<'tcx>>,
}

impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for IllegalRpititVisitor<'tcx> {
type Result = ControlFlow<MethodViolationCode>;

// FIXME(RPITIT): Perhaps we should use a visitor here?
ty.skip_binder().walk().find_map(|arg| {
if let ty::GenericArgKind::Type(ty) = arg.unpack()
&& let ty::Alias(ty::Projection, proj) = ty.kind()
&& tcx.is_impl_trait_in_trait(proj.def_id)
fn visit_ty(&mut self, ty: Ty<'tcx>) -> Self::Result {
if let ty::Alias(ty::Projection, proj) = *ty.kind()
&& Some(proj) != self.allowed
&& self.tcx.is_impl_trait_in_trait(proj.def_id)
{
Some(MethodViolationCode::ReferencesImplTraitInTrait(tcx.def_span(proj.def_id)))
ControlFlow::Break(MethodViolationCode::ReferencesImplTraitInTrait(
self.tcx.def_span(proj.def_id),
))
} else {
None
ty.super_visit_with(self)
}
})
}
}

pub(crate) fn provide(providers: &mut Providers) {
Expand Down
Loading
Loading