diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index a7c42ca1e5..ca9a7f7fa0 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -513,9 +513,8 @@ impl Field { } /// Returns the type as in the signature of the struct (i.e., with - /// placeholder types for type parameters). This is good for showing - /// signature help, but not so good to actually get the type of the field - /// when you actually have a variable of the struct. + /// placeholder types for type parameters). Only use this in the context of + /// the field definition. pub fn ty(&self, db: &dyn HirDatabase) -> Type { let var_id = self.parent.into(); let generic_def_id: GenericDefId = match self.parent { diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index 1b5064b5a2..3aa467e3ca 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -11,7 +11,7 @@ use hir_def::{ AsMacroCall, FunctionId, TraitId, VariantId, }; use hir_expand::{name::AsName, ExpansionInfo}; -use hir_ty::associated_type_shorthand_candidates; +use hir_ty::{associated_type_shorthand_candidates, Interner}; use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet}; use syntax::{ @@ -227,7 +227,7 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> { pub fn resolve_record_field( &self, field: &ast::RecordExprField, - ) -> Option<(Field, Option)> { + ) -> Option<(Field, Option, Type)> { self.imp.resolve_record_field(field) } @@ -501,14 +501,12 @@ impl<'db> SemanticsImpl<'db> { } fn resolve_method_call(&self, call: &ast::MethodCallExpr) -> Option { - self.analyze(call.syntax()).resolve_method_call(self.db, call) + self.analyze(call.syntax()).resolve_method_call(self.db, call).map(|(id, _)| id) } fn resolve_method_call_as_callable(&self, call: &ast::MethodCallExpr) -> Option { - // FIXME: this erases Substs, we should instead record the correct - // substitution during inference and use that - let func = self.resolve_method_call(call)?; - let ty = hir_ty::TyBuilder::value_ty(self.db, func.into()).fill_with_unknown().build(); + let (func, subst) = self.analyze(call.syntax()).resolve_method_call(self.db, call)?; + let ty = self.db.value_ty(func.into()).substitute(&Interner, &subst); let resolver = self.analyze(call.syntax()).resolver; let ty = Type::new_with_resolver(self.db, &resolver, ty)?; let mut res = ty.as_callable(self.db)?; @@ -520,7 +518,10 @@ impl<'db> SemanticsImpl<'db> { self.analyze(field.syntax()).resolve_field(self.db, field) } - fn resolve_record_field(&self, field: &ast::RecordExprField) -> Option<(Field, Option)> { + fn resolve_record_field( + &self, + field: &ast::RecordExprField, + ) -> Option<(Field, Option, Type)> { self.analyze(field.syntax()).resolve_record_field(self.db, field) } diff --git a/crates/hir/src/source_analyzer.rs b/crates/hir/src/source_analyzer.rs index 20753314da..3f940124cc 100644 --- a/crates/hir/src/source_analyzer.rs +++ b/crates/hir/src/source_analyzer.rs @@ -143,7 +143,7 @@ impl SourceAnalyzer { &self, db: &dyn HirDatabase, call: &ast::MethodCallExpr, - ) -> Option { + ) -> Option<(FunctionId, Substitution)> { let expr_id = self.expr_id(db, &call.clone().into())?; self.infer.as_ref()?.method_resolution(expr_id) } @@ -161,7 +161,7 @@ impl SourceAnalyzer { &self, db: &dyn HirDatabase, field: &ast::RecordExprField, - ) -> Option<(Field, Option)> { + ) -> Option<(Field, Option, Type)> { let record_expr = ast::RecordExpr::cast(field.syntax().parent().and_then(|p| p.parent())?)?; let expr = ast::Expr::from(record_expr); let expr_id = self.body_source_map.as_ref()?.node_expr(InFile::new(self.file_id, &expr))?; @@ -178,10 +178,13 @@ impl SourceAnalyzer { _ => None, } }; + let (_, subst) = self.infer.as_ref()?.type_of_expr.get(expr_id)?.as_adt()?; let variant = self.infer.as_ref()?.variant_resolution_for_expr(expr_id)?; let variant_data = variant.variant_data(db.upcast()); let field = FieldId { parent: variant, local_id: variant_data.field(&local_name)? }; - Some((field.into(), local)) + let field_ty = + db.field_types(variant).get(field.local_id)?.clone().substitute(&Interner, subst); + Some((field.into(), local, Type::new_with_resolver(db, &self.resolver, field_ty)?)) } pub(crate) fn resolve_record_pat_field( diff --git a/crates/hir_def/src/lib.rs b/crates/hir_def/src/lib.rs index a82ea5957e..70001cac86 100644 --- a/crates/hir_def/src/lib.rs +++ b/crates/hir_def/src/lib.rs @@ -485,6 +485,14 @@ impl VariantId { VariantId::UnionId(it) => it.lookup(db).id.file_id(), } } + + pub fn adt_id(self) -> AdtId { + match self { + VariantId::EnumVariantId(it) => it.parent.into(), + VariantId::StructId(it) => it.into(), + VariantId::UnionId(it) => it.into(), + } + } } trait Intern { diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs index 53c4ee9da9..d1f113e7ff 100644 --- a/crates/hir_ty/src/diagnostics/expr.rs +++ b/crates/hir_ty/src/diagnostics/expr.rs @@ -181,7 +181,7 @@ impl<'a, 'b> ExprValidator<'a, 'b> { for (id, expr) in body.exprs.iter() { if let Expr::MethodCall { receiver, .. } = expr { let function_id = match self.infer.method_resolution(id) { - Some(id) => id, + Some((id, _)) => id, None => continue, }; @@ -239,15 +239,11 @@ impl<'a, 'b> ExprValidator<'a, 'b> { return; } - // FIXME: note that we erase information about substs here. This - // is not right, but, luckily, doesn't matter as we care only - // about the number of params - let callee = match self.infer.method_resolution(call_id) { - Some(callee) => callee, + let (callee, subst) = match self.infer.method_resolution(call_id) { + Some(it) => it, None => return, }; - let sig = - db.callable_item_signature(callee.into()).into_value_and_skipped_binders().0; + let sig = db.callable_item_signature(callee.into()).substitute(&Interner, &subst); (sig, args) } diff --git a/crates/hir_ty/src/diagnostics/unsafe_check.rs b/crates/hir_ty/src/diagnostics/unsafe_check.rs index ed97dc0e3f..5d13bddea3 100644 --- a/crates/hir_ty/src/diagnostics/unsafe_check.rs +++ b/crates/hir_ty/src/diagnostics/unsafe_check.rs @@ -105,7 +105,7 @@ fn walk_unsafe( Expr::MethodCall { .. } => { if infer .method_resolution(current) - .map(|func| db.function_data(func).is_unsafe()) + .map(|(func, _)| db.function_data(func).is_unsafe()) .unwrap_or(false) { unsafe_exprs.push(UnsafeExpr { expr: current, inside_unsafe_block }); diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs index f1cebbdb98..db3c937ff5 100644 --- a/crates/hir_ty/src/infer.rs +++ b/crates/hir_ty/src/infer.rs @@ -37,8 +37,8 @@ use syntax::SmolStr; use super::{DomainGoal, InEnvironment, ProjectionTy, TraitEnvironment, TraitRef, Ty}; use crate::{ db::HirDatabase, fold_tys, infer::diagnostics::InferenceDiagnostic, - lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Goal, Interner, TyBuilder, - TyExt, TyKind, + lower::ImplTraitLoweringMode, to_assoc_type_id, AliasEq, AliasTy, Goal, Interner, Substitution, + TyBuilder, TyExt, TyKind, }; // This lint has a false positive here. See the link below for details. @@ -132,7 +132,7 @@ impl Default for InternedStandardTypes { #[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct InferenceResult { /// For each method call expr, records the function it resolves to. - method_resolutions: FxHashMap, + method_resolutions: FxHashMap, /// For each field access expr, records the field it resolves to. field_resolutions: FxHashMap, /// For each struct literal or pattern, records the variant it resolves to. @@ -152,8 +152,8 @@ pub struct InferenceResult { } impl InferenceResult { - pub fn method_resolution(&self, expr: ExprId) -> Option { - self.method_resolutions.get(&expr).copied() + pub fn method_resolution(&self, expr: ExprId) -> Option<(FunctionId, Substitution)> { + self.method_resolutions.get(&expr).cloned() } pub fn field_resolution(&self, expr: ExprId) -> Option { self.field_resolutions.get(&expr).copied() @@ -284,14 +284,17 @@ impl<'a> InferenceContext<'a> { self.table.propagate_diverging_flag(); let mut result = std::mem::take(&mut self.result); for ty in result.type_of_expr.values_mut() { - *ty = self.table.resolve_ty_completely(ty.clone()); + *ty = self.table.resolve_completely(ty.clone()); } for ty in result.type_of_pat.values_mut() { - *ty = self.table.resolve_ty_completely(ty.clone()); + *ty = self.table.resolve_completely(ty.clone()); } for mismatch in result.type_mismatches.values_mut() { - mismatch.expected = self.table.resolve_ty_completely(mismatch.expected.clone()); - mismatch.actual = self.table.resolve_ty_completely(mismatch.actual.clone()); + mismatch.expected = self.table.resolve_completely(mismatch.expected.clone()); + mismatch.actual = self.table.resolve_completely(mismatch.actual.clone()); + } + for (_, subst) in result.method_resolutions.values_mut() { + *subst = self.table.resolve_completely(subst.clone()); } result } @@ -300,8 +303,8 @@ impl<'a> InferenceContext<'a> { self.result.type_of_expr.insert(expr, ty); } - fn write_method_resolution(&mut self, expr: ExprId, func: FunctionId) { - self.result.method_resolutions.insert(expr, func); + fn write_method_resolution(&mut self, expr: ExprId, func: FunctionId, subst: Substitution) { + self.result.method_resolutions.insert(expr, (func, subst)); } fn write_field_resolution(&mut self, expr: ExprId, field: FieldId) { diff --git a/crates/hir_ty/src/infer/expr.rs b/crates/hir_ty/src/infer/expr.rs index 08c05c67cc..eab8fac910 100644 --- a/crates/hir_ty/src/infer/expr.rs +++ b/crates/hir_ty/src/infer/expr.rs @@ -891,17 +891,21 @@ impl<'a> InferenceContext<'a> { method_name, ) }); - let (derefed_receiver_ty, method_ty, def_generics) = match resolved { + let (derefed_receiver_ty, method_ty, substs) = match resolved { Some((ty, func)) => { let ty = canonicalized_receiver.decanonicalize_ty(ty); - self.write_method_resolution(tgt_expr, func); - (ty, self.db.value_ty(func.into()), Some(generics(self.db.upcast(), func.into()))) + let generics = generics(self.db.upcast(), func.into()); + let substs = self.substs_for_method_call(generics, generic_args, &ty); + self.write_method_resolution(tgt_expr, func, substs.clone()); + (ty, self.db.value_ty(func.into()), substs) } - None => (receiver_ty, Binders::empty(&Interner, self.err_ty()), None), + None => ( + receiver_ty, + Binders::empty(&Interner, self.err_ty()), + Substitution::empty(&Interner), + ), }; - let substs = self.substs_for_method_call(def_generics, generic_args, &derefed_receiver_ty); let method_ty = method_ty.substitute(&Interner, &substs); - let method_ty = self.insert_type_vars(method_ty); self.register_obligations_for_call(&method_ty); let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) { Some(sig) => { @@ -950,23 +954,21 @@ impl<'a> InferenceContext<'a> { fn substs_for_method_call( &mut self, - def_generics: Option, + def_generics: Generics, generic_args: Option<&GenericArgs>, receiver_ty: &Ty, ) -> Substitution { let (parent_params, self_params, type_params, impl_trait_params) = - def_generics.as_ref().map_or((0, 0, 0, 0), |g| g.provenance_split()); + def_generics.provenance_split(); assert_eq!(self_params, 0); // method shouldn't have another Self param let total_len = parent_params + type_params + impl_trait_params; let mut substs = Vec::with_capacity(total_len); // Parent arguments are unknown, except for the receiver type - if let Some(parent_generics) = def_generics.as_ref().map(|p| p.iter_parent()) { - for (_id, param) in parent_generics { - if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf { - substs.push(receiver_ty.clone()); - } else { - substs.push(self.err_ty()); - } + for (_id, param) in def_generics.iter_parent() { + if param.provenance == hir_def::generics::TypeParamProvenance::TraitSelf { + substs.push(receiver_ty.clone()); + } else { + substs.push(self.table.new_type_var()); } } // handle provided type arguments @@ -989,7 +991,7 @@ impl<'a> InferenceContext<'a> { }; let supplied_params = substs.len(); for _ in supplied_params..total_len { - substs.push(self.err_ty()); + substs.push(self.table.new_type_var()); } assert_eq!(substs.len(), total_len); Substitution::from_iter(&Interner, substs) diff --git a/crates/hir_ty/src/infer/unify.rs b/crates/hir_ty/src/infer/unify.rs index f8233cac39..ea5684229f 100644 --- a/crates/hir_ty/src/infer/unify.rs +++ b/crates/hir_ty/src/infer/unify.rs @@ -295,8 +295,11 @@ impl<'a> InferenceTable<'a> { .expect("fold failed unexpectedly") } - pub(crate) fn resolve_ty_completely(&mut self, ty: Ty) -> Ty { - self.resolve_with_fallback(ty, |_, _, d, _| d) + pub(crate) fn resolve_completely(&mut self, t: T) -> T::Result + where + T: HasInterner + Fold, + { + self.resolve_with_fallback(t, |_, _, d, _| d) } /// Unify two types and register new trait goals that arise from that. diff --git a/crates/ide_assists/src/handlers/fix_visibility.rs b/crates/ide_assists/src/handlers/fix_visibility.rs index 6c7824e551..89f7b2c2c3 100644 --- a/crates/ide_assists/src/handlers/fix_visibility.rs +++ b/crates/ide_assists/src/handlers/fix_visibility.rs @@ -85,7 +85,7 @@ fn add_vis_to_referenced_module_def(acc: &mut Assists, ctx: &AssistContext) -> O fn add_vis_to_referenced_record_field(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let record_field: ast::RecordExprField = ctx.find_node_at_offset()?; - let (record_field_def, _) = ctx.sema.resolve_record_field(&record_field)?; + let (record_field_def, _, _) = ctx.sema.resolve_record_field(&record_field)?; let current_module = ctx.sema.scope(record_field.syntax()).module()?; let visibility = record_field_def.visibility(ctx.db()); diff --git a/crates/ide_completion/src/context.rs b/crates/ide_completion/src/context.rs index 787eb2fd3f..1ec59ff804 100644 --- a/crates/ide_completion/src/context.rs +++ b/crates/ide_completion/src/context.rs @@ -337,25 +337,24 @@ impl<'a> CompletionContext<'a> { }, ast::RecordExprFieldList(_it) => { cov_mark::hit!(expected_type_struct_field_without_leading_char); - self.token.prev_sibling_or_token() - .and_then(|se| se.into_node()) - .and_then(|node| ast::RecordExprField::cast(node)) - .and_then(|rf| self.sema.resolve_record_field(&rf).zip(Some(rf))) - .map(|(f, rf)|( - Some(f.0.ty(self.db)), - rf.field_name().map(NameOrNameRef::NameRef), + // wouldn't try {} be nice... + (|| { + let expr_field = self.token.prev_sibling_or_token()? + .into_node() + .and_then(|node| ast::RecordExprField::cast(node))?; + let (_, _, ty) = self.sema.resolve_record_field(&expr_field)?; + Some(( + Some(ty), + expr_field.field_name().map(NameOrNameRef::NameRef), )) - .unwrap_or((None, None)) + })().unwrap_or((None, None)) }, ast::RecordExprField(it) => { cov_mark::hit!(expected_type_struct_field_with_leading_char); - self.sema - .resolve_record_field(&it) - .map(|f|( - Some(f.0.ty(self.db)), - it.field_name().map(NameOrNameRef::NameRef), - )) - .unwrap_or((None, None)) + ( + it.expr().as_ref().and_then(|e| self.sema.type_of_expr(e)), + it.field_name().map(NameOrNameRef::NameRef), + ) }, ast::MatchExpr(it) => { cov_mark::hit!(expected_type_match_arm_without_leading_char); @@ -382,6 +381,12 @@ impl<'a> CompletionContext<'a> { let def = self.sema.to_def(&it); (def.map(|def| def.ret_type(self.db)), None) }, + ast::ClosureExpr(it) => { + let ty = self.sema.type_of_expr(&it.into()); + ty.and_then(|ty| ty.as_callable(self.db)) + .map(|c| (Some(c.return_type()), None)) + .unwrap_or((None, None)) + }, ast::Stmt(_it) => (None, None), _ => { match node.parent() { @@ -784,6 +789,19 @@ fn foo() { ) } + #[test] + fn expected_type_generic_struct_field() { + check_expected_type_and_name( + r#" +struct Foo { a: T } +fn foo() -> Foo { + Foo { a: $0 } +} +"#, + expect![[r#"ty: u32, name: a"#]], + ) + } + #[test] fn expected_type_struct_field_with_leading_char() { cov_mark::check!(expected_type_struct_field_with_leading_char); @@ -895,4 +913,52 @@ fn foo() -> u32 { expect![[r#"ty: u32, name: ?"#]], ) } + + #[test] + fn expected_type_closure_param_return() { + // FIXME: make this work with `|| $0` + check_expected_type_and_name( + r#" +fn foo() { + bar(|| a$0); +} + +fn bar(f: impl FnOnce() -> u32) {} +#[lang = "fn_once"] +trait FnOnce { type Output; } +"#, + expect![[r#"ty: u32, name: ?"#]], + ); + } + + #[test] + fn expected_type_generic_function() { + check_expected_type_and_name( + r#" +fn foo() { + bar::($0); +} + +fn bar(t: T) {} +"#, + expect![[r#"ty: u32, name: t"#]], + ); + } + + #[test] + fn expected_type_generic_method() { + check_expected_type_and_name( + r#" +fn foo() { + S(1u32).bar($0); +} + +struct S(T); +impl S { + fn bar(self, t: T) {} +} +"#, + expect![[r#"ty: u32, name: t"#]], + ); + } } diff --git a/crates/ide_completion/src/render.rs b/crates/ide_completion/src/render.rs index 6b04ee1648..d7f96b8645 100644 --- a/crates/ide_completion/src/render.rs +++ b/crates/ide_completion/src/render.rs @@ -667,6 +667,13 @@ fn foo() { A { the$0 } } ), detail: "u32", deprecated: true, + relevance: CompletionRelevance { + exact_name_match: false, + type_match: Some( + CouldUnify, + ), + is_local: false, + }, }, ] "#]], diff --git a/crates/ide_db/src/call_info/tests.rs b/crates/ide_db/src/call_info/tests.rs index be1cc12de1..1aeda08e5f 100644 --- a/crates/ide_db/src/call_info/tests.rs +++ b/crates/ide_db/src/call_info/tests.rs @@ -188,6 +188,24 @@ fn main() { S.foo($0); } ); } +#[test] +fn test_fn_signature_for_generic_method() { + check( + r#" +struct S(T); +impl S { + fn foo(&self, x: T) {} +} + +fn main() { S(1u32).foo($0); } +"#, + expect![[r#" + fn foo(&self, x: u32) + () + "#]], + ); +} + #[test] fn test_fn_signature_for_method_with_arg_as_assoc_fn() { check( diff --git a/crates/ide_db/src/defs.rs b/crates/ide_db/src/defs.rs index de0dc2a40f..1dcccbb8be 100644 --- a/crates/ide_db/src/defs.rs +++ b/crates/ide_db/src/defs.rs @@ -311,7 +311,7 @@ impl NameRefClass { } if let Some(record_field) = ast::RecordExprField::for_field_name(name_ref) { - if let Some((field, local)) = sema.resolve_record_field(&record_field) { + if let Some((field, local, _)) = sema.resolve_record_field(&record_field) { let field = Definition::Field(field); let res = match local { None => NameRefClass::Definition(field),