Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions crates/hir-def/src/hir/type_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,16 @@ impl TypeRef {
TypeRef::Tuple(ThinVec::new())
}

pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(&TypeRef)) {
pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefId, &TypeRef)) {
go(this, f, map);

fn go(type_ref: TypeRefId, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) {
let type_ref = &map[type_ref];
f(type_ref);
fn go(
type_ref_id: TypeRefId,
f: &mut impl FnMut(TypeRefId, &TypeRef),
map: &ExpressionStore,
) {
let type_ref = &map[type_ref_id];
f(type_ref_id, type_ref);
match type_ref {
TypeRef::Fn(fn_) => {
fn_.params.iter().for_each(|&(_, param_type)| go(param_type, f, map))
Expand All @@ -224,7 +228,7 @@ impl TypeRef {
};
}

fn go_path(path: &Path, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) {
fn go_path(path: &Path, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) {
if let Some(type_ref) = path.type_anchor() {
go(type_ref, f, map);
}
Expand Down
43 changes: 41 additions & 2 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use hir_def::{
layout::Integer,
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
signatures::{ConstSignature, StaticSignature},
type_ref::{ConstRef, LifetimeRefId, TypeRefId},
type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId},
};
use hir_expand::{mod_path::ModPath, name::Name};
use indexmap::IndexSet;
Expand All @@ -56,6 +56,7 @@ use triomphe::Arc;

use crate::{
ImplTraitId, IncorrectGenericsLenKind, PathLoweringDiagnostic, TargetFeatures,
collect_type_inference_vars,
db::{HirDatabase, InternedClosureId, InternedOpaqueTyId},
infer::{
coerce::{CoerceMany, DynamicCoerceMany},
Expand Down Expand Up @@ -456,6 +457,7 @@ pub struct InferenceResult<'db> {
/// unresolved or missing subpatterns or subpatterns of mismatched types.
pub(crate) type_of_pat: ArenaMap<PatId, Ty<'db>>,
pub(crate) type_of_binding: ArenaMap<BindingId, Ty<'db>>,
pub(crate) type_of_type_placeholder: ArenaMap<TypeRefId, Ty<'db>>,
pub(crate) type_of_opaque: FxHashMap<InternedOpaqueTyId, Ty<'db>>,
type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch<'db>>,
/// Whether there are any type-mismatching errors in the result.
Expand Down Expand Up @@ -501,6 +503,7 @@ impl<'db> InferenceResult<'db> {
type_of_expr: Default::default(),
type_of_pat: Default::default(),
type_of_binding: Default::default(),
type_of_type_placeholder: Default::default(),
type_of_opaque: Default::default(),
type_mismatches: Default::default(),
has_errors: Default::default(),
Expand Down Expand Up @@ -565,6 +568,12 @@ impl<'db> InferenceResult<'db> {
_ => None,
})
}
pub fn placeholder_types(&self) -> impl Iterator<Item = (TypeRefId, &Ty<'db>)> {
self.type_of_type_placeholder.iter()
}
pub fn type_of_type_placeholder(&self, type_ref: TypeRefId) -> Option<Ty<'db>> {
self.type_of_type_placeholder.get(type_ref).copied()
}
pub fn closure_info(&self, closure: InternedClosureId) -> &(Vec<CapturedItem<'db>>, FnTrait) {
self.closure_info.get(&closure).unwrap()
}
Expand Down Expand Up @@ -972,6 +981,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
type_of_expr,
type_of_pat,
type_of_binding,
type_of_type_placeholder,
type_of_opaque,
type_mismatches,
has_errors,
Expand Down Expand Up @@ -1004,6 +1014,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
*has_errors = *has_errors || ty.references_non_lt_error();
}
type_of_binding.shrink_to_fit();
for ty in type_of_type_placeholder.values_mut() {
*ty = table.resolve_completely(*ty);
*has_errors = *has_errors || ty.references_non_lt_error();
}
type_of_type_placeholder.shrink_to_fit();
type_of_opaque.shrink_to_fit();

*has_errors |= !type_mismatches.is_empty();
Expand Down Expand Up @@ -1233,6 +1248,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
self.result.type_of_pat.insert(pat, ty);
}

fn write_type_placeholder_ty(&mut self, type_ref: TypeRefId, ty: Ty<'db>) {
self.result.type_of_type_placeholder.insert(type_ref, ty);
}

fn write_binding_ty(&mut self, id: BindingId, ty: Ty<'db>) {
self.result.type_of_binding.insert(id, ty);
}
Expand Down Expand Up @@ -1281,7 +1300,27 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
) -> Ty<'db> {
let ty = self
.with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref));
self.process_user_written_ty(ty)
let ty = self.process_user_written_ty(ty);

