11107: Fix generic type substitution in impl trait with assoc type r=pnevyk a=pnevyk

Fixes #11045 

The path transform now detects if a type parameter that is being substituted has an associated type. In that case it is necessary (or safe in general case) to fully qualify the substitution with a trait which the associated type belongs to.

This PR also fixes the previous wrong behavior of the substitution that could create an invalid tree `PATH_TYPE -> PATH_TYPE -> ...`.

Co-authored-by: Petr Nevyhoštěný <petr.nevyhosteny@gmail.com>
This commit is contained in:
bors[bot] 2022-01-08 09:05:09 +00:00 committed by GitHub
commit c17db9fa53
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 468 additions and 11 deletions

View file

@ -942,4 +942,366 @@ impl FooB for Foo {
"#, "#,
) )
} }
#[test]
fn test_assoc_type_when_trait_with_same_name_in_scope() {
check_assist(
add_missing_impl_members,
r#"
pub trait Foo {}
pub trait Types {
type Foo;
}
pub trait Behavior<T: Types> {
fn reproduce(&self, foo: T::Foo);
}
pub struct Impl;
impl<T: Types> Behavior<T> for Impl { $0 }"#,
r#"
pub trait Foo {}
pub trait Types {
type Foo;
}
pub trait Behavior<T: Types> {
fn reproduce(&self, foo: T::Foo);
}
pub struct Impl;
impl<T: Types> Behavior<T> for Impl {
fn reproduce(&self, foo: <T as Types>::Foo) {
${0:todo!()}
}
}"#,
);
}
#[test]
fn test_assoc_type_on_concrete_type() {
check_assist(
add_missing_impl_members,
r#"
pub trait Types {
type Foo;
}
impl Types for u32 {
type Foo = bool;
}
pub trait Behavior<T: Types> {
fn reproduce(&self, foo: T::Foo);
}
pub struct Impl;
impl Behavior<u32> for Impl { $0 }"#,
r#"
pub trait Types {
type Foo;
}
impl Types for u32 {
type Foo = bool;
}
pub trait Behavior<T: Types> {
fn reproduce(&self, foo: T::Foo);
}
pub struct Impl;
impl Behavior<u32> for Impl {
fn reproduce(&self, foo: <u32 as Types>::Foo) {
${0:todo!()}
}
}"#,
);
}
#[test]
fn test_assoc_type_on_concrete_type_qualified() {
check_assist(
add_missing_impl_members,
r#"
pub trait Types {
type Foo;
}
impl Types for std::string::String {
type Foo = bool;
}
pub trait Behavior<T: Types> {
fn reproduce(&self, foo: T::Foo);
}
pub struct Impl;
impl Behavior<std::string::String> for Impl { $0 }"#,
r#"
pub trait Types {
type Foo;
}
impl Types for std::string::String {
type Foo = bool;
}
pub trait Behavior<T: Types> {
fn reproduce(&self, foo: T::Foo);
}
pub struct Impl;
impl Behavior<std::string::String> for Impl {
fn reproduce(&self, foo: <std::string::String as Types>::Foo) {
${0:todo!()}
}
}"#,
);
}
#[test]
fn test_assoc_type_on_concrete_type_multi_option_ambiguous() {
check_assist(
add_missing_impl_members,
r#"
pub trait Types {
type Foo;
}
pub trait Types2 {
type Foo;
}
impl Types for u32 {
type Foo = bool;
}
impl Types2 for u32 {
type Foo = String;
}
pub trait Behavior<T: Types + Types2> {
fn reproduce(&self, foo: <T as Types2>::Foo);
}
pub struct Impl;
impl Behavior<u32> for Impl { $0 }"#,
r#"
pub trait Types {
type Foo;
}
pub trait Types2 {
type Foo;
}
impl Types for u32 {
type Foo = bool;
}
impl Types2 for u32 {
type Foo = String;
}
pub trait Behavior<T: Types + Types2> {
fn reproduce(&self, foo: <T as Types2>::Foo);
}
pub struct Impl;
impl Behavior<u32> for Impl {
fn reproduce(&self, foo: <u32 as Types2>::Foo) {
${0:todo!()}
}
}"#,
);
}
#[test]
fn test_assoc_type_on_concrete_type_multi_option() {
check_assist(
add_missing_impl_members,
r#"
pub trait Types {
type Foo;
}
pub trait Types2 {
type Bar;
}
impl Types for u32 {
type Foo = bool;
}
impl Types2 for u32 {
type Bar = String;
}
pub trait Behavior<T: Types + Types2> {
fn reproduce(&self, foo: T::Bar);
}
pub struct Impl;
impl Behavior<u32> for Impl { $0 }"#,
r#"
pub trait Types {
type Foo;
}
pub trait Types2 {
type Bar;
}
impl Types for u32 {
type Foo = bool;
}
impl Types2 for u32 {
type Bar = String;
}
pub trait Behavior<T: Types + Types2> {
fn reproduce(&self, foo: T::Bar);
}
pub struct Impl;
impl Behavior<u32> for Impl {
fn reproduce(&self, foo: <u32 as Types2>::Bar) {
${0:todo!()}
}
}"#,
);
}
#[test]
fn test_assoc_type_on_concrete_type_multi_option_foreign() {
check_assist(
add_missing_impl_members,
r#"
mod bar {
pub trait Types2 {
type Bar;
}
}
pub trait Types {
type Foo;
}
impl Types for u32 {
type Foo = bool;
}
impl bar::Types2 for u32 {
type Bar = String;
}
pub trait Behavior<T: Types + bar::Types2> {
fn reproduce(&self, foo: T::Bar);
}
pub struct Impl;
impl Behavior<u32> for Impl { $0 }"#,
r#"
mod bar {
pub trait Types2 {
type Bar;
}
}
pub trait Types {
type Foo;
}
impl Types for u32 {
type Foo = bool;
}
impl bar::Types2 for u32 {
type Bar = String;
}
pub trait Behavior<T: Types + bar::Types2> {
fn reproduce(&self, foo: T::Bar);
}
pub struct Impl;
impl Behavior<u32> for Impl {
fn reproduce(&self, foo: <u32 as bar::Types2>::Bar) {
${0:todo!()}
}
}"#,
);
}
#[test]
fn test_transform_path_in_path_expr() {
check_assist(
add_missing_default_members,
r#"
pub trait Const {
const FOO: u32;
}
pub trait Trait<T: Const> {
fn foo() -> bool {
match T::FOO {
0 => true,
_ => false,
}
}
}
impl Const for u32 {
const FOO: u32 = 1;
}
struct Impl;
impl Trait<u32> for Impl { $0 }"#,
r#"
pub trait Const {
const FOO: u32;
}
pub trait Trait<T: Const> {
fn foo() -> bool {
match T::FOO {
0 => true,
_ => false,
}
}
}
impl Const for u32 {
const FOO: u32 = 1;
}
struct Impl;
impl Trait<u32> for Impl {
$0fn foo() -> bool {
match <u32 as Const>::FOO {
0 => true,
_ => false,
}
}
}"#,
);
}
} }

