Merge pull request #18652 from Giga-Bowser/extract-constant

feat: Add an assist to extract an expression into a constant
This commit is contained in:
Lukas Wirth 2024-12-12 13:22:05 +00:00 committed by GitHub
commit a6c291ed07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1638 additions and 310 deletions

View file

@ -7,6 +7,7 @@ use hir::{
TypeInfo, TypeParam, TypeInfo, TypeParam,
}; };
use ide_db::{ use ide_db::{
assists::GroupLabel,
defs::{Definition, NameRefClass}, defs::{Definition, NameRefClass},
famous_defs::FamousDefs, famous_defs::FamousDefs,
helpers::mod_path_to_ast, helpers::mod_path_to_ast,
@ -104,7 +105,8 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
let scope = ImportScope::find_insert_use_container(&node, &ctx.sema)?; let scope = ImportScope::find_insert_use_container(&node, &ctx.sema)?;
acc.add( acc.add_group(
&GroupLabel("Extract into...".to_owned()),
AssistId("extract_function", crate::AssistKind::RefactorExtract), AssistId("extract_function", crate::AssistKind::RefactorExtract),
"Extract into function", "Extract into function",
target_range, target_range,

File diff suppressed because it is too large Load diff

View file

@ -1,19 +1,17 @@
use hir::{HirDisplay, ModuleDef, PathResolution, Semantics}; use hir::HirDisplay;
use ide_db::{ use ide_db::{
assists::{AssistId, AssistKind}, assists::{AssistId, AssistKind},
defs::Definition, defs::Definition,
syntax_helpers::node_ext::preorder_expr,
RootDatabase,
}; };
use stdx::to_upper_snake_case; use stdx::to_upper_snake_case;
use syntax::{ use syntax::{
ast::{self, make, HasName}, ast::{self, make, HasName},
ted, AstNode, WalkEvent, ted, AstNode,
}; };
use crate::{ use crate::{
assist_context::{AssistContext, Assists}, assist_context::{AssistContext, Assists},
utils, utils::{self},
}; };
// Assist: promote_local_to_const // Assist: promote_local_to_const
@ -63,7 +61,7 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>)
}; };
let initializer = let_stmt.initializer()?; let initializer = let_stmt.initializer()?;
if !is_body_const(&ctx.sema, &initializer) { if !utils::is_body_const(&ctx.sema, &initializer) {
cov_mark::hit!(promote_local_non_const); cov_mark::hit!(promote_local_non_const);
return None; return None;
} }
@ -103,40 +101,6 @@ pub(crate) fn promote_local_to_const(acc: &mut Assists, ctx: &AssistContext<'_>)
) )
} }
fn is_body_const(sema: &Semantics<'_, RootDatabase>, expr: &ast::Expr) -> bool {
let mut is_const = true;
preorder_expr(expr, &mut |ev| {
let expr = match ev {
WalkEvent::Enter(_) if !is_const => return true,
WalkEvent::Enter(expr) => expr,
WalkEvent::Leave(_) => return false,
};
match expr {
ast::Expr::CallExpr(call) => {
if let Some(ast::Expr::PathExpr(path_expr)) = call.expr() {
if let Some(PathResolution::Def(ModuleDef::Function(func))) =
path_expr.path().and_then(|path| sema.resolve_path(&path))
{
is_const &= func.is_const(sema.db);
}
}
}
ast::Expr::MethodCallExpr(call) => {
is_const &=
sema.resolve_method_call(&call).map(|it| it.is_const(sema.db)).unwrap_or(true)
}
ast::Expr::ForExpr(_)
| ast::Expr::ReturnExpr(_)
| ast::Expr::TryExpr(_)
| ast::Expr::YieldExpr(_)
| ast::Expr::AwaitExpr(_) => is_const = false,
_ => (),
}
!is_const
});
is_const
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::tests::{check_assist, check_assist_not_applicable}; use crate::tests::{check_assist, check_assist_not_applicable};

View file

@ -362,8 +362,7 @@ pub fn test_some_range(a: int) -> bool {
expect![[r#" expect![[r#"
Convert integer base Convert integer base
Extract into variable Extract into...
Extract into function
Replace if let with match Replace if let with match
"#]] "#]]
.assert_eq(&expected); .assert_eq(&expected);
@ -391,8 +390,7 @@ pub fn test_some_range(a: int) -> bool {
expect![[r#" expect![[r#"
Convert integer base Convert integer base
Extract into variable Extract into...
Extract into function
Replace if let with match Replace if let with match
"#]] "#]]
.assert_eq(&expected); .assert_eq(&expected);
@ -405,8 +403,7 @@ pub fn test_some_range(a: int) -> bool {
let expected = labels(&assists); let expected = labels(&assists);
expect![[r#" expect![[r#"
Extract into variable Extract into...
Extract into function
"#]] "#]]
.assert_eq(&expected); .assert_eq(&expected);
} }
@ -440,7 +437,7 @@ pub fn test_some_range(a: int) -> bool {
{ {
let assists = assists(&db, &cfg, AssistResolveStrategy::None, frange.into()); let assists = assists(&db, &cfg, AssistResolveStrategy::None, frange.into());
assert_eq!(2, assists.len()); assert_eq!(4, assists.len());
let mut assists = assists.into_iter(); let mut assists = assists.into_iter();
let extract_into_variable_assist = assists.next().unwrap(); let extract_into_variable_assist = assists.next().unwrap();
@ -451,7 +448,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into variable", label: "Extract into variable",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: None, source_change: None,
command: None, command: None,
@ -459,6 +460,46 @@ pub fn test_some_range(a: int) -> bool {
"#]] "#]]
.assert_debug_eq(&extract_into_variable_assist); .assert_debug_eq(&extract_into_variable_assist);
let extract_into_constant_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_constant",
RefactorExtract,
),
label: "Extract into constant",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: None,
command: None,
}
"#]]
.assert_debug_eq(&extract_into_constant_assist);
let extract_into_static_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_static",
RefactorExtract,
),
label: "Extract into static",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: None,
command: None,
}
"#]]
.assert_debug_eq(&extract_into_static_assist);
let extract_into_function_assist = assists.next().unwrap(); let extract_into_function_assist = assists.next().unwrap();
expect![[r#" expect![[r#"
Assist { Assist {
@ -467,7 +508,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into function", label: "Extract into function",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: None, source_change: None,
command: None, command: None,
@ -486,7 +531,7 @@ pub fn test_some_range(a: int) -> bool {
}), }),
frange.into(), frange.into(),
); );
assert_eq!(2, assists.len()); assert_eq!(4, assists.len());
let mut assists = assists.into_iter(); let mut assists = assists.into_iter();
let extract_into_variable_assist = assists.next().unwrap(); let extract_into_variable_assist = assists.next().unwrap();
@ -497,7 +542,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into variable", label: "Extract into variable",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: None, source_change: None,
command: None, command: None,
@ -505,6 +554,46 @@ pub fn test_some_range(a: int) -> bool {
"#]] "#]]
.assert_debug_eq(&extract_into_variable_assist); .assert_debug_eq(&extract_into_variable_assist);
let extract_into_constant_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_constant",
RefactorExtract,
),
label: "Extract into constant",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: None,
command: None,
}
"#]]
.assert_debug_eq(&extract_into_constant_assist);
let extract_into_static_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_static",
RefactorExtract,
),
label: "Extract into static",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: None,
command: None,
}
"#]]
.assert_debug_eq(&extract_into_static_assist);
let extract_into_function_assist = assists.next().unwrap(); let extract_into_function_assist = assists.next().unwrap();
expect![[r#" expect![[r#"
Assist { Assist {
@ -513,7 +602,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into function", label: "Extract into function",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: None, source_change: None,
command: None, command: None,
@ -532,7 +625,7 @@ pub fn test_some_range(a: int) -> bool {
}), }),
frange.into(), frange.into(),
); );
assert_eq!(2, assists.len()); assert_eq!(4, assists.len());
let mut assists = assists.into_iter(); let mut assists = assists.into_iter();
let extract_into_variable_assist = assists.next().unwrap(); let extract_into_variable_assist = assists.next().unwrap();
@ -543,7 +636,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into variable", label: "Extract into variable",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: Some( source_change: Some(
SourceChange { SourceChange {
@ -594,6 +691,46 @@ pub fn test_some_range(a: int) -> bool {
"#]] "#]]
.assert_debug_eq(&extract_into_variable_assist); .assert_debug_eq(&extract_into_variable_assist);
let extract_into_constant_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_constant",
RefactorExtract,
),
label: "Extract into constant",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: None,
command: None,
}
"#]]
.assert_debug_eq(&extract_into_constant_assist);
let extract_into_static_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_static",
RefactorExtract,
),
label: "Extract into static",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: None,
command: None,
}
"#]]
.assert_debug_eq(&extract_into_static_assist);
let extract_into_function_assist = assists.next().unwrap(); let extract_into_function_assist = assists.next().unwrap();
expect![[r#" expect![[r#"
Assist { Assist {
@ -602,7 +739,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into function", label: "Extract into function",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: None, source_change: None,
command: None, command: None,
@ -613,7 +754,7 @@ pub fn test_some_range(a: int) -> bool {
{ {
let assists = assists(&db, &cfg, AssistResolveStrategy::All, frange.into()); let assists = assists(&db, &cfg, AssistResolveStrategy::All, frange.into());
assert_eq!(2, assists.len()); assert_eq!(4, assists.len());
let mut assists = assists.into_iter(); let mut assists = assists.into_iter();
let extract_into_variable_assist = assists.next().unwrap(); let extract_into_variable_assist = assists.next().unwrap();
@ -624,7 +765,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into variable", label: "Extract into variable",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: Some( source_change: Some(
SourceChange { SourceChange {
@ -675,6 +820,140 @@ pub fn test_some_range(a: int) -> bool {
"#]] "#]]
.assert_debug_eq(&extract_into_variable_assist); .assert_debug_eq(&extract_into_variable_assist);
let extract_into_constant_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_constant",
RefactorExtract,
),
label: "Extract into constant",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: Some(
SourceChange {
source_file_edits: {
FileId(
0,
): (
TextEdit {
indels: [
Indel {
insert: "const",
delete: 45..47,
},
Indel {
insert: "VAR_NAME:",
delete: 48..60,
},
Indel {
insert: "i32",
delete: 61..81,
},
Indel {
insert: "=",
delete: 82..86,
},
Indel {
insert: "5;\n if let 2..6 = VAR_NAME {\n true\n } else {\n false\n }",
delete: 87..108,
},
],
},
Some(
SnippetEdit(
[
(
0,
51..51,
),
],
),
),
),
},
file_system_edits: [],
is_snippet: true,
},
),
command: Some(
Rename,
),
}
"#]]
.assert_debug_eq(&extract_into_constant_assist);
let extract_into_static_assist = assists.next().unwrap();
expect![[r#"
Assist {
id: AssistId(
"extract_static",
RefactorExtract,
),
label: "Extract into static",
group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60,
source_change: Some(
SourceChange {
source_file_edits: {
FileId(
0,
): (
TextEdit {
indels: [
Indel {
insert: "static",
delete: 45..47,
},
Indel {
insert: "VAR_NAME:",
delete: 48..60,
},
Indel {
insert: "i32",
delete: 61..81,
},
Indel {
insert: "=",
delete: 82..86,
},
Indel {
insert: "5;\n if let 2..6 = VAR_NAME {\n true\n } else {\n false\n }",
delete: 87..108,
},
],
},
Some(
SnippetEdit(
[
(
0,
52..52,
),
],
),
),
),
},
file_system_edits: [],
is_snippet: true,
},
),
command: Some(
Rename,
),
}
"#]]
.assert_debug_eq(&extract_into_static_assist);
let extract_into_function_assist = assists.next().unwrap(); let extract_into_function_assist = assists.next().unwrap();
expect![[r#" expect![[r#"
Assist { Assist {
@ -683,7 +962,11 @@ pub fn test_some_range(a: int) -> bool {
RefactorExtract, RefactorExtract,
), ),
label: "Extract into function", label: "Extract into function",
group: None, group: Some(
GroupLabel(
"Extract into...",
),
),
target: 59..60, target: 59..60,
source_change: Some( source_change: Some(
SourceChange { SourceChange {

View file

@ -932,6 +932,24 @@ enum TheEnum {
) )
} }
#[test]
fn doctest_extract_constant() {
check_doc_test(
"extract_constant",
r#####"
fn main() {
$0(1 + 2)$0 * 4;
}
"#####,
r#####"
fn main() {
const $0VAR_NAME: i32 = 1 + 2;
VAR_NAME * 4;
}
"#####,
)
}
#[test] #[test]
fn doctest_extract_expressions_from_format_string() { fn doctest_extract_expressions_from_format_string() {
check_doc_test( check_doc_test(
@ -1006,6 +1024,24 @@ fn bar(name: i32) -> i32 {
) )
} }
#[test]
fn doctest_extract_static() {
check_doc_test(
"extract_static",
r#####"
fn main() {
$0(1 + 2)$0 * 4;
}
"#####,
r#####"
fn main() {
static $0VAR_NAME: i32 = 1 + 2;
VAR_NAME * 4;
}
"#####,
)
}
#[test] #[test]
fn doctest_extract_struct_from_enum_variant() { fn doctest_extract_struct_from_enum_variant() {
check_doc_test( check_doc_test(

View file

@ -3,11 +3,13 @@
pub(crate) use gen_trait_fn_body::gen_trait_fn_body; pub(crate) use gen_trait_fn_body::gen_trait_fn_body;
use hir::{ use hir::{
db::{ExpandDatabase, HirDatabase}, db::{ExpandDatabase, HirDatabase},
HasAttrs as HirHasAttrs, HirDisplay, InFile, Semantics, HasAttrs as HirHasAttrs, HirDisplay, InFile, ModuleDef, PathResolution, Semantics,
}; };
use ide_db::{ use ide_db::{
famous_defs::FamousDefs, path_transform::PathTransform, famous_defs::FamousDefs,
syntax_helpers::prettify_macro_expansion, RootDatabase, path_transform::PathTransform,
syntax_helpers::{node_ext::preorder_expr, prettify_macro_expansion},
RootDatabase,
}; };
use stdx::format_to; use stdx::format_to;
use syntax::{ use syntax::{
@ -19,7 +21,7 @@ use syntax::{
}, },
ted, AstNode, AstToken, Direction, Edition, NodeOrToken, SourceFile, ted, AstNode, AstToken, Direction, Edition, NodeOrToken, SourceFile,
SyntaxKind::*, SyntaxKind::*,
SyntaxNode, SyntaxToken, TextRange, TextSize, T, SyntaxNode, SyntaxToken, TextRange, TextSize, WalkEvent, T,
}; };
use crate::assist_context::{AssistContext, SourceChangeBuilder}; use crate::assist_context::{AssistContext, SourceChangeBuilder};
@ -966,3 +968,37 @@ pub(crate) fn tt_from_syntax(node: SyntaxNode) -> Vec<NodeOrToken<ast::TokenTree
tt_stack.pop().expect("parent token tree was closed before it was completed").1 tt_stack.pop().expect("parent token tree was closed before it was completed").1
} }
pub fn is_body_const(sema: &Semantics<'_, RootDatabase>, expr: &ast::Expr) -> bool {
let mut is_const = true;
preorder_expr(expr, &mut |ev| {
let expr = match ev {
WalkEvent::Enter(_) if !is_const => return true,
WalkEvent::Enter(expr) => expr,
WalkEvent::Leave(_) => return false,
};
match expr {
ast::Expr::CallExpr(call) => {
if let Some(ast::Expr::PathExpr(path_expr)) = call.expr() {
if let Some(PathResolution::Def(ModuleDef::Function(func))) =
path_expr.path().and_then(|path| sema.resolve_path(&path))
{
is_const &= func.is_const(sema.db);
}
}
}
ast::Expr::MethodCallExpr(call) => {
is_const &=
sema.resolve_method_call(&call).map(|it| it.is_const(sema.db)).unwrap_or(true)
}
ast::Expr::ForExpr(_)
| ast::Expr::ReturnExpr(_)
| ast::Expr::TryExpr(_)
| ast::Expr::YieldExpr(_)
| ast::Expr::AwaitExpr(_) => is_const = false,
_ => (),
}
!is_const
});
is_const
}

View file

@ -895,7 +895,29 @@ pub fn item_const(
None => String::new(), None => String::new(),
Some(it) => format!("{it} "), Some(it) => format!("{it} "),
}; };
ast_from_text(&format!("{visibility} const {name}: {ty} = {expr};")) ast_from_text(&format!("{visibility}const {name}: {ty} = {expr};"))
}
pub fn item_static(
visibility: Option<ast::Visibility>,
is_unsafe: bool,
is_mut: bool,
name: ast::Name,
ty: ast::Type,
expr: Option<ast::Expr>,
) -> ast::Static {
let visibility = match visibility {
None => String::new(),
Some(it) => format!("{it} "),
};
let is_unsafe = if is_unsafe { "unsafe " } else { "" };
let is_mut = if is_mut { "mut " } else { "" };
let expr = match expr {
Some(it) => &format!(" = {it}"),
None => "",
};
ast_from_text(&format!("{visibility}{is_unsafe}static {is_mut}{name}: {ty}{expr};"))
} }
pub fn unnamed_param(ty: ast::Type) -> ast::Param { pub fn unnamed_param(ty: ast::Type) -> ast::Param {

View file

@ -188,6 +188,73 @@ impl SyntaxFactory {
ast ast
} }
pub fn item_const(
&self,
visibility: Option<ast::Visibility>,
name: ast::Name,
ty: ast::Type,
expr: ast::Expr,
) -> ast::Const {
let ast = make::item_const(visibility.clone(), name.clone(), ty.clone(), expr.clone())
.clone_for_update();
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
if let Some(visibility) = visibility {
builder.map_node(
visibility.syntax().clone(),
ast.visibility().unwrap().syntax().clone(),
);
}
builder.map_node(name.syntax().clone(), ast.name().unwrap().syntax().clone());
builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
builder.map_node(expr.syntax().clone(), ast.body().unwrap().syntax().clone());
builder.finish(&mut mapping);
}
ast
}
pub fn item_static(
&self,
visibility: Option<ast::Visibility>,
is_unsafe: bool,
is_mut: bool,
name: ast::Name,
ty: ast::Type,
expr: Option<ast::Expr>,
) -> ast::Static {
let ast = make::item_static(
visibility.clone(),
is_unsafe,
is_mut,
name.clone(),
ty.clone(),
expr.clone(),
)
.clone_for_update();
if let Some(mut mapping) = self.mappings() {
let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone());
if let Some(visibility) = visibility {
builder.map_node(
visibility.syntax().clone(),
ast.visibility().unwrap().syntax().clone(),
);
}
builder.map_node(name.syntax().clone(), ast.name().unwrap().syntax().clone());
builder.map_node(ty.syntax().clone(), ast.ty().unwrap().syntax().clone());
if let Some(expr) = expr {
builder.map_node(expr.syntax().clone(), ast.body().unwrap().syntax().clone());
}
builder.finish(&mut mapping);
}
ast
}
pub fn turbofish_generic_arg_list( pub fn turbofish_generic_arg_list(
&self, &self,
args: impl IntoIterator<Item = ast::GenericArg> + Clone, args: impl IntoIterator<Item = ast::GenericArg> + Clone,