diff --git a/crates/ra_ssr/src/resolving.rs b/crates/ra_ssr/src/resolving.rs index d53981737a..123bd2bb24 100644 --- a/crates/ra_ssr/src/resolving.rs +++ b/crates/ra_ssr/src/resolving.rs @@ -152,6 +152,7 @@ impl<'db> ResolutionScope<'db> { .left_biased() .map(|token| token.parent()) .unwrap_or_else(|| file.syntax().clone()); + let node = pick_node_for_resolution(node); let scope = sema.scope(&node); ResolutionScope { scope, @@ -185,6 +186,33 @@ impl<'db> ResolutionScope<'db> { } } +/// Returns a suitable node for resolving paths in the current scope. If we create a scope based on +/// a statement node, then we can't resolve local variables that were defined in the current scope +/// (only in parent scopes). So we find another node, ideally a child of the statement where local +/// variable resolution is permitted. +fn pick_node_for_resolution(node: SyntaxNode) -> SyntaxNode { + match node.kind() { + SyntaxKind::EXPR_STMT => { + if let Some(n) = node.first_child() { + mark::hit!(cursor_after_semicolon); + return n; + } + } + SyntaxKind::LET_STMT | SyntaxKind::BIND_PAT => { + if let Some(next) = node.next_sibling() { + return pick_node_for_resolution(next); + } + } + SyntaxKind::NAME => { + if let Some(parent) = node.parent() { + return pick_node_for_resolution(parent); + } + } + _ => {} + } + node +} + /// Returns whether `path` or any of its qualifiers contains type arguments. fn path_contains_type_arguments(path: Option) -> bool { if let Some(path) = path { diff --git a/crates/ra_ssr/src/tests.rs b/crates/ra_ssr/src/tests.rs index b38807c0f9..18ef2506af 100644 --- a/crates/ra_ssr/src/tests.rs +++ b/crates/ra_ssr/src/tests.rs @@ -885,3 +885,40 @@ fn ufcs_matches_method_call() { "#]], ); } + +#[test] +fn replace_local_variable_reference() { + // The pattern references a local variable `foo` in the block containing the cursor. We should + // only replace references to this variable `foo`, not other variables that just happen to have + // the same name. + mark::check!(cursor_after_semicolon); + assert_ssr_transform( + "foo + $a ==>> $a - foo", + r#" + fn bar1() -> i32 { + let mut res = 0; + let foo = 5; + res += foo + 1; + let foo = 10; + res += foo + 2;<|> + res += foo + 3; + let foo = 15; + res += foo + 4; + res + } + "#, + expect![[r#" + fn bar1() -> i32 { + let mut res = 0; + let foo = 5; + res += foo + 1; + let foo = 10; + res += 2 - foo; + res += 3 - foo; + let foo = 15; + res += foo + 4; + res + } + "#]], + ) +}