View file

@ -118,14 +118,20 @@ struct Ctx<'a> {
impl<'a> Ctx<'a> { impl<'a> Ctx<'a> {
fn apply(&self, item: &SyntaxNode) { fn apply(&self, item: &SyntaxNode) {
for event in item.preorder() { // `transform_path` may update a node's parent and that would break the
let node = match event { // tree traversal. Thus all paths in the tree are collected into a vec
syntax::WalkEvent::Enter(_) => continue, // so that such operation is safe.
syntax::WalkEvent::Leave(it) => it, let paths = item
}; .preorder()
if let Some(path) = ast::Path::cast(node.clone()) { .filter_map(|event| match event {
self.transform_path(path); syntax::WalkEvent::Enter(_) => None,
} syntax::WalkEvent::Leave(node) => Some(node),
})
.filter_map(ast::Path::cast)
.collect::<Vec<_>>();
for path in paths {
self.transform_path(path);
} }
} }
fn transform_path(&self, path: ast::Path) -> Option<()> { fn transform_path(&self, path: ast::Path) -> Option<()> {
@ -145,10 +151,60 @@ impl<'a> Ctx<'a> {
match resolution { match resolution {
hir::PathResolution::TypeParam(tp) => { hir::PathResolution::TypeParam(tp) => {
if let Some(subst) = self.substs.get(&tp) { if let Some(subst) = self.substs.get(&tp) {
ted::replace(path.syntax(), subst.clone_subtree().clone_for_update().syntax()) let parent = path.syntax().parent()?;
if let Some(parent) = ast::Path::cast(parent.clone()) {
// Path inside path means that there is an associated
// type/constant on the type parameter. It is necessary
// to fully qualify the type with `as Trait`. Even
// though it might be unnecessary if `subst` is generic
// type, always fully qualifying the path is safer
// because of potential clash of associated types from
// multiple traits
let trait_ref = find_trait_for_assoc_item(
self.source_scope,
tp,
parent.segment()?.name_ref()?,
)
.and_then(|trait_ref| {
let found_path = self.target_module.find_use_path(
self.source_scope.db.upcast(),
hir::ModuleDef::Trait(trait_ref),
)?;
match ast::make::ty_path(mod_path_to_ast(&found_path)) {
ast::Type::PathType(path_ty) => Some(path_ty),
_ => None,
}
});
let segment = ast::make::path_segment_ty(subst.clone(), trait_ref);
let qualified =
ast::make::path_from_segments(std::iter::once(segment), false);
ted::replace(path.syntax(), qualified.clone_for_update().syntax());
} else if let Some(path_ty) = ast::PathType::cast(parent) {
ted::replace(
path_ty.syntax(),
subst.clone_subtree().clone_for_update().syntax(),
);
} else {
ted::replace(
path.syntax(),
subst.clone_subtree().clone_for_update().syntax(),
);
}
} }
} }
hir::PathResolution::Def(def) => { hir::PathResolution::Def(def) => {
if let hir::ModuleDef::Trait(_) = def {
if matches!(path.segment()?.kind()?, ast::PathSegmentKind::Type { .. }) {
// `speculative_resolve` resolves segments like `<T as
// Trait>` into `Trait`, but just the trait name should
// not be used as the replacement of the original
// segment.
return None;
}
}
let found_path = let found_path =
self.target_module.find_use_path(self.source_scope.db.upcast(), def)?; self.target_module.find_use_path(self.source_scope.db.upcast(), def)?;
let res = mod_path_to_ast(&found_path).clone_for_update(); let res = mod_path_to_ast(&found_path).clone_for_update();
@ -195,3 +251,34 @@ fn get_type_args_from_arg_list(generic_arg_list: ast::GenericArgList) -> Option<
Some(result) Some(result)
} }
fn find_trait_for_assoc_item(
scope: &SemanticsScope,
type_param: hir::TypeParam,
assoc_item: ast::NameRef,
) -> Option<hir::Trait> {
let db = scope.db;
let trait_bounds = type_param.trait_bounds(db);
let assoc_item_name = assoc_item.text();
for trait_ in trait_bounds {
let names = trait_.items(db).into_iter().filter_map(|item| match item {
hir::AssocItem::TypeAlias(ta) => Some(ta.name(db)),
hir::AssocItem::Const(cst) => cst.name(db),
_ => None,
});
for name in names {
if assoc_item_name.as_str() == name.as_text()?.as_str() {
// It is fine to return the first match because in case of
// multiple possibilities, the exact trait must be disambiguated
// in the definition of trait being implemented, so this search
// should not be needed.
return Some(trait_);
}
}
}
None
}

View file

@ -188,6 +188,14 @@ pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment {
ast_from_text(&format!("use {};", name_ref)) ast_from_text(&format!("use {};", name_ref))
} }
pub fn path_segment_ty(type_ref: ast::Type, trait_ref: Option<ast::PathType>) -> ast::PathSegment {
let text = match trait_ref {
Some(trait_ref) => format!("fn f(x: <{} as {}>) {{}}", type_ref, trait_ref),
None => format!("fn f(x: <{}>) {{}}", type_ref),
};
ast_from_text(&text)
}
pub fn path_segment_self() -> ast::PathSegment { pub fn path_segment_self() -> ast::PathSegment {
ast_from_text("use self;") ast_from_text("use self;")
} }
@ -218,9 +226,9 @@ pub fn path_from_segments(
) -> ast::Path { ) -> ast::Path {
let segments = segments.into_iter().map(|it| it.syntax().clone()).join("::"); let segments = segments.into_iter().map(|it| it.syntax().clone()).join("::");
ast_from_text(&if is_abs { ast_from_text(&if is_abs {
format!("use ::{};", segments) format!("fn f(x: ::{}) {{}}", segments)
} else { } else {
format!("use {};", segments) format!("fn f(x: {}) {{}}", segments)
}) })
} }