Switch introduce_named_lifetime assist to use mutable syntax tree

This commit is contained in:
Dawer 2021-04-15 01:56:19 +05:00
parent 27e80e9438
commit 144afa55a6
3 changed files with 147 additions and 36 deletions

View file

@ -1,7 +1,8 @@
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
use syntax::{ use syntax::{
ast::{self, GenericParamsOwner, NameOwner}, ast::{self, edit_in_place::GenericParamsOwnerEdit, make, GenericParamsOwner},
AstNode, TextRange, TextSize, ted::{self, Position},
AstNode, TextRange,
}; };
use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists}; use crate::{assist_context::AssistBuilder, AssistContext, AssistId, AssistKind, Assists};
@ -37,10 +38,12 @@ static ASSIST_LABEL: &str = "Introduce named lifetime";
pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
let lifetime = let lifetime =
ctx.find_node_at_offset::<ast::Lifetime>().filter(|lifetime| lifetime.text() == "'_")?; ctx.find_node_at_offset::<ast::Lifetime>().filter(|lifetime| lifetime.text() == "'_")?;
let lifetime_loc = lifetime.lifetime_ident_token()?.text_range();
if let Some(fn_def) = lifetime.syntax().ancestors().find_map(ast::Fn::cast) { if let Some(fn_def) = lifetime.syntax().ancestors().find_map(ast::Fn::cast) {
generate_fn_def_assist(acc, &fn_def, lifetime.lifetime_ident_token()?.text_range()) generate_fn_def_assist(acc, fn_def, lifetime_loc, lifetime)
} else if let Some(impl_def) = lifetime.syntax().ancestors().find_map(ast::Impl::cast) { } else if let Some(impl_def) = lifetime.syntax().ancestors().find_map(ast::Impl::cast) {
generate_impl_def_assist(acc, &impl_def, lifetime.lifetime_ident_token()?.text_range()) generate_impl_def_assist(acc, impl_def, lifetime_loc, lifetime)
} else { } else {
None None
} }
@ -49,26 +52,26 @@ pub(crate) fn introduce_named_lifetime(acc: &mut Assists, ctx: &AssistContext) -
/// Generate the assist for the fn def case /// Generate the assist for the fn def case
fn generate_fn_def_assist( fn generate_fn_def_assist(
acc: &mut Assists, acc: &mut Assists,
fn_def: &ast::Fn, fn_def: ast::Fn,
lifetime_loc: TextRange, lifetime_loc: TextRange,
lifetime: ast::Lifetime,
) -> Option<()> { ) -> Option<()> {
let param_list: ast::ParamList = fn_def.param_list()?; let param_list: ast::ParamList = fn_def.param_list()?;
let new_lifetime_param = generate_unique_lifetime_param_name(&fn_def.generic_param_list())?; let new_lifetime_param = generate_unique_lifetime_param_name(fn_def.generic_param_list())?;
let end_of_fn_ident = fn_def.name()?.ident_token()?.text_range().end();
let self_param = let self_param =
// use the self if it's a reference and has no explicit lifetime // use the self if it's a reference and has no explicit lifetime
param_list.self_param().filter(|p| p.lifetime().is_none() && p.amp_token().is_some()); param_list.self_param().filter(|p| p.lifetime().is_none() && p.amp_token().is_some());
// compute the location which implicitly has the same lifetime as the anonymous lifetime // compute the location which implicitly has the same lifetime as the anonymous lifetime
let loc_needing_lifetime = if let Some(self_param) = self_param { let loc_needing_lifetime = if let Some(self_param) = self_param {
// if we have a self reference, use that // if we have a self reference, use that
Some(self_param.name()?.syntax().text_range().start()) Some(NeedsLifetime::SelfParam(self_param))
} else { } else {
// otherwise, if there's a single reference parameter without a named liftime, use that // otherwise, if there's a single reference parameter without a named liftime, use that
let fn_params_without_lifetime: Vec<_> = param_list let fn_params_without_lifetime: Vec<_> = param_list
.params() .params()
.filter_map(|param| match param.ty() { .filter_map(|param| match param.ty() {
Some(ast::Type::RefType(ascribed_type)) if ascribed_type.lifetime().is_none() => { Some(ast::Type::RefType(ascribed_type)) if ascribed_type.lifetime().is_none() => {
Some(ascribed_type.amp_token()?.text_range().end()) Some(NeedsLifetime::RefType(ascribed_type))
} }
_ => None, _ => None,
}) })
@ -81,30 +84,46 @@ fn generate_fn_def_assist(
} }
}; };
acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| { acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| {
add_lifetime_param(fn_def, builder, end_of_fn_ident, new_lifetime_param); let fn_def = builder.make_ast_mut(fn_def);
builder.replace(lifetime_loc, format!("'{}", new_lifetime_param)); let lifetime = builder.make_ast_mut(lifetime);
loc_needing_lifetime.map(|loc| builder.insert(loc, format!("'{} ", new_lifetime_param))); let loc_needing_lifetime =
loc_needing_lifetime.and_then(|it| it.make_mut(builder).to_position());
add_lifetime_param(fn_def.get_or_create_generic_param_list(), new_lifetime_param);
ted::replace(
lifetime.syntax(),
make_ast_lifetime(new_lifetime_param).clone_for_update().syntax(),
);
loc_needing_lifetime.map(|position| {
ted::insert(position, make_ast_lifetime(new_lifetime_param).clone_for_update().syntax())
});
}) })
} }
/// Generate the assist for the impl def case /// Generate the assist for the impl def case
fn generate_impl_def_assist( fn generate_impl_def_assist(
acc: &mut Assists, acc: &mut Assists,
impl_def: &ast::Impl, impl_def: ast::Impl,
lifetime_loc: TextRange, lifetime_loc: TextRange,
lifetime: ast::Lifetime,
) -> Option<()> { ) -> Option<()> {
let new_lifetime_param = generate_unique_lifetime_param_name(&impl_def.generic_param_list())?; let new_lifetime_param = generate_unique_lifetime_param_name(impl_def.generic_param_list())?;
let end_of_impl_kw = impl_def.impl_token()?.text_range().end();
acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| { acc.add(AssistId(ASSIST_NAME, AssistKind::Refactor), ASSIST_LABEL, lifetime_loc, |builder| {
add_lifetime_param(impl_def, builder, end_of_impl_kw, new_lifetime_param); let impl_def = builder.make_ast_mut(impl_def);
builder.replace(lifetime_loc, format!("'{}", new_lifetime_param)); let lifetime = builder.make_ast_mut(lifetime);
add_lifetime_param(impl_def.get_or_create_generic_param_list(), new_lifetime_param);
ted::replace(
lifetime.syntax(),
make_ast_lifetime(new_lifetime_param).clone_for_update().syntax(),
);
}) })
} }
/// Given a type parameter list, generate a unique lifetime parameter name /// Given a type parameter list, generate a unique lifetime parameter name
/// which is not in the list /// which is not in the list
fn generate_unique_lifetime_param_name( fn generate_unique_lifetime_param_name(
existing_type_param_list: &Option<ast::GenericParamList>, existing_type_param_list: Option<ast::GenericParamList>,
) -> Option<char> { ) -> Option<char> {
match existing_type_param_list { match existing_type_param_list {
Some(type_params) => { Some(type_params) => {
@ -118,25 +137,37 @@ fn generate_unique_lifetime_param_name(
} }
} }
/// Add the lifetime param to `builder`. If there are type parameters in `type_params_owner`, add it to the end. Otherwise fn add_lifetime_param(type_params: ast::GenericParamList, new_lifetime_param: char) {
/// add new type params brackets with the lifetime parameter at `new_type_params_loc`. let generic_param =
fn add_lifetime_param<TypeParamsOwner: ast::GenericParamsOwner>( make::generic_param(format!("'{}", new_lifetime_param), None).clone_for_update();
type_params_owner: &TypeParamsOwner, type_params.add_generic_param(generic_param);
builder: &mut AssistBuilder, }
new_type_params_loc: TextSize,
new_lifetime_param: char, fn make_ast_lifetime(new_lifetime_param: char) -> ast::Lifetime {
) { make::generic_param(format!("'{}", new_lifetime_param), None)
match type_params_owner.generic_param_list() { .syntax()
// add the new lifetime parameter to an existing type param list .descendants()
Some(type_params) => { .find_map(ast::Lifetime::cast)
builder.insert( .unwrap()
(u32::from(type_params.syntax().text_range().end()) - 1).into(), }
format!(", '{}", new_lifetime_param),
); enum NeedsLifetime {
SelfParam(ast::SelfParam),
RefType(ast::RefType),
}
impl NeedsLifetime {
fn make_mut(self, builder: &mut AssistBuilder) -> Self {
match self {
Self::SelfParam(it) => Self::SelfParam(builder.make_ast_mut(it)),
Self::RefType(it) => Self::RefType(builder.make_ast_mut(it)),
} }
// create a new type param list containing only the new lifetime parameter }
None => {
builder.insert(new_type_params_loc, format!("<'{}>", new_lifetime_param)); fn to_position(self) -> Option<Position> {
match self {
Self::SelfParam(it) => Some(Position::after(it.amp_token()?)),
Self::RefType(it) => Some(Position::after(it.amp_token()?)),
} }
} }
} }
@ -312,4 +343,13 @@ mod tests {
r#"fn my_fun<'other, 'a>(self, f: &'a Foo, b: &'other Bar) -> X<'a>"#, r#"fn my_fun<'other, 'a>(self, f: &'a Foo, b: &'other Bar) -> X<'a>"#,
); );
} }
#[test]
fn test_function_add_lifetime_to_self_ref_mut() {
check_assist(
introduce_named_lifetime,
r#"fn foo(&mut self) -> &'_$0 ()"#,
r#"fn foo<'a>(&'a mut self) -> &'a ()"#,
);
}
} }

