mirror of
https://github.com/rust-lang/rust-analyzer
synced 2025-01-13 21:54:42 +00:00
Merge #3829
3829: Adds to SSR match for semantically equivalent call and method call r=matklad a=mikhail-m1 #3186 maybe I've missed some corner cases, but it works in general Co-authored-by: Mikhail Modin <mikhailm1@gmail.com>
This commit is contained in:
commit
4f904b2970
1 changed files with 110 additions and 12 deletions
|
@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt};
|
|||
use ra_ide_db::symbol_index::SymbolsDatabase;
|
||||
use ra_ide_db::RootDatabase;
|
||||
use ra_syntax::ast::make::try_expr_from_text;
|
||||
use ra_syntax::ast::{AstToken, Comment, RecordField, RecordLit};
|
||||
use ra_syntax::{AstNode, SyntaxElement, SyntaxNode};
|
||||
use ra_syntax::ast::{
|
||||
ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, RecordField, RecordLit,
|
||||
};
|
||||
use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode};
|
||||
use ra_text_edit::{TextEdit, TextEditBuilder};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
use std::{iter::once, str::FromStr};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct SsrError(String);
|
||||
|
@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
|
|||
)
|
||||
}
|
||||
|
||||
fn check_call_and_method_call(
|
||||
pattern: CallExpr,
|
||||
code: MethodCallExpr,
|
||||
placeholders: &[Var],
|
||||
match_: Match,
|
||||
) -> Option<Match> {
|
||||
let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) =
|
||||
pattern.expr()
|
||||
{
|
||||
let segment = path_exr.path().and_then(|p| p.segment());
|
||||
(segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?;
|
||||
let match_ =
|
||||
check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?;
|
||||
let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
|
||||
let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
|
||||
let code_args = once(code.expr()?).chain(code_args);
|
||||
check_iter(pattern_args, code_args, placeholders, match_)
|
||||
}
|
||||
|
||||
fn check_method_call_and_call(
|
||||
pattern: MethodCallExpr,
|
||||
code: CallExpr,
|
||||
placeholders: &[Var],
|
||||
match_: Match,
|
||||
) -> Option<Match> {
|
||||
let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() {
|
||||
let segment = path_exr.path().and_then(|p| p.segment());
|
||||
(segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?;
|
||||
let match_ =
|
||||
check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?;
|
||||
let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
|
||||
let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
|
||||
let pattern_args = once(pattern.expr()?).chain(pattern_args);
|
||||
check_iter(pattern_args, code_args, placeholders, match_)
|
||||
}
|
||||
|
||||
fn check_opt_nodes(
|
||||
pattern: Option<impl AstNode>,
|
||||
code: Option<impl AstNode>,
|
||||
|
@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
|
|||
) -> Option<Match> {
|
||||
match (pattern, code) {
|
||||
(Some(pattern), Some(code)) => check(
|
||||
&SyntaxElement::from(pattern.syntax().clone()),
|
||||
&SyntaxElement::from(code.syntax().clone()),
|
||||
&pattern.syntax().clone().into(),
|
||||
&code.syntax().clone().into(),
|
||||
placeholders,
|
||||
match_,
|
||||
),
|
||||
|
@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
|
|||
}
|
||||
}
|
||||
|
||||
fn check_iter<T, I1, I2>(
|
||||
mut pattern: I1,
|
||||
mut code: I2,
|
||||
placeholders: &[Var],
|
||||
match_: Match,
|
||||
) -> Option<Match>
|
||||
where
|
||||
T: AstNode,
|
||||
I1: Iterator<Item = T>,
|
||||
I2: Iterator<Item = T>,
|
||||
{
|
||||
pattern
|
||||
.by_ref()
|
||||
.zip(code.by_ref())
|
||||
.fold(Some(match_), |accum, (a, b)| {
|
||||
accum.and_then(|match_| {
|
||||
check(
|
||||
&a.syntax().clone().into(),
|
||||
&b.syntax().clone().into(),
|
||||
placeholders,
|
||||
match_,
|
||||
)
|
||||
})
|
||||
})
|
||||
.filter(|_| pattern.next().is_none() && code.next().is_none())
|
||||
}
|
||||
|
||||
fn check(
|
||||
pattern: &SyntaxElement,
|
||||
code: &SyntaxElement,
|
||||
|
@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
|
|||
(RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone()))
|
||||
{
|
||||
check_record_lit(pattern, code, placeholders, match_)
|
||||
} else if let (Some(pattern), Some(code)) =
|
||||
(CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone()))
|
||||
{
|
||||
check_call_and_method_call(pattern, code, placeholders, match_)
|
||||
} else if let (Some(pattern), Some(code)) =
|
||||
(MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone()))
|
||||
{
|
||||
check_method_call_and_call(pattern, code, placeholders, match_)
|
||||
} else {
|
||||
let mut pattern_children = pattern
|
||||
.children_with_tokens()
|
||||
|
@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
|
|||
let kind = pattern.pattern.kind();
|
||||
let matches = code
|
||||
.descendants()
|
||||
.filter(|n| n.kind() == kind)
|
||||
.filter(|n| {
|
||||
n.kind() == kind
|
||||
|| (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR)
|
||||
|| (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR)
|
||||
})
|
||||
.filter_map(|code| {
|
||||
let match_ =
|
||||
Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] };
|
||||
check(
|
||||
&SyntaxElement::from(pattern.pattern.clone()),
|
||||
&SyntaxElement::from(code),
|
||||
&pattern.vars,
|
||||
match_,
|
||||
)
|
||||
check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_)
|
||||
})
|
||||
.collect();
|
||||
SsrMatches { matches }
|
||||
|
@ -498,4 +578,22 @@ mod tests {
|
|||
"fn main() { foo::new(1, 2) }",
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssr_call_and_method_call() {
|
||||
assert_ssr_transform(
|
||||
"foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)",
|
||||
"fn main() { get().bar.foo::<'a>(1); }",
|
||||
"fn main() { foo2(get().bar, 1); }",
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssr_method_call_and_call() {
|
||||
assert_ssr_transform(
|
||||
"$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)",
|
||||
"fn main() { X::foo::<i32>(x, 1); }",
|
||||
"fn main() { x.foo2(1); }",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue