diff --git a/Cargo.lock b/Cargo.lock index 7abfeaaebe..ffa3851068 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1455,6 +1455,7 @@ dependencies = [ "expect", "hir", "ide_db", + "itertools", "rustc-hash", "syntax", "test_utils", diff --git a/crates/ssr/Cargo.toml b/crates/ssr/Cargo.toml index 56c1f77618..7c2090de3c 100644 --- a/crates/ssr/Cargo.toml +++ b/crates/ssr/Cargo.toml @@ -12,6 +12,7 @@ doctest = false [dependencies] rustc-hash = "1.1.0" +itertools = "0.9.0" text_edit = { path = "../text_edit" } syntax = { path = "../syntax" } diff --git a/crates/ssr/src/lib.rs b/crates/ssr/src/lib.rs index 292bd5b9a7..ba669fd56c 100644 --- a/crates/ssr/src/lib.rs +++ b/crates/ssr/src/lib.rs @@ -21,7 +21,10 @@ // code in the `foo` module, we'll insert just `Bar`. // // Inherent method calls should generally be written in UFCS form. e.g. `foo::Bar::baz($s, $a)` will -// match `$s.baz($a)`, provided the method call `baz` resolves to the method `foo::Bar::baz`. +// match `$s.baz($a)`, provided the method call `baz` resolves to the method `foo::Bar::baz`. When a +// placeholder is the receiver of a method call in the search pattern (e.g. `$s.foo()`), but not in +// the replacement template (e.g. `bar($s)`), then *, & and &mut will be added as needed to mirror +// whatever autoderef and autoref was happening implicitly in the matched code. // // The scope of the search / replace will be restricted to the current selection if any, otherwise // it will apply to the whole workspace. diff --git a/crates/ssr/src/matching.rs b/crates/ssr/src/matching.rs index ffc7202ae5..8bb5ced900 100644 --- a/crates/ssr/src/matching.rs +++ b/crates/ssr/src/matching.rs @@ -2,7 +2,7 @@ //! process of matching, placeholder values are recorded. use crate::{ - parsing::{Constraint, NodeKind, Placeholder}, + parsing::{Constraint, NodeKind, Placeholder, Var}, resolving::{ResolvedPattern, ResolvedRule, UfcsCallInfo}, SsrMatches, }; @@ -56,10 +56,6 @@ pub struct Match { pub(crate) rendered_template_paths: FxHashMap, } -/// Represents a `$var` in an SSR query. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) struct Var(pub String); - /// Information about a placeholder bound in a match. #[derive(Debug)] pub(crate) struct PlaceholderMatch { @@ -69,6 +65,10 @@ pub(crate) struct PlaceholderMatch { pub(crate) range: FileRange, /// More matches, found within `node`. pub(crate) inner_matches: SsrMatches, + /// How many times the code that the placeholder matched needed to be dereferenced. Will only be + /// non-zero if the placeholder matched to the receiver of a method call. + pub(crate) autoderef_count: usize, + pub(crate) autoref_kind: ast::SelfParamKind, } #[derive(Debug)] @@ -173,7 +173,7 @@ impl<'db, 'sema> Matcher<'db, 'sema> { code: &SyntaxNode, ) -> Result<(), MatchFailed> { // Handle placeholders. - if let Some(placeholder) = self.get_placeholder(&SyntaxElement::Node(pattern.clone())) { + if let Some(placeholder) = self.get_placeholder_for_node(pattern) { for constraint in &placeholder.constraints { self.check_constraint(constraint, code)?; } @@ -183,8 +183,8 @@ impl<'db, 'sema> Matcher<'db, 'sema> { // probably can't fail range validation, but just to be safe... self.validate_range(&original_range)?; matches_out.placeholder_values.insert( - Var(placeholder.ident.to_string()), - PlaceholderMatch::new(code, original_range), + placeholder.ident.clone(), + PlaceholderMatch::new(Some(code), original_range), ); } return Ok(()); @@ -487,7 +487,7 @@ impl<'db, 'sema> Matcher<'db, 'sema> { } if let Phase::Second(match_out) = phase { match_out.placeholder_values.insert( - Var(placeholder.ident.to_string()), + placeholder.ident.clone(), PlaceholderMatch::from_range(FileRange { file_id: self.sema.original_range(code).file_id, range: first_matched_token @@ -536,18 +536,40 @@ impl<'db, 'sema> Matcher<'db, 'sema> { if pattern_ufcs.function != code_resolved_function { fail_match!("Method call resolved to a different function"); } - if code_resolved_function.has_self_param(self.sema.db) { - if let (Some(pattern_type), Some(expr)) = (&pattern_ufcs.qualifier_type, &code.expr()) { - self.check_expr_type(pattern_type, expr)?; - } - } // Check arguments. let mut pattern_args = pattern_ufcs .call_expr .arg_list() .ok_or_else(|| match_error!("Pattern function call has no args"))? .args(); - self.attempt_match_opt(phase, pattern_args.next(), code.expr())?; + // If the function we're calling takes a self parameter, then we store additional + // information on the placeholder match about autoderef and autoref. This allows us to use + // the placeholder in a context where autoderef and autoref don't apply. + if code_resolved_function.has_self_param(self.sema.db) { + if let (Some(pattern_type), Some(expr)) = (&pattern_ufcs.qualifier_type, &code.expr()) { + let deref_count = self.check_expr_type(pattern_type, expr)?; + let pattern_receiver = pattern_args.next(); + self.attempt_match_opt(phase, pattern_receiver.clone(), code.expr())?; + if let Phase::Second(match_out) = phase { + if let Some(placeholder_value) = pattern_receiver + .and_then(|n| self.get_placeholder_for_node(n.syntax())) + .and_then(|placeholder| { + match_out.placeholder_values.get_mut(&placeholder.ident) + }) + { + placeholder_value.autoderef_count = deref_count; + placeholder_value.autoref_kind = self + .sema + .resolve_method_call_as_callable(code) + .and_then(|callable| callable.receiver_param(self.sema.db)) + .map(|self_param| self_param.kind()) + .unwrap_or(ast::SelfParamKind::Owned); + } + } + } + } else { + self.attempt_match_opt(phase, pattern_args.next(), code.expr())?; + } let mut code_args = code.arg_list().ok_or_else(|| match_error!("Code method call has no args"))?.args(); loop { @@ -575,26 +597,35 @@ impl<'db, 'sema> Matcher<'db, 'sema> { self.attempt_match_node_children(phase, pattern_ufcs.call_expr.syntax(), code.syntax()) } + /// Verifies that `expr` matches `pattern_type`, possibly after dereferencing some number of + /// times. Returns the number of times it needed to be dereferenced. fn check_expr_type( &self, pattern_type: &hir::Type, expr: &ast::Expr, - ) -> Result<(), MatchFailed> { + ) -> Result { use hir::HirDisplay; let code_type = self.sema.type_of_expr(&expr).ok_or_else(|| { match_error!("Failed to get receiver type for `{}`", expr.syntax().text()) })?; - if !code_type + // Temporary needed to make the borrow checker happy. + let res = code_type .autoderef(self.sema.db) - .any(|deref_code_type| *pattern_type == deref_code_type) - { - fail_match!( - "Pattern type `{}` didn't match code type `{}`", - pattern_type.display(self.sema.db), - code_type.display(self.sema.db) - ); - } - Ok(()) + .enumerate() + .find(|(_, deref_code_type)| pattern_type == deref_code_type) + .map(|(count, _)| count) + .ok_or_else(|| { + match_error!( + "Pattern type `{}` didn't match code type `{}`", + pattern_type.display(self.sema.db), + code_type.display(self.sema.db) + ) + }); + res + } + + fn get_placeholder_for_node(&self, node: &SyntaxNode) -> Option<&Placeholder> { + self.get_placeholder(&SyntaxElement::Node(node.clone())) } fn get_placeholder(&self, element: &SyntaxElement) -> Option<&Placeholder> { @@ -676,12 +707,18 @@ fn recording_match_fail_reasons() -> bool { } impl PlaceholderMatch { - fn new(node: &SyntaxNode, range: FileRange) -> Self { - Self { node: Some(node.clone()), range, inner_matches: SsrMatches::default() } + fn new(node: Option<&SyntaxNode>, range: FileRange) -> Self { + Self { + node: node.cloned(), + range, + inner_matches: SsrMatches::default(), + autoderef_count: 0, + autoref_kind: ast::SelfParamKind::Owned, + } } fn from_range(range: FileRange) -> Self { - Self { node: None, range, inner_matches: SsrMatches::default() } + Self::new(None, range) } } diff --git a/crates/ssr/src/parsing.rs b/crates/ssr/src/parsing.rs index 9570e96e36..05b66dcd78 100644 --- a/crates/ssr/src/parsing.rs +++ b/crates/ssr/src/parsing.rs @@ -8,7 +8,7 @@ use crate::errors::bail; use crate::{SsrError, SsrPattern, SsrRule}; use rustc_hash::{FxHashMap, FxHashSet}; -use std::str::FromStr; +use std::{fmt::Display, str::FromStr}; use syntax::{ast, AstNode, SmolStr, SyntaxKind, SyntaxNode, T}; use test_utils::mark; @@ -34,12 +34,16 @@ pub(crate) enum PatternElement { #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct Placeholder { /// The name of this placeholder. e.g. for "$a", this would be "a" - pub(crate) ident: SmolStr, + pub(crate) ident: Var, /// A unique name used in place of this placeholder when we parse the pattern as Rust code. stand_in_name: String, pub(crate) constraints: Vec, } +/// Represents a `$var` in an SSR query. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct Var(pub String); + #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum Constraint { Kind(NodeKind), @@ -205,7 +209,7 @@ fn parse_pattern(pattern_str: &str) -> Result, SsrError> { if token.kind == T![$] { let placeholder = parse_placeholder(&mut tokens)?; if !placeholder_names.insert(placeholder.ident.clone()) { - bail!("Name `{}` repeats more than once", placeholder.ident); + bail!("Placeholder `{}` repeats more than once", placeholder.ident); } res.push(PatternElement::Placeholder(placeholder)); } else { @@ -228,7 +232,7 @@ fn validate_rule(rule: &SsrRule) -> Result<(), SsrError> { for p in &rule.template.tokens { if let PatternElement::Placeholder(placeholder) = p { if !defined_placeholders.contains(&placeholder.ident) { - undefined.push(format!("${}", placeholder.ident)); + undefined.push(placeholder.ident.to_string()); } if !placeholder.constraints.is_empty() { bail!("Replacement placeholders cannot have constraints"); @@ -344,7 +348,17 @@ impl NodeKind { impl Placeholder { fn new(name: SmolStr, constraints: Vec) -> Self { - Self { stand_in_name: format!("__placeholder_{}", name), constraints, ident: name } + Self { + stand_in_name: format!("__placeholder_{}", name), + constraints, + ident: Var(name.to_string()), + } + } +} + +impl Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "${}", self.0) } } diff --git a/crates/ssr/src/replacing.rs b/crates/ssr/src/replacing.rs index 8f8fe6149a..29284e3f1c 100644 --- a/crates/ssr/src/replacing.rs +++ b/crates/ssr/src/replacing.rs @@ -1,10 +1,11 @@ //! Code for applying replacement templates for matches that have previously been found. -use crate::matching::Var; use crate::{resolving::ResolvedRule, Match, SsrMatches}; +use itertools::Itertools; use rustc_hash::{FxHashMap, FxHashSet}; use syntax::ast::{self, AstToken}; use syntax::{SyntaxElement, SyntaxKind, SyntaxNode, SyntaxToken, TextRange, TextSize}; +use test_utils::mark; use text_edit::TextEdit; /// Returns a text edit that will replace each match in `matches` with its corresponding replacement @@ -114,11 +115,33 @@ impl ReplacementRenderer<'_> { fn render_token(&mut self, token: &SyntaxToken) { if let Some(placeholder) = self.rule.get_placeholder(&token) { if let Some(placeholder_value) = - self.match_info.placeholder_values.get(&Var(placeholder.ident.to_string())) + self.match_info.placeholder_values.get(&placeholder.ident) { let range = &placeholder_value.range.range; let mut matched_text = self.file_src[usize::from(range.start())..usize::from(range.end())].to_owned(); + // If a method call is performed directly on the placeholder, then autoderef and + // autoref will apply, so we can just substitute whatever the placeholder matched to + // directly. If we're not applying a method call, then we need to add explicitly + // deref and ref in order to match whatever was being done implicitly at the match + // site. + if !token_is_method_call_receiver(token) + && (placeholder_value.autoderef_count > 0 + || placeholder_value.autoref_kind != ast::SelfParamKind::Owned) + { + mark::hit!(replace_autoref_autoderef_capture); + let ref_kind = match placeholder_value.autoref_kind { + ast::SelfParamKind::Owned => "", + ast::SelfParamKind::Ref => "&", + ast::SelfParamKind::MutRef => "&mut ", + }; + matched_text = format!( + "{}{}{}", + ref_kind, + "*".repeat(placeholder_value.autoderef_count), + matched_text + ); + } let edit = matches_to_edit_at_offset( &placeholder_value.inner_matches, self.file_src, @@ -179,6 +202,29 @@ impl ReplacementRenderer<'_> { } } +/// Returns whether token is the receiver of a method call. Note, being within the receiver of a +/// method call doesn't count. e.g. if the token is `$a`, then `$a.foo()` will return true, while +/// `($a + $b).foo()` or `x.foo($a)` will return false. +fn token_is_method_call_receiver(token: &SyntaxToken) -> bool { + use syntax::ast::AstNode; + // Find the first method call among the ancestors of `token`, then check if the only token + // within the receiver is `token`. + if let Some(receiver) = + token.ancestors().find_map(ast::MethodCallExpr::cast).and_then(|call| call.expr()) + { + let tokens = receiver.syntax().descendants_with_tokens().filter_map(|node_or_token| { + match node_or_token { + SyntaxElement::Token(t) => Some(t), + _ => None, + } + }); + if let Some((only_token,)) = tokens.collect_tuple() { + return only_token == *token; + } + } + false +} + fn parse_as_kind(code: &str, kind: SyntaxKind) -> Option { use syntax::ast::AstNode; if ast::Expr::can_cast(kind) { diff --git a/crates/ssr/src/tests.rs b/crates/ssr/src/tests.rs index 0d0a000906..e45c88864d 100644 --- a/crates/ssr/src/tests.rs +++ b/crates/ssr/src/tests.rs @@ -31,7 +31,7 @@ fn parser_two_delimiters() { fn parser_repeated_name() { assert_eq!( parse_error_text("foo($a, $a) ==>>"), - "Parse error: Name `a` repeats more than once" + "Parse error: Placeholder `$a` repeats more than once" ); } @@ -1172,3 +1172,110 @@ fn match_trait_method_call() { assert_matches("Bar::foo($a, $b)", code, &["v1.foo(1)", "Bar::foo(&v1, 3)", "v1_ref.foo(5)"]); assert_matches("Bar2::foo($a, $b)", code, &["v2.foo(2)", "Bar2::foo(&v2, 4)", "v2_ref.foo(6)"]); } + +#[test] +fn replace_autoref_autoderef_capture() { + // Here we have several calls to `$a.foo()`. In the first case autoref is applied, in the + // second, we already have a reference, so it isn't. When $a is used in a context where autoref + // doesn't apply, we need to prefix it with `&`. Finally, we have some cases where autoderef + // needs to be applied. + mark::check!(replace_autoref_autoderef_capture); + let code = r#" + struct Foo {} + impl Foo { + fn foo(&self) {} + fn foo2(&self) {} + } + fn bar(_: &Foo) {} + fn main() { + let f = Foo {}; + let fr = &f; + let fr2 = &fr; + let fr3 = &fr2; + f.foo(); + fr.foo(); + fr2.foo(); + fr3.foo(); + } + "#; + assert_ssr_transform( + "Foo::foo($a) ==>> bar($a)", + code, + expect![[r#" + struct Foo {} + impl Foo { + fn foo(&self) {} + fn foo2(&self) {} + } + fn bar(_: &Foo) {} + fn main() { + let f = Foo {}; + let fr = &f; + let fr2 = &fr; + let fr3 = &fr2; + bar(&f); + bar(&*fr); + bar(&**fr2); + bar(&***fr3); + } + "#]], + ); + // If the placeholder is used as the receiver of another method call, then we don't need to + // explicitly autoderef or autoref. + assert_ssr_transform( + "Foo::foo($a) ==>> $a.foo2()", + code, + expect![[r#" + struct Foo {} + impl Foo { + fn foo(&self) {} + fn foo2(&self) {} + } + fn bar(_: &Foo) {} + fn main() { + let f = Foo {}; + let fr = &f; + let fr2 = &fr; + let fr3 = &fr2; + f.foo2(); + fr.foo2(); + fr2.foo2(); + fr3.foo2(); + } + "#]], + ); +} + +#[test] +fn replace_autoref_mut() { + let code = r#" + struct Foo {} + impl Foo { + fn foo(&mut self) {} + } + fn bar(_: &mut Foo) {} + fn main() { + let mut f = Foo {}; + f.foo(); + let fr = &mut f; + fr.foo(); + } + "#; + assert_ssr_transform( + "Foo::foo($a) ==>> bar($a)", + code, + expect![[r#" + struct Foo {} + impl Foo { + fn foo(&mut self) {} + } + fn bar(_: &mut Foo) {} + fn main() { + let mut f = Foo {}; + bar(&mut f); + let fr = &mut f; + bar(&mut *fr); + } + "#]], + ); +}