8524: Fix extract function with partial block selection r=matklad a=brandondong

**Reproduction:**
```rust
fn foo() {
    let n = 1;
    let mut v = $0n * n;$0
    v += 1;
}
```
1. Select the snippet ($0) and use the "Extract into function" assist.
2. Extracted function is incorrect and does not compile:
```rust
fn foo() {
    let n = 1;
    let mut v = fun_name(n);
    v += 1;
}

fn fun_name(n: i32) {}
```
3. Omitting the ending semicolon from the selection fixes the extracted function:
```rust
fn fun_name(n: i32) -> i32 {
    n * n
}
```

**Cause:**
- When `extraction_target` uses a block extraction (semicolon case) instead of an expression extraction (no semicolon case), the user selection is directly used as the TextRange.
- However, the existing function extraction logic for blocks requires that the TextRange spans from start to end of complete statements to work correctly.
- For example:
```rust
fn foo() {
    let m = 2;
    let n = 1;
    let mut v = m $0* n;
    let mut w = 3;$0
    v += 1;
    w += 1;
}
```
produces
```rust
fn foo() {
    let m = 2;
    let n = 1;
    let mut v = m let mut w = fun_name(n);
    v += 1;
    w += 1;
}

fn fun_name(n: i32) -> i32 {
    let mut w = 3;
    w
}
```
- The user selected TextRange is directly replaced by the function call which is now in the middle of another statement. The extracted function body only contains statements that were fully covered by the TextRange and so the `* n` code is deleted. The logic for calculating variable usage and outlived variables for the function parameters and return type respectively search within the TextRange and so do not include `m` or `v`.

**Fix:**
- Only extract full statements when using block extraction. If a user selected part of a statement, extract that full statement.

8527: Switch introduce_named_lifetime assist to use mutable syntax tree  r=matklad a=iDawer

This extends `GenericParamsOwnerEdit` trait with `get_or_create_generic_param_list` method

Co-authored-by: Brandon <brandondong604@hotmail.com>
Co-authored-by: Dawer <7803845+iDawer@users.noreply.github.com>
This commit is contained in:
bors[bot] 2021-04-19 13:09:18 +00:00 committed by GitHub
commit e7a8977358
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 298 additions and 38 deletions

View file

@ -599,7 +599,12 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
// we have selected a few statements in a block
// so covering_element returns the whole block
if node.kind() == BLOCK_EXPR {
let body = FunctionBody::from_range(node.clone(), selection_range);
// Extract the full statements.
let statements_range = node
.children()
.filter(|c| selection_range.intersect(c.text_range()).is_some())
.fold(selection_range, |acc, c| acc.cover(c.text_range()));
let body = FunctionBody::from_range(node.clone(), statements_range);
if body.is_some() {
return body;
}
@ -610,7 +615,8 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
// so we try to expand covering_element to parent and repeat the previous
if let Some(parent) = node.parent() {
if parent.kind() == BLOCK_EXPR {
let body = FunctionBody::from_range(parent, selection_range);
// Extract the full statement.
let body = FunctionBody::from_range(parent, node.text_range());
if body.is_some() {
return body;
}
@ -1784,6 +1790,60 @@ fn $0fun_name() -> i32 {
);
}
#[test]
fn extract_partial_block_single_line() {
check_assist(
extract_function,
r#"
fn foo() {
let n = 1;
let mut v = $0n * n;$0
v += 1;
}"#,
r#"
fn foo() {
let n = 1;
let mut v = fun_name(n);
v += 1;
}
fn $0fun_name(n: i32) -> i32 {
let mut v = n * n;
v
}"#,
);
}
#[test]
fn extract_partial_block() {
check_assist(
extract_function,
r#"
fn foo() {
let m = 2;
let n = 1;
let mut v = m $0* n;
let mut w = 3;$0
v += 1;
w += 1;
}"#,
r#"
fn foo() {
let m = 2;
let n = 1;
let (mut v, mut w) = fun_name(m, n);
v += 1;
w += 1;
}
fn $0fun_name(m: i32, n: i32) -> (i32, i32) {
let mut v = m * n;
let mut w = 3;
(v, w)
}"#,
);
}
#[test]
fn argument_form_expr() {
check_assist(

View file

@ -1,7 +1,8 @@
use rustc_hash::FxHashSet;
use syntax::{
ast::{self, GenericParamsOwner, NameOwner},
AstNode, TextRange, TextSize,
ast::{self, edit_in_place::GenericParamsOwnerEdit, make, GenericParamsOwner},
ted::{self, Position},
AstNode, TextRange,
};
use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists};
@ -37,10 +38,12 @@ static ASSIST_LABEL: &str = "Introduce named lifetime";
pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
let lifetime =
ctx.find_node_at_offset::<ast::Lifetime>().filter(|lifetime| lifetime.text() == "'_")?;
let lifetime_loc = lifetime.lifetime_ident_token()?.text_range();
if let Some(fn_def) = lifetime.syntax().ancestors().find_map(ast::Fn::cast) {
generate_fn_def_assist(acc, &fn_def, lifetime.lifetime_ident_token()?.text_range())
generate_fn_def_assist(acc, fn_def, lifetime_loc, lifetime)
} else if let Some(impl_def) = lifetime.syntax().ancestors().find_map(ast::Impl::cast) {
generate_impl_def_assist(acc, &impl_def, lifetime.lifetime_ident_token()?.text_range())
generate_impl_def_assist(acc, impl_def, lifetime_loc, lifetime)
} else {
None
}
@ -49,26 +52,26 @@ pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -
/// Generate the assist for the fn def case
fn generate_fn_def_assist(
acc: &mut Assists,
fn_def: &ast::Fn,
fn_def: ast::Fn,
lifetime_loc: TextRange,
lifetime: ast::Lifetime,
) -> Option<()> {
let param_list: ast::ParamList = fn_def.param_list()?;
let new_lifetime_param = generate_unique_lifetime_param_name(&fn_def.generic_param_list())?;
let end_of_fn_ident = fn_def.name()?.ident_token()?.text_range().end();
let new_lifetime_param = generate_unique_lifetime_param_name(fn_def.generic_param_list())?;
let self_param =
// use the self if it's a reference and has no explicit lifetime
param_list.self_param().filter(|p| p.lifetime().is_none() && p.amp_token().is_some());
// compute the location which implicitly has the same lifetime as the anonymous lifetime
let loc_needing_lifetime = if let Some(self_param) = self_param {
// if we have a self reference, use that
Some(self_param.name()?.syntax().text_range().start())
Some(NeedsLifetime::SelfParam(self_param))
} else {
// otherwise, if there's a single reference parameter without a named liftime, use that
let fn_params_without_lifetime: Vec<_> = param_list
.params()
.filter_map(|param| match param.ty() {
Some(ast::Type::RefType(ascribed_type)) if ascribed_type.lifetime().is_none() => {
Some(ascribed_type.amp_token()?.text_range().end())
Some(NeedsLifetime::RefType(ascribed_type))
}
_ => None,
})
@ -81,30 +84,46 @@ fn generate_fn_def_assist(
}
};
acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| {
add_lifetime_param(fn_def, builder, end_of_fn_ident, new_lifetime_param);
builder.replace(lifetime_loc, format!("'{}", new_lifetime_param));
loc_needing_lifetime.map(|loc| builder.insert(loc, format!("'{} ", new_lifetime_param)));
let fn_def = builder.make_ast_mut(fn_def);
let lifetime = builder.make_ast_mut(lifetime);
let loc_needing_lifetime =
loc_needing_lifetime.and_then(|it| it.make_mut(builder).to_position());
add_lifetime_param(fn_def.get_or_create_generic_param_list(), new_lifetime_param);
ted::replace(
lifetime.syntax(),
make_ast_lifetime(new_lifetime_param).clone_for_update().syntax(),
);
loc_needing_lifetime.map(|position| {
ted::insert(position, make_ast_lifetime(new_lifetime_param).clone_for_update().syntax())
});
})
}
/// Generate the assist for the impl def case
fn generate_impl_def_assist(
acc: &mut Assists,
impl_def: &ast::Impl,
impl_def: ast::Impl,
lifetime_loc: TextRange,
lifetime: ast::Lifetime,
) -> Option<()> {
let new_lifetime_param = generate_unique_lifetime_param_name(&impl_def.generic_param_list())?;
let end_of_impl_kw = impl_def.impl_token()?.text_range().end();
let new_lifetime_param = generate_unique_lifetime_param_name(impl_def.generic_param_list())?;
acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| {
add_lifetime_param(impl_def, builder, end_of_impl_kw, new_lifetime_param);
builder.replace(lifetime_loc, format!("'{}", new_lifetime_param));
let impl_def = builder.make_ast_mut(impl_def);
let lifetime = builder.make_ast_mut(lifetime);
add_lifetime_param(impl_def.get_or_create_generic_param_list(), new_lifetime_param);
ted::replace(
lifetime.syntax(),
make_ast_lifetime(new_lifetime_param).clone_for_update().syntax(),
);
})
}
/// Given a type parameter list, generate a unique lifetime parameter name
/// which is not in the list
fn generate_unique_lifetime_param_name(
existing_type_param_list: &Option<ast::GenericParamList>,
existing_type_param_list: Option<ast::GenericParamList>,
) -> Option<char> {
match existing_type_param_list {
Some(type_params) => {
@ -118,25 +137,37 @@ fn generate_unique_lifetime_param_name(
}
}
/// Add the lifetime param to `builder`. If there are type parameters in `type_params_owner`, add it to the end. Otherwise
/// add new type params brackets with the lifetime parameter at `new_type_params_loc`.
fn add_lifetime_param<TypeParamsOwner: ast::GenericParamsOwner>(
type_params_owner: &TypeParamsOwner,
builder: &mut AssistBuilder,
new_type_params_loc: TextSize,
new_lifetime_param: char,
) {
match type_params_owner.generic_param_list() {
// add the new lifetime parameter to an existing type param list
Some(type_params) => {
builder.insert(
(u32::from(type_params.syntax().text_range().end()) - 1).into(),
format!(", '{}", new_lifetime_param),
);
fn add_lifetime_param(type_params: ast::GenericParamList, new_lifetime_param: char) {
let generic_param =
make::generic_param(format!("'{}", new_lifetime_param), None).clone_for_update();
type_params.add_generic_param(generic_param);
}
fn make_ast_lifetime(new_lifetime_param: char) -> ast::Lifetime {
make::generic_param(format!("'{}", new_lifetime_param), None)
.syntax()
.descendants()
.find_map(ast::Lifetime::cast)
.unwrap()
}
enum NeedsLifetime {
SelfParam(ast::SelfParam),
RefType(ast::RefType),
}
impl NeedsLifetime {
fn make_mut(self, builder: &mut AssistBuilder) -> Self {
match self {
Self::SelfParam(it) => Self::SelfParam(builder.make_ast_mut(it)),
Self::RefType(it) => Self::RefType(builder.make_ast_mut(it)),
}
// create a new type param list containing only the new lifetime parameter
None => {
builder.insert(new_type_params_loc, format!("<'{}>", new_lifetime_param));
}
fn to_position(self) -> Option<Position> {
match self {
Self::SelfParam(it) => Some(Position::after(it.amp_token()?)),
Self::RefType(it) => Some(Position::after(it.amp_token()?)),
}
}
}
@ -312,4 +343,13 @@ mod tests {
r#"fn my_fun<'other, 'a>(self, f: &'a Foo, b: &'other Bar) -> X<'a>"#,
);
}
#[test]
fn test_function_add_lifetime_to_self_ref_mut() {
check_assist(
introduce_named_lifetime,
r#"fn foo(&mut self) -> &'_$0 ()"#,
r#"fn foo<'a>(&'a mut self) -> &'a ()"#,
);
}
}

View file

@ -14,10 +14,29 @@ use crate::{
use super::NameOwner;
pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
fn get_or_create_where_clause(&self) -> ast::WhereClause;
}
impl GenericParamsOwnerEdit for ast::Fn {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = if let Some(name) = self.name() {
Position::after(name.syntax)
} else if let Some(fn_token) = self.fn_token() {
Position::after(fn_token)
} else if let Some(param_list) = self.param_list() {
Position::before(param_list.syntax)
} else {
Position::last_child_of(self.syntax())
};
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(ty) = self.ret_type() {
@ -34,6 +53,20 @@ impl GenericParamsOwnerEdit for ast::Fn {
}
impl GenericParamsOwnerEdit for ast::Impl {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = if let Some(imp_token) = self.impl_token() {
Position::after(imp_token)
} else {
Position::last_child_of(self.syntax())
};
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(items) = self.assoc_item_list() {
@ -48,6 +81,22 @@ impl GenericParamsOwnerEdit for ast::Impl {
}
impl GenericParamsOwnerEdit for ast::Trait {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = if let Some(name) = self.name() {
Position::after(name.syntax)
} else if let Some(trait_token) = self.trait_token() {
Position::after(trait_token)
} else {
Position::last_child_of(self.syntax())
};
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(items) = self.assoc_item_list() {
@ -62,6 +111,22 @@ impl GenericParamsOwnerEdit for ast::Trait {
}
impl GenericParamsOwnerEdit for ast::Struct {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = if let Some(name) = self.name() {
Position::after(name.syntax)
} else if let Some(struct_token) = self.struct_token() {
Position::after(struct_token)
} else {
Position::last_child_of(self.syntax())
};
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let tfl = self.field_list().and_then(|fl| match fl {
@ -84,6 +149,22 @@ impl GenericParamsOwnerEdit for ast::Struct {
}
impl GenericParamsOwnerEdit for ast::Enum {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = if let Some(name) = self.name() {
Position::after(name.syntax)
} else if let Some(enum_token) = self.enum_token() {
Position::after(enum_token)
} else {
Position::last_child_of(self.syntax())
};
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() {
let position = if let Some(gpl) = self.generic_param_list() {
@ -104,6 +185,37 @@ fn create_where_clause(position: Position) {
ted::insert(position, where_clause.syntax());
}
fn create_generic_param_list(position: Position) -> ast::GenericParamList {
let gpl = make::generic_param_list(empty()).clone_for_update();
ted::insert_raw(position, gpl.syntax());
gpl
}
impl ast::GenericParamList {
pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
match self.generic_params().last() {
Some(last_param) => {
let mut elems = Vec::new();
if !last_param
.syntax()
.siblings_with_tokens(Direction::Next)
.any(|it| it.kind() == T![,])
{
elems.push(make::token(T![,]).into());
elems.push(make::tokens::single_space().into());
};
elems.push(generic_param.syntax().clone().into());
let after_last_param = Position::after(last_param.syntax());
ted::insert_all(after_last_param, elems);
}
None => {
let after_l_angle = Position::after(self.l_angle_token().unwrap());
ted::insert(after_l_angle, generic_param.syntax())
}
}
}
}
impl ast::WhereClause {
pub fn add_predicate(&self, predicate: ast::WherePred) {
if let Some(pred) = self.predicates().last() {
@ -164,3 +276,44 @@ impl ast::Use {
ted::remove(self.syntax())
}
}
#[cfg(test)]
mod tests {
use std::fmt;
use crate::SourceFile;
use super::*;
fn ast_mut_from_text<N: AstNode>(text: &str) -> N {
let parse = SourceFile::parse(text);
parse.tree().syntax().descendants().find_map(N::cast).unwrap().clone_for_update()
}
#[test]
fn test_create_generic_param_list() {
fn check_create_gpl<N: GenericParamsOwnerEdit + fmt::Display>(before: &str, after: &str) {
let gpl_owner = ast_mut_from_text::<N>(before);
gpl_owner.get_or_create_generic_param_list();
assert_eq!(gpl_owner.to_string(), after);
}
check_create_gpl::<ast::Fn>("fn foo", "fn foo<>");
check_create_gpl::<ast::Fn>("fn foo() {}", "fn foo<>() {}");
check_create_gpl::<ast::Impl>("impl", "impl<>");
check_create_gpl::<ast::Impl>("impl Struct {}", "impl<> Struct {}");
check_create_gpl::<ast::Impl>("impl Trait for Struct {}", "impl<> Trait for Struct {}");
check_create_gpl::<ast::Trait>("trait Trait<>", "trait Trait<>");
check_create_gpl::<ast::Trait>("trait Trait<> {}", "trait Trait<> {}");
check_create_gpl::<ast::Struct>("struct A", "struct A<>");
check_create_gpl::<ast::Struct>("struct A;", "struct A<>;");
check_create_gpl::<ast::Struct>("struct A();", "struct A<>();");
check_create_gpl::<ast::Struct>("struct A {}", "struct A<> {}");
check_create_gpl::<ast::Enum>("enum E", "enum E<>");
check_create_gpl::<ast::Enum>("enum E {", "enum E<> {");
}
}

View file

@ -165,6 +165,13 @@ fn ws_between(left: &SyntaxElement, right: &SyntaxElement) -> Option<SyntaxToken
if right.kind() == T![;] || right.kind() == T![,] {
return None;
}
if left.kind() == T![<] || right.kind() == T![>] {
return None;
}
if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME {
return None;
}
if right.kind() == SyntaxKind::USE {
let indent = IndentLevel::from_element(left);
return Some(make::tokens::whitespace(&format!("\n{}", indent)));