// Record the association from placeholders' TypeRefId to type variables.
// We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order.
let type_variables = collect_type_inference_vars(&ty);
let mut placeholder_ids = vec![];
TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| {
if matches!(type_ref, TypeRef::Placeholder) {
placeholder_ids.push(type_ref_id);
}
});

if placeholder_ids.len() == type_variables.len() {
for (placeholder_id, type_variable) in
placeholder_ids.into_iter().zip(type_variables.into_iter())
{
self.write_type_placeholder_ty(placeholder_id, type_variable);
}
}

ty
}

fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> {
Expand Down
29 changes: 29 additions & 0 deletions crates/hir-ty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,35 @@ where
Vec::from_iter(collector.params)
}

struct TypeInferenceVarCollector<'db> {
type_inference_vars: Vec<Ty<'db>>,
}

impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for TypeInferenceVarCollector<'db> {
type Result = ();

fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result {
use crate::rustc_type_ir::Flags;
if ty.is_ty_var() {
self.type_inference_vars.push(ty);
} else if ty.flags().intersects(rustc_type_ir::TypeFlags::HAS_TY_INFER) {
ty.super_visit_with(self);
} else {
// Fast path: don't visit inner types (e.g. generic arguments) when `flags` indicate
// that there are no placeholders.
}
}
}

pub fn collect_type_inference_vars<'db, T>(value: &T) -> Vec<Ty<'db>>
where
T: ?Sized + rustc_type_ir::TypeVisitable<DbInterner<'db>>,
{
let mut collector = TypeInferenceVarCollector { type_inference_vars: vec![] };
value.visit_with(&mut collector);
collector.type_inference_vars
}

pub fn known_const_to_ast<'db>(
konst: Const<'db>,
db: &'db dyn HirDatabase,
Expand Down
33 changes: 33 additions & 0 deletions crates/hir-ty/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use hir_def::{
item_scope::ItemScope,
nameres::DefMap,
src::HasSource,
type_ref::TypeRefId,
};
use hir_expand::{FileRange, InFile, db::ExpandDatabase};
use itertools::Itertools;
Expand Down Expand Up @@ -219,6 +220,24 @@ fn check_impl(
}
}
}

for (type_ref, ty) in inference_result.placeholder_types() {
let node = match type_node(&body_source_map, type_ref, &db) {
Some(value) => value,
None => continue,
};
let range = node.as_ref().original_file_range_rooted(&db);
if let Some(expected) = types.remove(&range) {
let actual = salsa::attach(&db, || {
if display_source {
ty.display_source_code(&db, def.module(&db), true).unwrap()
} else {
ty.display_test(&db, display_target).to_string()
}
});
assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range);
}
}
}

let mut buf = String::new();
Expand Down Expand Up @@ -275,6 +294,20 @@ fn pat_node(
})
}

fn type_node(
body_source_map: &BodySourceMap,
type_ref: TypeRefId,
db: &TestDB,
) -> Option<InFile<SyntaxNode>> {
Some(match body_source_map.type_syntax(type_ref) {
Ok(sp) => {
let root = db.parse_or_expand(sp.file_id);
sp.map(|ptr| ptr.to_node(&root).syntax().clone())
}
Err(SyntheticSyntax) => return None,
})
}

fn infer(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> String {
infer_with_mismatches(ra_fixture, false)
}
Expand Down
19 changes: 19 additions & 0 deletions crates/hir-ty/src/tests/display_source_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,22 @@ fn test() {
"#,
);
}