View file

@ -14,10 +14,21 @@ use crate::{
use super::NameOwner; use super::NameOwner;
pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit { pub trait GenericParamsOwnerEdit: ast::GenericParamsOwner + AstNodeEdit {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList;
fn get_or_create_where_clause(&self) -> ast::WhereClause; fn get_or_create_where_clause(&self) -> ast::WhereClause;
} }
impl GenericParamsOwnerEdit for ast::Fn { impl GenericParamsOwnerEdit for ast::Fn {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = Position::after(self.name().unwrap().syntax);
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause { fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() { if self.where_clause().is_none() {
let position = if let Some(ty) = self.ret_type() { let position = if let Some(ty) = self.ret_type() {
@ -34,6 +45,16 @@ impl GenericParamsOwnerEdit for ast::Fn {
} }
impl GenericParamsOwnerEdit for ast::Impl { impl GenericParamsOwnerEdit for ast::Impl {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
match self.generic_param_list() {
Some(it) => it,
None => {
let position = Position::after(self.impl_token().unwrap());
create_generic_param_list(position)
}
}
}
fn get_or_create_where_clause(&self) -> WhereClause { fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() { if self.where_clause().is_none() {
let position = if let Some(items) = self.assoc_item_list() { let position = if let Some(items) = self.assoc_item_list() {
@ -48,6 +69,10 @@ impl GenericParamsOwnerEdit for ast::Impl {
} }
impl GenericParamsOwnerEdit for ast::Trait { impl GenericParamsOwnerEdit for ast::Trait {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
todo!()
}
fn get_or_create_where_clause(&self) -> WhereClause { fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() { if self.where_clause().is_none() {
let position = if let Some(items) = self.assoc_item_list() { let position = if let Some(items) = self.assoc_item_list() {
@ -62,6 +87,10 @@ impl GenericParamsOwnerEdit for ast::Trait {
} }
impl GenericParamsOwnerEdit for ast::Struct { impl GenericParamsOwnerEdit for ast::Struct {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
todo!()
}
fn get_or_create_where_clause(&self) -> WhereClause { fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() { if self.where_clause().is_none() {
let tfl = self.field_list().and_then(|fl| match fl { let tfl = self.field_list().and_then(|fl| match fl {
@ -84,6 +113,10 @@ impl GenericParamsOwnerEdit for ast::Struct {
} }
impl GenericParamsOwnerEdit for ast::Enum { impl GenericParamsOwnerEdit for ast::Enum {
fn get_or_create_generic_param_list(&self) -> ast::GenericParamList {
todo!()
}
fn get_or_create_where_clause(&self) -> WhereClause { fn get_or_create_where_clause(&self) -> WhereClause {
if self.where_clause().is_none() { if self.where_clause().is_none() {
let position = if let Some(gpl) = self.generic_param_list() { let position = if let Some(gpl) = self.generic_param_list() {
@ -104,6 +137,37 @@ fn create_where_clause(position: Position) {
ted::insert(position, where_clause.syntax()); ted::insert(position, where_clause.syntax());
} }
fn create_generic_param_list(position: Position) -> ast::GenericParamList {
let gpl = make::generic_param_list(empty()).clone_for_update();
ted::insert_raw(position, gpl.syntax());
gpl
}
impl ast::GenericParamList {
pub fn add_generic_param(&self, generic_param: ast::GenericParam) {
match self.generic_params().last() {
Some(last_param) => {
let mut elems = Vec::new();
if !last_param
.syntax()
.siblings_with_tokens(Direction::Next)
.any(|it| it.kind() == T![,])
{
elems.push(make::token(T![,]).into());
elems.push(make::tokens::single_space().into());
};
elems.push(generic_param.syntax().clone().into());
let after_last_param = Position::after(last_param.syntax());
ted::insert_all(after_last_param, elems);
}
None => {
let after_l_angle = Position::after(self.l_angle_token().unwrap());
ted::insert(after_l_angle, generic_param.syntax())
}
}
}
}
impl ast::WhereClause { impl ast::WhereClause {
pub fn add_predicate(&self, predicate: ast::WherePred) { pub fn add_predicate(&self, predicate: ast::WherePred) {
if let Some(pred) = self.predicates().last() { if let Some(pred) = self.predicates().last() {

View file

@ -165,6 +165,13 @@ fn ws_between(left: &SyntaxElement, right: &SyntaxElement) -> Option<SyntaxToken
if right.kind() == T![;] || right.kind() == T![,] { if right.kind() == T![;] || right.kind() == T![,] {
return None; return None;
} }
if left.kind() == T![<] || right.kind() == T![>] {
return None;
}
if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME {
return None;
}
if right.kind() == SyntaxKind::USE { if right.kind() == SyntaxKind::USE {
let indent = IndentLevel::from_element(left); let indent = IndentLevel::from_element(left);
return Some(make::tokens::whitespace(&format!("\n{}", indent))); return Some(make::tokens::whitespace(&format!("\n{}", indent)));