diff --git a/crates/ra_ide/src/lib.rs b/crates/ra_ide/src/lib.rs index 9d66c365b1..f86f98be73 100644 --- a/crates/ra_ide/src/lib.rs +++ b/crates/ra_ide/src/lib.rs @@ -37,6 +37,7 @@ mod display; mod inlay_hints; mod expand; mod expand_macro; +mod ssr; #[cfg(test)] mod marks; @@ -73,6 +74,7 @@ pub use crate::{ }, runnables::{Runnable, RunnableKind, TestId}, source_change::{FileSystemEdit, SourceChange, SourceFileEdit}, + ssr::SsrError, syntax_highlighting::HighlightedRange, }; @@ -464,6 +466,16 @@ impl Analysis { self.with_db(|db| references::rename(db, position, new_name)) } + pub fn structural_search_replace( + &self, + query: &str, + ) -> Cancelable> { + self.with_db(|db| { + let edits = ssr::parse_search_replace(query, db)?; + Ok(SourceChange::source_file_edits("ssr", edits)) + }) + } + /// Performs an operation on that may be Canceled. fn with_db T + std::panic::UnwindSafe, T>( &self, diff --git a/crates/ra_ide/src/ssr.rs b/crates/ra_ide/src/ssr.rs new file mode 100644 index 0000000000..14eb0b8b25 --- /dev/null +++ b/crates/ra_ide/src/ssr.rs @@ -0,0 +1,324 @@ +//! structural search replace + +use crate::source_change::SourceFileEdit; +use ra_ide_db::RootDatabase; +use ra_syntax::ast::make::expr_from_text; +use ra_syntax::AstNode; +use ra_syntax::SyntaxElement; +use ra_syntax::SyntaxNode; +use ra_text_edit::{TextEdit, TextEditBuilder}; +use rustc_hash::FxHashMap; +use std::collections::HashMap; +use std::str::FromStr; + +pub use ra_db::{SourceDatabase, SourceDatabaseExt}; +use ra_ide_db::symbol_index::SymbolsDatabase; + +#[derive(Debug, PartialEq)] +pub struct SsrError(String); + +impl std::fmt::Display for SsrError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Parse error: {}", self.0) + } +} + +impl std::error::Error for SsrError {} + +pub fn parse_search_replace( + query: &str, + db: &RootDatabase, +) -> Result, SsrError> { + let mut edits = vec![]; + let query: SsrQuery = query.parse()?; + for &root in db.local_roots().iter() { + let sr = db.source_root(root); + for file_id in sr.walk() { + dbg!(db.file_relative_path(file_id)); + let matches = find(&query.pattern, db.parse(file_id).tree().syntax()); + if !matches.matches.is_empty() { + edits.push(SourceFileEdit { file_id, edit: replace(&matches, &query.template) }); + } + } + } + Ok(edits) +} + +#[derive(Debug)] +struct SsrQuery { + pattern: SsrPattern, + template: SsrTemplate, +} + +#[derive(Debug)] +struct SsrPattern { + pattern: SyntaxNode, + vars: Vec, +} + +/// represents an `$var` in an SSR query +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct Var(String); + +#[derive(Debug)] +struct SsrTemplate { + template: SyntaxNode, + placeholders: FxHashMap, +} + +type Binding = HashMap; + +#[derive(Debug)] +struct Match { + place: SyntaxNode, + binding: Binding, +} + +#[derive(Debug)] +struct SsrMatches { + matches: Vec, +} + +impl FromStr for SsrQuery { + type Err = SsrError; + + fn from_str(query: &str) -> Result { + let mut it = query.split("==>>"); + let pattern = it.next().expect("at least empty string").trim(); + let mut template = + it.next().ok_or(SsrError("Cannot find delemiter `==>>`".into()))?.trim().to_string(); + if it.next().is_some() { + return Err(SsrError("More than one delimiter found".into())); + } + let mut vars = vec![]; + let mut it = pattern.split('$'); + let mut pattern = it.next().expect("something").to_string(); + + for part in it.map(split_by_var) { + let (var, var_type, remainder) = part?; + is_expr(var_type)?; + let new_var = create_name(var, &mut vars)?; + pattern.push_str(new_var); + pattern.push_str(remainder); + template = replace_in_template(template, var, new_var); + } + + let template = expr_from_text(&template).syntax().clone(); + let mut placeholders = FxHashMap::default(); + + traverse(&template, &mut |n| { + if let Some(v) = vars.iter().find(|v| v.0.as_str() == n.text()) { + placeholders.insert(n.clone(), v.clone()); + false + } else { + true + } + }); + + let pattern = SsrPattern { pattern: expr_from_text(&pattern).syntax().clone(), vars }; + let template = SsrTemplate { template, placeholders }; + Ok(SsrQuery { pattern, template }) + } +} + +fn traverse(node: &SyntaxNode, go: &mut impl FnMut(&SyntaxNode) -> bool) { + if !go(node) { + return; + } + for ref child in node.children() { + traverse(child, go); + } +} + +fn split_by_var(s: &str) -> Result<(&str, &str, &str), SsrError> { + let end_of_name = s.find(":").ok_or(SsrError("Use $:expr".into()))?; + let name = &s[0..end_of_name]; + is_name(name)?; + let type_begin = end_of_name + 1; + let type_length = s[type_begin..].find(|c| !char::is_ascii_alphanumeric(&c)).unwrap_or(s.len()); + let type_name = &s[type_begin..type_begin + type_length]; + Ok((name, type_name, &s[type_begin + type_length..])) +} + +fn is_name(s: &str) -> Result<(), SsrError> { + if s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { + Ok(()) + } else { + Err(SsrError("Name can contain only alphanumerics and _".into())) + } +} + +fn is_expr(s: &str) -> Result<(), SsrError> { + if s == "expr" { + Ok(()) + } else { + Err(SsrError("Only $:expr is supported".into())) + } +} + +fn replace_in_template(template: String, var: &str, new_var: &str) -> String { + let name = format!("${}", var); + template.replace(&name, new_var) +} + +fn create_name<'a>(name: &str, vars: &'a mut Vec) -> Result<&'a str, SsrError> { + let sanitized_name = format!("__search_pattern_{}", name); + if vars.iter().any(|a| a.0 == sanitized_name) { + return Err(SsrError(format!("Name `{}` repeats more than once", name))); + } + vars.push(Var(sanitized_name)); + Ok(&vars.last().unwrap().0) +} + +fn find(pattern: &SsrPattern, code: &SyntaxNode) -> SsrMatches { + fn check( + pattern: &SyntaxElement, + code: &SyntaxElement, + placeholders: &[Var], + match_: &mut Match, + ) -> bool { + match (pattern, code) { + (SyntaxElement::Token(ref pattern), SyntaxElement::Token(ref code)) => { + pattern.text() == code.text() + } + (SyntaxElement::Node(ref pattern), SyntaxElement::Node(ref code)) => { + if placeholders.iter().find(|&n| n.0.as_str() == pattern.text()).is_some() { + match_.binding.insert(Var(pattern.text().to_string()), code.clone()); + true + } else { + pattern.green().children().count() == code.green().children().count() + && pattern + .children_with_tokens() + .zip(code.children_with_tokens()) + .all(|(a, b)| check(&a, &b, placeholders, match_)) + } + } + _ => false, + } + } + let kind = pattern.pattern.kind(); + let matches = code + .descendants_with_tokens() + .filter(|n| n.kind() == kind) + .filter_map(|code| { + let mut match_ = + Match { place: code.as_node().unwrap().clone(), binding: HashMap::new() }; + if check( + &SyntaxElement::from(pattern.pattern.clone()), + &code, + &pattern.vars, + &mut match_, + ) { + Some(match_) + } else { + None + } + }) + .collect(); + SsrMatches { matches } +} + +fn replace(matches: &SsrMatches, template: &SsrTemplate) -> TextEdit { + let mut builder = TextEditBuilder::default(); + for match_ in &matches.matches { + builder.replace(match_.place.text_range(), render_replace(&match_.binding, template)); + } + builder.finish() +} + +fn render_replace(binding: &Binding, template: &SsrTemplate) -> String { + let mut builder = TextEditBuilder::default(); + for element in template.template.descendants() { + if let Some(var) = template.placeholders.get(&element) { + builder.replace(element.text_range(), binding[var].to_string()) + } + } + builder.finish().apply(&template.template.text().to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + use ra_syntax::SourceFile; + + fn parse_error_text(query: &str) -> String { + format!("{}", query.parse::().unwrap_err()) + } + + #[test] + fn parser_happy_case() { + let result: SsrQuery = "foo($a:expr, $b:expr) ==>> bar($b, $a)".parse().unwrap(); + assert_eq!(&result.pattern.pattern.text(), "foo(__search_pattern_a, __search_pattern_b)"); + assert_eq!(result.pattern.vars.len(), 2); + assert_eq!(result.pattern.vars[0].0, "__search_pattern_a"); + assert_eq!(result.pattern.vars[1].0, "__search_pattern_b"); + assert_eq!(&result.template.template.text(), "bar(__search_pattern_b, __search_pattern_a)"); + dbg!(result.template.placeholders); + } + + #[test] + fn parser_empty_query() { + assert_eq!(parse_error_text(""), "Parse error: Cannot find delemiter `==>>`"); + } + + #[test] + fn parser_no_delimiter() { + assert_eq!(parse_error_text("foo()"), "Parse error: Cannot find delemiter `==>>`"); + } + + #[test] + fn parser_two_delimiters() { + assert_eq!( + parse_error_text("foo() ==>> a ==>> b "), + "Parse error: More than one delimiter found" + ); + } + + #[test] + fn parser_no_pattern_type() { + assert_eq!(parse_error_text("foo($a) ==>>"), "Parse error: Use $:expr"); + } + + #[test] + fn parser_invalid_name() { + assert_eq!( + parse_error_text("foo($a+:expr) ==>>"), + "Parse error: Name can contain only alphanumerics and _" + ); + } + + #[test] + fn parser_invalid_type() { + assert_eq!( + parse_error_text("foo($a:ident) ==>>"), + "Parse error: Only $:expr is supported" + ); + } + + #[test] + fn parser_repeated_name() { + assert_eq!( + parse_error_text("foo($a:expr, $a:expr) ==>>"), + "Parse error: Name `a` repeats more than once" + ); + } + + #[test] + fn parse_match_replace() { + let query: SsrQuery = "foo($x:expr) ==>> bar($x)".parse().unwrap(); + let input = "fn main() { foo(1+2); }"; + + let code = SourceFile::parse(input).tree(); + let matches = find(&query.pattern, code.syntax()); + assert_eq!(matches.matches.len(), 1); + assert_eq!(matches.matches[0].place.text(), "foo(1+2)"); + assert_eq!(matches.matches[0].binding.len(), 1); + assert_eq!( + matches.matches[0].binding[&Var("__search_pattern_x".to_string())].text(), + "1+2" + ); + + let edit = replace(&matches, &query.template); + assert_eq!(edit.apply(input), "fn main() { bar(1+2); }"); + } +} diff --git a/crates/ra_lsp_server/src/main_loop.rs b/crates/ra_lsp_server/src/main_loop.rs index 1e70cea136..7ae2e1e6fd 100644 --- a/crates/ra_lsp_server/src/main_loop.rs +++ b/crates/ra_lsp_server/src/main_loop.rs @@ -527,6 +527,7 @@ fn on_request( .on::(handlers::handle_call_hierarchy_prepare)? .on::(handlers::handle_call_hierarchy_incoming)? .on::(handlers::handle_call_hierarchy_outgoing)? + .on::(handlers::handle_ssr)? .finish(); Ok(()) } diff --git a/crates/ra_lsp_server/src/main_loop/handlers.rs b/crates/ra_lsp_server/src/main_loop/handlers.rs index 3893430c04..ae51141cbf 100644 --- a/crates/ra_lsp_server/src/main_loop/handlers.rs +++ b/crates/ra_lsp_server/src/main_loop/handlers.rs @@ -882,6 +882,11 @@ pub fn handle_document_highlight( )) } +pub fn handle_ssr(world: WorldSnapshot, params: req::SsrParams) -> Result { + let _p = profile("handle_ssr"); + world.analysis().structural_search_replace(¶ms.arg)??.try_conv_with(&world) +} + pub fn publish_diagnostics(world: &WorldSnapshot, file_id: FileId) -> Result { let _p = profile("publish_diagnostics"); let line_index = world.analysis().file_line_index(file_id)?; diff --git a/crates/ra_lsp_server/src/req.rs b/crates/ra_lsp_server/src/req.rs index dc327f53d2..7ff7f60b31 100644 --- a/crates/ra_lsp_server/src/req.rs +++ b/crates/ra_lsp_server/src/req.rs @@ -206,3 +206,16 @@ pub struct InlayHint { pub kind: InlayKind, pub label: String, } + +pub enum Ssr {} + +impl Request for Ssr { + type Params = SsrParams; + type Result = SourceChange; + const METHOD: &'static str = "rust-analyzer/ssr"; +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct SsrParams { + pub arg: String, +} diff --git a/crates/ra_syntax/src/ast/make.rs b/crates/ra_syntax/src/ast/make.rs index 862eb11728..89d1403e78 100644 --- a/crates/ra_syntax/src/ast/make.rs +++ b/crates/ra_syntax/src/ast/make.rs @@ -84,7 +84,7 @@ pub fn expr_prefix(op: SyntaxKind, expr: ast::Expr) -> ast::Expr { let token = token(op); expr_from_text(&format!("{}{}", token, expr.syntax())) } -fn expr_from_text(text: &str) -> ast::Expr { +pub fn expr_from_text(text: &str) -> ast::Expr { ast_from_text(&format!("const C: () = {};", text)) } diff --git a/editors/code/package.json b/editors/code/package.json index 46acbfe76c..774fed21d1 100644 --- a/editors/code/package.json +++ b/editors/code/package.json @@ -124,6 +124,11 @@ "command": "rust-analyzer.onEnter", "title": "Enhanced enter key", "category": "Rust Analyzer" + }, + { + "command": "rust-analyzer.ssr", + "title": "Structural Search Replace", + "category": "Rust Analyzer" } ], "keybindings": [ diff --git a/editors/code/src/commands/index.ts b/editors/code/src/commands/index.ts index aee9694320..b5ebec117f 100644 --- a/editors/code/src/commands/index.ts +++ b/editors/code/src/commands/index.ts @@ -12,6 +12,7 @@ export * from './parent_module'; export * from './syntax_tree'; export * from './expand_macro'; export * from './runnables'; +export * from './ssr'; export function collectGarbage(ctx: Ctx): Cmd { return async () => { diff --git a/editors/code/src/commands/ssr.ts b/editors/code/src/commands/ssr.ts new file mode 100644 index 0000000000..6287bf47b4 --- /dev/null +++ b/editors/code/src/commands/ssr.ts @@ -0,0 +1,36 @@ +import { Ctx, Cmd } from '../ctx'; +import { applySourceChange, SourceChange } from '../source_change'; +import * as vscode from 'vscode'; + +export function ssr(ctx: Ctx): Cmd { + return async () => { + const client = ctx.client; + if (!client) return; + + const options: vscode.InputBoxOptions = { + placeHolder: "foo($a:expr, $b:expr) ==>> bar($a, foo($b))", + prompt: "Enter request", + validateInput: (x: string) => { + if (x.includes('==>>')) { + return null; + } + return "Enter request: pattern ==>> template" + } + } + const request = await vscode.window.showInputBox(options); + + if (!request) return; + + const ssrRequest: SsrRequest = { arg: request }; + const change = await client.sendRequest( + 'rust-analyzer/ssr', + ssrRequest, + ); + + await applySourceChange(ctx, change); + }; +} + +interface SsrRequest { + arg: string; +} diff --git a/editors/code/src/main.ts b/editors/code/src/main.ts index 5efce41f40..5a99e96f0e 100644 --- a/editors/code/src/main.ts +++ b/editors/code/src/main.ts @@ -22,6 +22,7 @@ export async function activate(context: vscode.ExtensionContext) { ctx.registerCommand('run', commands.run); ctx.registerCommand('reload', commands.reload); ctx.registerCommand('onEnter', commands.onEnter); + ctx.registerCommand('ssr', commands.ssr) // Internal commands which are invoked by the server. ctx.registerCommand('runSingle', commands.runSingle);