Adds to SSR match for semantically equivalent call and method call

This commit is contained in:
Mikhail Modin 2020-04-02 20:17:33 +01:00
parent 642f3f4bd6
commit 35a2cd08c1

View file

@ -5,12 +5,14 @@ use ra_db::{SourceDatabase, SourceDatabaseExt};
use ra_ide_db::symbol_index::SymbolsDatabase;
use ra_ide_db::RootDatabase;
use ra_syntax::ast::make::try_expr_from_text;
use ra_syntax::ast::{AstToken, Comment, RecordField, RecordLit};
use ra_syntax::{AstNode, SyntaxElement, SyntaxNode};
use ra_syntax::ast::{
ArgList, AstToken, CallExpr, Comment, Expr, MethodCallExpr, RecordField, RecordLit,
};
use ra_syntax::{AstNode, SyntaxElement, SyntaxKind, SyntaxNode};
use ra_text_edit::{TextEdit, TextEditBuilder};
use rustc_hash::FxHashMap;
use std::collections::HashMap;
use std::str::FromStr;
use std::{iter::once, str::FromStr};
#[derive(Debug, PartialEq)]
pub struct SsrError(String);
@ -219,6 +221,50 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
)
}
fn check_call_and_method_call(
pattern: CallExpr,
code: MethodCallExpr,
placeholders: &[Var],
match_: Match,
) -> Option<Match> {
let (pattern_name, pattern_type_args) = if let Some(Expr::PathExpr(path_exr)) =
pattern.expr()
{
let segment = path_exr.path().and_then(|p| p.segment());
(segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
} else {
(None, None)
};
let match_ = check_opt_nodes(pattern_name, code.name_ref(), placeholders, match_)?;
let match_ =
check_opt_nodes(pattern_type_args, code.type_arg_list(), placeholders, match_)?;
let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
let code_args = once(code.expr()?).chain(code_args);
check_iter(pattern_args, code_args, placeholders, match_)
}
fn check_method_call_and_call(
pattern: MethodCallExpr,
code: CallExpr,
placeholders: &[Var],
match_: Match,
) -> Option<Match> {
let (code_name, code_type_args) = if let Some(Expr::PathExpr(path_exr)) = code.expr() {
let segment = path_exr.path().and_then(|p| p.segment());
(segment.as_ref().and_then(|s| s.name_ref()), segment.and_then(|s| s.type_arg_list()))
} else {
(None, None)
};
let match_ = check_opt_nodes(pattern.name_ref(), code_name, placeholders, match_)?;
let match_ =
check_opt_nodes(pattern.type_arg_list(), code_type_args, placeholders, match_)?;
let code_args = code.syntax().children().find_map(ArgList::cast)?.args();
let pattern_args = pattern.syntax().children().find_map(ArgList::cast)?.args();
let pattern_args = once(pattern.expr()?).chain(pattern_args);
check_iter(pattern_args, code_args, placeholders, match_)
}
fn check_opt_nodes(
pattern: Option<impl AstNode>,
code: Option<impl AstNode>,
@ -227,8 +273,8 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
) -> Option<Match> {
match (pattern, code) {
(Some(pattern), Some(code)) => check(
&SyntaxElement::from(pattern.syntax().clone()),
&SyntaxElement::from(code.syntax().clone()),
&pattern.syntax().clone().into(),
&code.syntax().clone().into(),
placeholders,
match_,
),
@ -237,6 +283,33 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
}
}
fn check_iter<T, I1, I2>(
mut pattern: I1,
mut code: I2,
placeholders: &[Var],
match_: Match,
) -> Option<Match>
where
T: AstNode,
I1: Iterator<Item = T>,
I2: Iterator<Item = T>,
{
pattern
.by_ref()
.zip(code.by_ref())
.fold(Some(match_), |accum, (a, b)| {
accum.and_then(|match_| {
check(
&a.syntax().clone().into(),
&b.syntax().clone().into(),
placeholders,
match_,
)
})
})
.filter(|_| pattern.next().is_none() && code.next().is_none())
}
fn check(
pattern: &SyntaxElement,
code: &SyntaxElement,
@ -260,6 +333,14 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
(RecordLit::cast(pattern.clone()), RecordLit::cast(code.clone()))
{
check_record_lit(pattern, code, placeholders, match_)
} else if let (Some(pattern), Some(code)) =
(CallExpr::cast(pattern.clone()), MethodCallExpr::cast(code.clone()))
{
check_call_and_method_call(pattern, code, placeholders, match_)
} else if let (Some(pattern), Some(code)) =
(MethodCallExpr::cast(pattern.clone()), CallExpr::cast(code.clone()))
{
check_method_call_and_call(pattern, code, placeholders, match_)
} else {
let mut pattern_children = pattern
.children_with_tokens()
@ -290,16 +371,15 @@ fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches {
let kind = pattern.pattern.kind();
let matches = code
.descendants()
.filter(|n| n.kind() == kind)
.filter(|n| {
n.kind() == kind
|| (kind == SyntaxKind::CALL_EXPR && n.kind() == SyntaxKind::METHOD_CALL_EXPR)
|| (kind == SyntaxKind::METHOD_CALL_EXPR && n.kind() == SyntaxKind::CALL_EXPR)
})
.filter_map(|code| {
let match_ =
Match { place: code.clone(), binding: HashMap::new(), ignored_comments: vec![] };
check(
&SyntaxElement::from(pattern.pattern.clone()),
&SyntaxElement::from(code),
&pattern.vars,
match_,
)
check(&pattern.pattern.clone().into(), &code.into(), &pattern.vars, match_)
})
.collect();
SsrMatches { matches }
@ -498,4 +578,22 @@ mod tests {
"fn main() { foo::new(1, 2) }",
)
}
#[test]
fn ssr_call_and_method_call() {
assert_ssr_transform(
"foo::<'a>($a:expr, $b:expr)) ==>> foo2($a, $b)",
"fn main() { get().bar.foo::<'a>(1); }",
"fn main() { foo2(get().bar, 1); }",
)
}
#[test]
fn ssr_method_call_and_call() {
assert_ssr_transform(
"$o:expr.foo::<i32>($a:expr)) ==>> $o.foo2($a)",
"fn main() { X::foo::<i32>(x, 1); }",
"fn main() { x.foo2(1); }",
)
}
}