diff --git a/crates/ra_ssr/src/lib.rs b/crates/ra_ssr/src/lib.rs index 2fb326b45f..7014a6ac66 100644 --- a/crates/ra_ssr/src/lib.rs +++ b/crates/ra_ssr/src/lib.rs @@ -51,8 +51,7 @@ pub struct MatchFinder<'db> { /// Our source of information about the user's code. sema: Semantics<'db, ra_ide_db::RootDatabase>, rules: Vec, - scope: hir::SemanticsScope<'db>, - hygiene: hir::Hygiene, + resolution_scope: resolving::ResolutionScope<'db>, } impl<'db> MatchFinder<'db> { @@ -63,21 +62,8 @@ impl<'db> MatchFinder<'db> { lookup_context: FilePosition, ) -> MatchFinder<'db> { let sema = Semantics::new(db); - let file = sema.parse(lookup_context.file_id); - // Find a node at the requested position, falling back to the whole file. - let node = file - .syntax() - .token_at_offset(lookup_context.offset) - .left_biased() - .map(|token| token.parent()) - .unwrap_or_else(|| file.syntax().clone()); - let scope = sema.scope(&node); - MatchFinder { - sema: Semantics::new(db), - rules: Vec::new(), - scope, - hygiene: hir::Hygiene::new(db, lookup_context.file_id.into()), - } + let resolution_scope = resolving::ResolutionScope::new(&sema, lookup_context); + MatchFinder { sema: Semantics::new(db), rules: Vec::new(), resolution_scope } } /// Constructs an instance using the start of the first file in `db` as the lookup context. @@ -106,8 +92,7 @@ impl<'db> MatchFinder<'db> { for parsed_rule in rule.parsed_rules { self.rules.push(ResolvedRule::new( parsed_rule, - &self.scope, - &self.hygiene, + &self.resolution_scope, self.rules.len(), )?); } @@ -140,8 +125,7 @@ impl<'db> MatchFinder<'db> { for parsed_rule in pattern.parsed_rules { self.rules.push(ResolvedRule::new( parsed_rule, - &self.scope, - &self.hygiene, + &self.resolution_scope, self.rules.len(), )?); } diff --git a/crates/ra_ssr/src/resolving.rs b/crates/ra_ssr/src/resolving.rs index 75f5567856..123bd2bb24 100644 --- a/crates/ra_ssr/src/resolving.rs +++ b/crates/ra_ssr/src/resolving.rs @@ -3,10 +3,16 @@ use crate::errors::error; use crate::{parsing, SsrError}; use parsing::Placeholder; +use ra_db::FilePosition; use ra_syntax::{ast, SmolStr, SyntaxKind, SyntaxNode, SyntaxToken}; use rustc_hash::{FxHashMap, FxHashSet}; use test_utils::mark; +pub(crate) struct ResolutionScope<'db> { + scope: hir::SemanticsScope<'db>, + hygiene: hir::Hygiene, +} + pub(crate) struct ResolvedRule { pub(crate) pattern: ResolvedPattern, pub(crate) template: Option, @@ -30,12 +36,11 @@ pub(crate) struct ResolvedPath { impl ResolvedRule { pub(crate) fn new( rule: parsing::ParsedRule, - scope: &hir::SemanticsScope, - hygiene: &hir::Hygiene, + resolution_scope: &ResolutionScope, index: usize, ) -> Result { let resolver = - Resolver { scope, hygiene, placeholders_by_stand_in: rule.placeholders_by_stand_in }; + Resolver { resolution_scope, placeholders_by_stand_in: rule.placeholders_by_stand_in }; let resolved_template = if let Some(template) = rule.template { Some(resolver.resolve_pattern_tree(template)?) } else { @@ -57,8 +62,7 @@ impl ResolvedRule { } struct Resolver<'a, 'db> { - scope: &'a hir::SemanticsScope<'db>, - hygiene: &'a hir::Hygiene, + resolution_scope: &'a ResolutionScope<'db>, placeholders_by_stand_in: FxHashMap, } @@ -104,6 +108,7 @@ impl Resolver<'_, '_> { && !self.path_contains_placeholder(&path) { let resolution = self + .resolution_scope .resolve_path(&path) .ok_or_else(|| error!("Failed to resolve path `{}`", node.text()))?; resolved_paths.insert(node, ResolvedPath { resolution, depth }); @@ -131,9 +136,32 @@ impl Resolver<'_, '_> { } false } +} + +impl<'db> ResolutionScope<'db> { + pub(crate) fn new( + sema: &hir::Semantics<'db, ra_ide_db::RootDatabase>, + lookup_context: FilePosition, + ) -> ResolutionScope<'db> { + use ra_syntax::ast::AstNode; + let file = sema.parse(lookup_context.file_id); + // Find a node at the requested position, falling back to the whole file. + let node = file + .syntax() + .token_at_offset(lookup_context.offset) + .left_biased() + .map(|token| token.parent()) + .unwrap_or_else(|| file.syntax().clone()); + let node = pick_node_for_resolution(node); + let scope = sema.scope(&node); + ResolutionScope { + scope, + hygiene: hir::Hygiene::new(sema.db, lookup_context.file_id.into()), + } + } fn resolve_path(&self, path: &ast::Path) -> Option { - let hir_path = hir::Path::from_src(path.clone(), self.hygiene)?; + let hir_path = hir::Path::from_src(path.clone(), &self.hygiene)?; // First try resolving the whole path. This will work for things like // `std::collections::HashMap`, but will fail for things like // `std::collections::HashMap::new`. @@ -158,6 +186,33 @@ impl Resolver<'_, '_> { } } +/// Returns a suitable node for resolving paths in the current scope. If we create a scope based on +/// a statement node, then we can't resolve local variables that were defined in the current scope +/// (only in parent scopes). So we find another node, ideally a child of the statement where local +/// variable resolution is permitted. +fn pick_node_for_resolution(node: SyntaxNode) -> SyntaxNode { + match node.kind() { + SyntaxKind::EXPR_STMT => { + if let Some(n) = node.first_child() { + mark::hit!(cursor_after_semicolon); + return n; + } + } + SyntaxKind::LET_STMT | SyntaxKind::BIND_PAT => { + if let Some(next) = node.next_sibling() { + return pick_node_for_resolution(next); + } + } + SyntaxKind::NAME => { + if let Some(parent) = node.parent() { + return pick_node_for_resolution(parent); + } + } + _ => {} + } + node +} + /// Returns whether `path` or any of its qualifiers contains type arguments. fn path_contains_type_arguments(path: Option) -> bool { if let Some(path) = path { diff --git a/crates/ra_ssr/src/tests.rs b/crates/ra_ssr/src/tests.rs index b38807c0f9..18ef2506af 100644 --- a/crates/ra_ssr/src/tests.rs +++ b/crates/ra_ssr/src/tests.rs @@ -885,3 +885,40 @@ fn ufcs_matches_method_call() { "#]], ); } + +#[test] +fn replace_local_variable_reference() { + // The pattern references a local variable `foo` in the block containing the cursor. We should + // only replace references to this variable `foo`, not other variables that just happen to have + // the same name. + mark::check!(cursor_after_semicolon); + assert_ssr_transform( + "foo + $a ==>> $a - foo", + r#" + fn bar1() -> i32 { + let mut res = 0; + let foo = 5; + res += foo + 1; + let foo = 10; + res += foo + 2;<|> + res += foo + 3; + let foo = 15; + res += foo + 4; + res + } + "#, + expect![[r#" + fn bar1() -> i32 { + let mut res = 0; + let foo = 5; + res += foo + 1; + let foo = 10; + res += 2 - foo; + res += 3 - foo; + let foo = 15; + res += foo + 4; + res + } + "#]], + ) +}