#[test]
fn type_placeholder_type() {
check_types_source_code(
r#"
struct S<T>(T);
fn test() {
let f: S<_> = S(3);
//^ i32
let f: [_; _] = [4_u32, 5, 6];
//^ u32
let f: (_, _, _) = (1_u32, 1_i32, false);
//^ u32
//^ i32
//^ bool
}
"#,
);
}
32 changes: 30 additions & 2 deletions crates/hir/src/source_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use hir_def::{
lang_item::LangItem,
nameres::MacroSubNs,
resolver::{HasResolver, Resolver, TypeNs, ValueNs, resolver_for_scope},
type_ref::{Mutability, TypeRefId},
type_ref::{Mutability, TypeRef, TypeRefId},
};
use hir_expand::{
HirFileId, InFile,
Expand Down Expand Up @@ -267,8 +267,11 @@ impl<'db> SourceAnalyzer<'db> {
db: &'db dyn HirDatabase,
ty: &ast::Type,
) -> Option<Type<'db>> {
let interner = DbInterner::new_with(db, None, None);

let type_ref = self.type_id(ty)?;
let ty = TyLoweringContext::new(

let mut ty = TyLoweringContext::new(
db,
&self.resolver,
self.store()?,
Expand All @@ -279,6 +282,31 @@ impl<'db> SourceAnalyzer<'db> {
LifetimeElisionKind::Infer,
)
.lower_ty(type_ref);

// Try and substitute unknown types using InferenceResult
if let Some(infer) = self.infer()
&& let Some(store) = self.store()
{
let mut inferred_types = vec![];
TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| {
if matches!(type_ref, TypeRef::Placeholder) {
inferred_types.push(infer.type_of_type_placeholder(type_ref_id));
}
});
let mut inferred_types = inferred_types.into_iter();

let substituted_ty = hir_ty::next_solver::fold::fold_tys(interner, ty, |ty| {
if ty.is_ty_error() { inferred_types.next().flatten().unwrap_or(ty) } else { ty }
});

// Only used the result if the placeholder and unknown type counts matched
let success =
inferred_types.next().is_none() && !substituted_ty.references_non_lt_error();
if success {
ty = substituted_ty;
}
}

Some(Type::new_with_resolver(db, &self.resolver, ty))
}

Expand Down
58 changes: 57 additions & 1 deletion crates/ide-assists/src/handlers/extract_type_alias.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use either::Either;
use hir::HirDisplay;
use ide_db::syntax_helpers::node_ext::walk_ty;
use syntax::{
ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, make},
Expand Down Expand Up @@ -39,6 +40,15 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) ->
);
let target = ty.syntax().text_range();

let resolved_ty = ctx.sema.resolve_type(&ty)?;
let resolved_ty = if !resolved_ty.contains_unknown() {
let module = ctx.sema.scope(ty.syntax())?.module();
let resolved_ty = resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?;
make::ty(&resolved_ty)
} else {
ty.clone()
};

acc.add(
AssistId::refactor_extract("extract_type_alias"),
"Extract type as type alias",
Expand Down Expand Up @@ -72,7 +82,7 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) ->

// Insert new alias
let ty_alias =
make::ty_alias(None, "Type", generic_params, None, None, Some((ty, None)))
make::ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None)))
.clone_for_update();

if let Some(cap) = ctx.config.snippet_cap
Expand Down Expand Up @@ -391,4 +401,50 @@ where
"#,
);
}

#[test]
fn inferred_generic_type_parameter() {
check_assist(
extract_type_alias,
r#"
struct Wrap<T>(T);

fn main() {
let wrap: $0Wrap<_>$0 = Wrap::<_>(3i32);
}
"#,
r#"
struct Wrap<T>(T);

type $0Type = Wrap<i32>;

fn main() {
let wrap: Type = Wrap::<_>(3i32);
}
"#,
)
}

#[test]
fn inferred_type() {
check_assist(
extract_type_alias,
r#"
struct Wrap<T>(T);

fn main() {
let wrap: Wrap<$0_$0> = Wrap::<_>(3i32);
}
"#,
r#"
struct Wrap<T>(T);

type $0Type = i32;

fn main() {
let wrap: Wrap<Type> = Wrap::<_>(3i32);
}
"#,
)
}
}
Loading