diff --git a/crates/assists/src/handlers/generate_function.rs b/crates/assists/src/handlers/generate_function.rs index b38d640581..d23f4293b3 100644 --- a/crates/assists/src/handlers/generate_function.rs +++ b/crates/assists/src/handlers/generate_function.rs @@ -36,8 +36,8 @@ use crate::{ // bar("", baz()); // } // -// fn bar(arg: &str, baz: Baz) { -// ${0:todo!()} +// fn bar(arg: &str, baz: Baz) ${0:-> ()} { +// todo!() // } // // ``` @@ -80,9 +80,9 @@ pub(crate) fn generate_function(acc: &mut Assists, ctx: &AssistContext) -> Optio struct FunctionTemplate { insert_offset: TextSize, - placeholder_expr: ast::MacroCall, leading_ws: String, fn_def: ast::Fn, + ret_type: ast::RetType, trailing_ws: String, file: FileId, } @@ -90,11 +90,9 @@ struct FunctionTemplate { impl FunctionTemplate { fn to_string(&self, cap: Option) -> String { let f = match cap { - Some(cap) => render_snippet( - cap, - self.fn_def.syntax(), - Cursor::Replace(self.placeholder_expr.syntax()), - ), + Some(cap) => { + render_snippet(cap, self.fn_def.syntax(), Cursor::Replace(self.ret_type.syntax())) + } None => self.fn_def.to_string(), }; format!("{}{}{}", self.leading_ws, f, self.trailing_ws) @@ -141,8 +139,14 @@ impl FunctionBuilder { let placeholder_expr = make::expr_todo(); let fn_body = make::block_expr(vec![], Some(placeholder_expr)); let visibility = if self.needs_pub { Some(make::visibility_pub_crate()) } else { None }; - let mut fn_def = - make::fn_(visibility, self.fn_name, self.type_params, self.params, fn_body); + let mut fn_def = make::fn_( + visibility, + self.fn_name, + self.type_params, + self.params, + fn_body, + Some(make::ret_type(make::ty("()"))), + ); let leading_ws; let trailing_ws; @@ -163,12 +167,10 @@ impl FunctionBuilder { } }; - let placeholder_expr = - fn_def.syntax().descendants().find_map(ast::MacroCall::cast).unwrap(); FunctionTemplate { insert_offset, - placeholder_expr, leading_ws, + ret_type: fn_def.ret_type().unwrap(), fn_def, trailing_ws, file: self.file, @@ -349,8 +351,8 @@ fn foo() { bar(); } -fn bar() { - ${0:todo!()} +fn bar() ${0:-> ()} { + todo!() } ", ) @@ -376,8 +378,8 @@ impl Foo { } } -fn bar() { - ${0:todo!()} +fn bar() ${0:-> ()} { + todo!() } ", ) @@ -400,8 +402,8 @@ fn foo1() { bar(); } -fn bar() { - ${0:todo!()} +fn bar() ${0:-> ()} { + todo!() } fn foo2() {} @@ -426,8 +428,8 @@ mod baz { bar(); } - fn bar() { - ${0:todo!()} + fn bar() ${0:-> ()} { + todo!() } } ", @@ -452,8 +454,8 @@ fn foo() { bar(baz()); } -fn bar(baz: Baz) { - ${0:todo!()} +fn bar(baz: Baz) ${0:-> ()} { + todo!() } ", ); @@ -485,8 +487,8 @@ impl Baz { } } -fn bar(baz: Baz) { - ${0:todo!()} +fn bar(baz: Baz) ${0:-> ()} { + todo!() } ", ) @@ -506,8 +508,8 @@ fn foo() { bar("bar") } -fn bar(arg: &str) { - ${0:todo!()} +fn bar(arg: &str) ${0:-> ()} { + todo!() } "#, ) @@ -527,8 +529,8 @@ fn foo() { bar('x') } -fn bar(arg: char) { - ${0:todo!()} +fn bar(arg: char) ${0:-> ()} { + todo!() } "#, ) @@ -548,8 +550,8 @@ fn foo() { bar(42) } -fn bar(arg: i32) { - ${0:todo!()} +fn bar(arg: i32) ${0:-> ()} { + todo!() } ", ) @@ -569,8 +571,8 @@ fn foo() { bar(42 as u8) } -fn bar(arg: u8) { - ${0:todo!()} +fn bar(arg: u8) ${0:-> ()} { + todo!() } ", ) @@ -594,8 +596,8 @@ fn foo() { bar(x as u8) } -fn bar(x: u8) { - ${0:todo!()} +fn bar(x: u8) ${0:-> ()} { + todo!() } ", ) @@ -617,8 +619,8 @@ fn foo() { bar(worble) } -fn bar(worble: ()) { - ${0:todo!()} +fn bar(worble: ()) ${0:-> ()} { + todo!() } ", ) @@ -646,8 +648,8 @@ fn baz() { bar(foo()) } -fn bar(foo: impl Foo) { - ${0:todo!()} +fn bar(foo: impl Foo) ${0:-> ()} { + todo!() } ", ) @@ -673,8 +675,8 @@ fn foo() { bar(&baz()) } -fn bar(baz: &Baz) { - ${0:todo!()} +fn bar(baz: &Baz) ${0:-> ()} { + todo!() } ", ) @@ -702,8 +704,8 @@ fn foo() { bar(Baz::baz()) } -fn bar(baz: Baz::Bof) { - ${0:todo!()} +fn bar(baz: Baz::Bof) ${0:-> ()} { + todo!() } ", ) @@ -725,8 +727,8 @@ fn foo(t: T) { bar(t) } -fn bar(t: T) { - ${0:todo!()} +fn bar(t: T) ${0:-> ()} { + todo!() } ", ) @@ -756,8 +758,8 @@ fn foo() { bar(Baz::new); } -fn bar(arg: fn() -> Baz) { - ${0:todo!()} +fn bar(arg: fn() -> Baz) ${0:-> ()} { + todo!() } ", ) @@ -781,8 +783,8 @@ fn foo() { bar(closure) } -fn bar(closure: impl Fn(i64) -> i64) { - ${0:todo!()} +fn bar(closure: impl Fn(i64) -> i64) ${0:-> ()} { + todo!() } ", ) @@ -802,8 +804,8 @@ fn foo() { bar(baz) } -fn bar(baz: ()) { - ${0:todo!()} +fn bar(baz: ()) ${0:-> ()} { + todo!() } ", ) @@ -827,8 +829,8 @@ fn foo() { bar(baz(), baz()) } -fn bar(baz_1: Baz, baz_2: Baz) { - ${0:todo!()} +fn bar(baz_1: Baz, baz_2: Baz) ${0:-> ()} { + todo!() } ", ) @@ -852,8 +854,8 @@ fn foo() { bar(baz(), baz(), "foo", "bar") } -fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) { - ${0:todo!()} +fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) ${0:-> ()} { + todo!() } "#, ) @@ -872,8 +874,8 @@ fn foo() { ", r" mod bar { - pub(crate) fn my_fn() { - ${0:todo!()} + pub(crate) fn my_fn() ${0:-> ()} { + todo!() } } @@ -911,8 +913,8 @@ fn bar() { baz(foo) } -fn baz(foo: foo::Foo) { - ${0:todo!()} +fn baz(foo: foo::Foo) ${0:-> ()} { + todo!() } ", ) @@ -935,8 +937,8 @@ fn foo() { mod bar { fn something_else() {} - pub(crate) fn my_fn() { - ${0:todo!()} + pub(crate) fn my_fn() ${0:-> ()} { + todo!() } } @@ -963,8 +965,8 @@ fn foo() { r" mod bar { mod baz { - pub(crate) fn my_fn() { - ${0:todo!()} + pub(crate) fn my_fn() ${0:-> ()} { + todo!() } } } @@ -992,8 +994,8 @@ fn main() { r" -pub(crate) fn bar() { - ${0:todo!()} +pub(crate) fn bar() ${0:-> ()} { + todo!() }", ) } diff --git a/crates/assists/src/tests/generated.rs b/crates/assists/src/tests/generated.rs index 7f6e98a548..41f536574e 100644 --- a/crates/assists/src/tests/generated.rs +++ b/crates/assists/src/tests/generated.rs @@ -454,8 +454,8 @@ fn foo() { bar("", baz()); } -fn bar(arg: &str, baz: Baz) { - ${0:todo!()} +fn bar(arg: &str, baz: Baz) ${0:-> ()} { + todo!() } "#####, diff --git a/crates/hir/src/code_model.rs b/crates/hir/src/code_model.rs index a101d724e4..b65be4fe1b 100644 --- a/crates/hir/src/code_model.rs +++ b/crates/hir/src/code_model.rs @@ -255,6 +255,25 @@ impl ModuleDef { ModuleDef::BuiltinType(it) => Some(it.as_name()), } } + + pub fn diagnostics(self, db: &dyn HirDatabase, sink: &mut DiagnosticSink) { + let id = match self { + ModuleDef::Adt(it) => match it { + Adt::Struct(it) => it.id.into(), + Adt::Enum(it) => it.id.into(), + Adt::Union(it) => it.id.into(), + }, + ModuleDef::Trait(it) => it.id.into(), + ModuleDef::Function(it) => it.id.into(), + ModuleDef::TypeAlias(it) => it.id.into(), + ModuleDef::Module(it) => it.id.into(), + ModuleDef::Const(it) => it.id.into(), + ModuleDef::Static(it) => it.id.into(), + _ => return, + }; + + hir_ty::diagnostics::validate_module_item(db, id, sink) + } } pub use hir_def::{ @@ -358,6 +377,8 @@ impl Module { let crate_def_map = db.crate_def_map(self.id.krate); crate_def_map.add_diagnostics(db.upcast(), self.id.local_id, sink); for decl in self.declarations(db) { + decl.diagnostics(db, sink); + match decl { crate::ModuleDef::Function(f) => f.diagnostics(db, sink), crate::ModuleDef::Module(m) => { @@ -760,7 +781,8 @@ impl Function { } pub fn diagnostics(self, db: &dyn HirDatabase, sink: &mut DiagnosticSink) { - hir_ty::diagnostics::validate_body(db, self.id.into(), sink) + hir_ty::diagnostics::validate_module_item(db, self.id.into(), sink); + hir_ty::diagnostics::validate_body(db, self.id.into(), sink); } /// Whether this function declaration has a definition. diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs index 363164b9b4..da2b40849a 100644 --- a/crates/hir/src/diagnostics.rs +++ b/crates/hir/src/diagnostics.rs @@ -2,5 +2,6 @@ pub use hir_def::diagnostics::UnresolvedModule; pub use hir_expand::diagnostics::{Diagnostic, DiagnosticSink, DiagnosticSinkBuilder}; pub use hir_ty::diagnostics::{ - MismatchedArgCount, MissingFields, MissingMatchArms, MissingOkInTailExpr, NoSuchField, + IncorrectCase, MismatchedArgCount, MissingFields, MissingMatchArms, MissingOkInTailExpr, + NoSuchField, }; diff --git a/crates/hir_def/src/item_scope.rs b/crates/hir_def/src/item_scope.rs index 12c24e1ca3..a8b3fe844a 100644 --- a/crates/hir_def/src/item_scope.rs +++ b/crates/hir_def/src/item_scope.rs @@ -95,6 +95,12 @@ impl ItemScope { self.impls.iter().copied() } + pub fn values( + &self, + ) -> impl Iterator + ExactSizeIterator + '_ { + self.values.values().copied() + } + pub fn visibility_of(&self, def: ModuleDefId) -> Option { self.name_of(ItemInNs::Types(def)) .or_else(|| self.name_of(ItemInNs::Values(def))) diff --git a/crates/hir_ty/src/diagnostics.rs b/crates/hir_ty/src/diagnostics.rs index 9ba005fabd..dfe98571e7 100644 --- a/crates/hir_ty/src/diagnostics.rs +++ b/crates/hir_ty/src/diagnostics.rs @@ -2,10 +2,11 @@ mod expr; mod match_check; mod unsafe_check; +mod decl_check; -use std::any::Any; +use std::{any::Any, fmt}; -use hir_def::DefWithBodyId; +use hir_def::{DefWithBodyId, ModuleDefId}; use hir_expand::diagnostics::{Diagnostic, DiagnosticCode, DiagnosticSink}; use hir_expand::{name::Name, HirFileId, InFile}; use stdx::format_to; @@ -15,6 +16,16 @@ use crate::db::HirDatabase; pub use crate::diagnostics::expr::{record_literal_missing_fields, record_pattern_missing_fields}; +pub fn validate_module_item( + db: &dyn HirDatabase, + owner: ModuleDefId, + sink: &mut DiagnosticSink<'_>, +) { + let _p = profile::span("validate_module_item"); + let mut validator = decl_check::DeclValidator::new(owner, sink); + validator.validate_item(db); +} + pub fn validate_body(db: &dyn HirDatabase, owner: DefWithBodyId, sink: &mut DiagnosticSink<'_>) { let _p = profile::span("validate_body"); let infer = db.infer(owner); @@ -231,6 +242,66 @@ impl Diagnostic for MismatchedArgCount { } } +#[derive(Debug)] +pub enum CaseType { + // `some_var` + LowerSnakeCase, + // `SOME_CONST` + UpperSnakeCase, + // `SomeStruct` + UpperCamelCase, +} + +impl fmt::Display for CaseType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let repr = match self { + CaseType::LowerSnakeCase => "snake_case", + CaseType::UpperSnakeCase => "UPPER_SNAKE_CASE", + CaseType::UpperCamelCase => "CamelCase", + }; + + write!(f, "{}", repr) + } +} + +#[derive(Debug)] +pub struct IncorrectCase { + pub file: HirFileId, + pub ident: AstPtr, + pub expected_case: CaseType, + pub ident_type: String, + pub ident_text: String, + pub suggested_text: String, +} + +impl Diagnostic for IncorrectCase { + fn code(&self) -> DiagnosticCode { + DiagnosticCode("incorrect-ident-case") + } + + fn message(&self) -> String { + format!( + "{} `{}` should have {} name, e.g. `{}`", + self.ident_type, + self.ident_text, + self.expected_case.to_string(), + self.suggested_text + ) + } + + fn display_source(&self) -> InFile { + InFile::new(self.file, self.ident.clone().into()) + } + + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } + + fn is_experimental(&self) -> bool { + true + } +} + #[cfg(test)] mod tests { use base_db::{fixture::WithFixture, FileId, SourceDatabase, SourceDatabaseExt}; @@ -242,7 +313,10 @@ mod tests { use rustc_hash::FxHashMap; use syntax::{TextRange, TextSize}; - use crate::{diagnostics::validate_body, test_db::TestDB}; + use crate::{ + diagnostics::{validate_body, validate_module_item}, + test_db::TestDB, + }; impl TestDB { fn diagnostics(&self, mut cb: F) { @@ -253,6 +327,9 @@ mod tests { let mut fns = Vec::new(); for (module_id, _) in crate_def_map.modules.iter() { for decl in crate_def_map[module_id].scope.declarations() { + let mut sink = DiagnosticSinkBuilder::new().build(&mut cb); + validate_module_item(self, decl, &mut sink); + if let ModuleDefId::FunctionId(f) = decl { fns.push(f) } @@ -262,6 +339,8 @@ mod tests { let impl_data = self.impl_data(impl_id); for item in impl_data.items.iter() { if let AssocItemId::FunctionId(f) = item { + let mut sink = DiagnosticSinkBuilder::new().build(&mut cb); + validate_module_item(self, ModuleDefId::FunctionId(*f), &mut sink); fns.push(*f) } } diff --git a/crates/hir_ty/src/diagnostics/decl_check.rs b/crates/hir_ty/src/diagnostics/decl_check.rs new file mode 100644 index 0000000000..f987636fe5 --- /dev/null +++ b/crates/hir_ty/src/diagnostics/decl_check.rs @@ -0,0 +1,833 @@ +//! Provides validators for the item declarations. +//! +//! This includes the following items: +//! +//! - variable bindings (e.g. `let x = foo();`) +//! - struct fields (e.g. `struct Foo { field: u8 }`) +//! - enum variants (e.g. `enum Foo { Variant { field: u8 } }`) +//! - function/method arguments (e.g. `fn foo(arg: u8)`) +//! - constants (e.g. `const FOO: u8 = 10;`) +//! - static items (e.g. `static FOO: u8 = 10;`) +//! - match arm bindings (e.g. `foo @ Some(_)`) + +mod case_conv; + +use hir_def::{ + adt::VariantData, + expr::{Pat, PatId}, + src::HasSource, + AdtId, ConstId, EnumId, FunctionId, Lookup, ModuleDefId, StaticId, StructId, +}; +use hir_expand::{ + diagnostics::DiagnosticSink, + name::{AsName, Name}, +}; +use syntax::{ + ast::{self, NameOwner}, + AstNode, AstPtr, +}; + +use crate::{ + db::HirDatabase, + diagnostics::{decl_check::case_conv::*, CaseType, IncorrectCase}, +}; + +pub(super) struct DeclValidator<'a, 'b: 'a> { + owner: ModuleDefId, + sink: &'a mut DiagnosticSink<'b>, +} + +#[derive(Debug)] +struct Replacement { + current_name: Name, + suggested_text: String, + expected_case: CaseType, +} + +impl<'a, 'b> DeclValidator<'a, 'b> { + pub(super) fn new( + owner: ModuleDefId, + sink: &'a mut DiagnosticSink<'b>, + ) -> DeclValidator<'a, 'b> { + DeclValidator { owner, sink } + } + + pub(super) fn validate_item(&mut self, db: &dyn HirDatabase) { + match self.owner { + ModuleDefId::FunctionId(func) => self.validate_func(db, func), + ModuleDefId::AdtId(adt) => self.validate_adt(db, adt), + ModuleDefId::ConstId(const_id) => self.validate_const(db, const_id), + ModuleDefId::StaticId(static_id) => self.validate_static(db, static_id), + _ => return, + } + } + + fn validate_adt(&mut self, db: &dyn HirDatabase, adt: AdtId) { + match adt { + AdtId::StructId(struct_id) => self.validate_struct(db, struct_id), + AdtId::EnumId(enum_id) => self.validate_enum(db, enum_id), + AdtId::UnionId(_) => { + // Unions aren't yet supported by this validator. + } + } + } + + fn validate_func(&mut self, db: &dyn HirDatabase, func: FunctionId) { + let data = db.function_data(func); + let body = db.body(func.into()); + + // 1. Check the function name. + let function_name = data.name.to_string(); + let fn_name_replacement = if let Some(new_name) = to_lower_snake_case(&function_name) { + let replacement = Replacement { + current_name: data.name.clone(), + suggested_text: new_name, + expected_case: CaseType::LowerSnakeCase, + }; + Some(replacement) + } else { + None + }; + + // 2. Check the param names. + let mut fn_param_replacements = Vec::new(); + + for pat_id in body.params.iter().cloned() { + let pat = &body[pat_id]; + + let param_name = match pat { + Pat::Bind { name, .. } => name, + _ => continue, + }; + + let name = param_name.to_string(); + if let Some(new_name) = to_lower_snake_case(&name) { + let replacement = Replacement { + current_name: param_name.clone(), + suggested_text: new_name, + expected_case: CaseType::LowerSnakeCase, + }; + fn_param_replacements.push(replacement); + } + } + + // 3. Check the patterns inside the function body. + let mut pats_replacements = Vec::new(); + + for (pat_idx, pat) in body.pats.iter() { + if body.params.contains(&pat_idx) { + // We aren't interested in function parameters, we've processed them above. + continue; + } + + let bind_name = match pat { + Pat::Bind { name, .. } => name, + _ => continue, + }; + + let name = bind_name.to_string(); + if let Some(new_name) = to_lower_snake_case(&name) { + let replacement = Replacement { + current_name: bind_name.clone(), + suggested_text: new_name, + expected_case: CaseType::LowerSnakeCase, + }; + pats_replacements.push((pat_idx, replacement)); + } + } + + // 4. If there is at least one element to spawn a warning on, go to the source map and generate a warning. + self.create_incorrect_case_diagnostic_for_func( + func, + db, + fn_name_replacement, + fn_param_replacements, + ); + self.create_incorrect_case_diagnostic_for_variables(func, db, pats_replacements); + + // 5. Recursively validate inner scope items, such as static variables and constants. + for (item_id, _) in body.item_scope.values() { + let mut validator = DeclValidator::new(item_id, self.sink); + validator.validate_item(db); + } + } + + /// Given the information about incorrect names in the function declaration, looks up into the source code + /// for exact locations and adds diagnostics into the sink. + fn create_incorrect_case_diagnostic_for_func( + &mut self, + func: FunctionId, + db: &dyn HirDatabase, + fn_name_replacement: Option, + fn_param_replacements: Vec, + ) { + // XXX: only look at sources if we do have incorrect names + if fn_name_replacement.is_none() && fn_param_replacements.is_empty() { + return; + } + + let fn_loc = func.lookup(db.upcast()); + let fn_src = fn_loc.source(db.upcast()); + + // 1. Diagnostic for function name. + if let Some(replacement) = fn_name_replacement { + let ast_ptr = match fn_src.value.name() { + Some(name) => name, + None => { + // We don't want rust-analyzer to panic over this, but it is definitely some kind of error in the logic. + log::error!( + "Replacement ({:?}) was generated for a function without a name: {:?}", + replacement, + fn_src + ); + return; + } + }; + + let diagnostic = IncorrectCase { + file: fn_src.file_id, + ident_type: "Function".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: replacement.expected_case, + ident_text: replacement.current_name.to_string(), + suggested_text: replacement.suggested_text, + }; + + self.sink.push(diagnostic); + } + + // 2. Diagnostics for function params. + let fn_params_list = match fn_src.value.param_list() { + Some(params) => params, + None => { + if !fn_param_replacements.is_empty() { + log::error!( + "Replacements ({:?}) were generated for a function parameters which had no parameters list: {:?}", + fn_param_replacements, fn_src + ); + } + return; + } + }; + let mut fn_params_iter = fn_params_list.params(); + for param_to_rename in fn_param_replacements { + // We assume that parameters in replacement are in the same order as in the + // actual params list, but just some of them (ones that named correctly) are skipped. + let ast_ptr: ast::Name = loop { + match fn_params_iter.next() { + Some(element) + if pat_equals_to_name(element.pat(), ¶m_to_rename.current_name) => + { + if let ast::Pat::IdentPat(pat) = element.pat().unwrap() { + break pat.name().unwrap(); + } else { + // This is critical. If we consider this parameter the expected one, + // it **must** have a name. + panic!( + "Pattern {:?} equals to expected replacement {:?}, but has no name", + element, param_to_rename + ); + } + } + Some(_) => {} + None => { + log::error!( + "Replacement ({:?}) was generated for a function parameter which was not found: {:?}", + param_to_rename, fn_src + ); + return; + } + } + }; + + let diagnostic = IncorrectCase { + file: fn_src.file_id, + ident_type: "Argument".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: param_to_rename.expected_case, + ident_text: param_to_rename.current_name.to_string(), + suggested_text: param_to_rename.suggested_text, + }; + + self.sink.push(diagnostic); + } + } + + /// Given the information about incorrect variable names, looks up into the source code + /// for exact locations and adds diagnostics into the sink. + fn create_incorrect_case_diagnostic_for_variables( + &mut self, + func: FunctionId, + db: &dyn HirDatabase, + pats_replacements: Vec<(PatId, Replacement)>, + ) { + // XXX: only look at source_map if we do have missing fields + if pats_replacements.is_empty() { + return; + } + + let (_, source_map) = db.body_with_source_map(func.into()); + + for (id, replacement) in pats_replacements { + if let Ok(source_ptr) = source_map.pat_syntax(id) { + if let Some(expr) = source_ptr.value.as_ref().left() { + let root = source_ptr.file_syntax(db.upcast()); + if let ast::Pat::IdentPat(ident_pat) = expr.to_node(&root) { + let parent = match ident_pat.syntax().parent() { + Some(parent) => parent, + None => continue, + }; + let name_ast = match ident_pat.name() { + Some(name_ast) => name_ast, + None => continue, + }; + + // We have to check that it's either `let var = ...` or `var @ Variant(_)` statement, + // because e.g. match arms are patterns as well. + // In other words, we check that it's a named variable binding. + let is_binding = ast::LetStmt::cast(parent.clone()).is_some() + || (ast::MatchArm::cast(parent).is_some() + && ident_pat.at_token().is_some()); + if !is_binding { + // This pattern is not an actual variable declaration, e.g. `Some(val) => {..}` match arm. + continue; + } + + let diagnostic = IncorrectCase { + file: source_ptr.file_id, + ident_type: "Variable".to_string(), + ident: AstPtr::new(&name_ast).into(), + expected_case: replacement.expected_case, + ident_text: replacement.current_name.to_string(), + suggested_text: replacement.suggested_text, + }; + + self.sink.push(diagnostic); + } + } + } + } + } + + fn validate_struct(&mut self, db: &dyn HirDatabase, struct_id: StructId) { + let data = db.struct_data(struct_id); + + // 1. Check the structure name. + let struct_name = data.name.to_string(); + let struct_name_replacement = if let Some(new_name) = to_camel_case(&struct_name) { + let replacement = Replacement { + current_name: data.name.clone(), + suggested_text: new_name, + expected_case: CaseType::UpperCamelCase, + }; + Some(replacement) + } else { + None + }; + + // 2. Check the field names. + let mut struct_fields_replacements = Vec::new(); + + if let VariantData::Record(fields) = data.variant_data.as_ref() { + for (_, field) in fields.iter() { + let field_name = field.name.to_string(); + if let Some(new_name) = to_lower_snake_case(&field_name) { + let replacement = Replacement { + current_name: field.name.clone(), + suggested_text: new_name, + expected_case: CaseType::LowerSnakeCase, + }; + struct_fields_replacements.push(replacement); + } + } + } + + // 3. If there is at least one element to spawn a warning on, go to the source map and generate a warning. + self.create_incorrect_case_diagnostic_for_struct( + struct_id, + db, + struct_name_replacement, + struct_fields_replacements, + ); + } + + /// Given the information about incorrect names in the struct declaration, looks up into the source code + /// for exact locations and adds diagnostics into the sink. + fn create_incorrect_case_diagnostic_for_struct( + &mut self, + struct_id: StructId, + db: &dyn HirDatabase, + struct_name_replacement: Option, + struct_fields_replacements: Vec, + ) { + // XXX: only look at sources if we do have incorrect names + if struct_name_replacement.is_none() && struct_fields_replacements.is_empty() { + return; + } + + let struct_loc = struct_id.lookup(db.upcast()); + let struct_src = struct_loc.source(db.upcast()); + + if let Some(replacement) = struct_name_replacement { + let ast_ptr = match struct_src.value.name() { + Some(name) => name, + None => { + // We don't want rust-analyzer to panic over this, but it is definitely some kind of error in the logic. + log::error!( + "Replacement ({:?}) was generated for a structure without a name: {:?}", + replacement, + struct_src + ); + return; + } + }; + + let diagnostic = IncorrectCase { + file: struct_src.file_id, + ident_type: "Structure".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: replacement.expected_case, + ident_text: replacement.current_name.to_string(), + suggested_text: replacement.suggested_text, + }; + + self.sink.push(diagnostic); + } + + let struct_fields_list = match struct_src.value.field_list() { + Some(ast::FieldList::RecordFieldList(fields)) => fields, + _ => { + if !struct_fields_replacements.is_empty() { + log::error!( + "Replacements ({:?}) were generated for a structure fields which had no fields list: {:?}", + struct_fields_replacements, struct_src + ); + } + return; + } + }; + let mut struct_fields_iter = struct_fields_list.fields(); + for field_to_rename in struct_fields_replacements { + // We assume that parameters in replacement are in the same order as in the + // actual params list, but just some of them (ones that named correctly) are skipped. + let ast_ptr = loop { + match struct_fields_iter.next() { + Some(element) if names_equal(element.name(), &field_to_rename.current_name) => { + break element.name().unwrap() + } + Some(_) => {} + None => { + log::error!( + "Replacement ({:?}) was generated for a structure field which was not found: {:?}", + field_to_rename, struct_src + ); + return; + } + } + }; + + let diagnostic = IncorrectCase { + file: struct_src.file_id, + ident_type: "Field".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: field_to_rename.expected_case, + ident_text: field_to_rename.current_name.to_string(), + suggested_text: field_to_rename.suggested_text, + }; + + self.sink.push(diagnostic); + } + } + + fn validate_enum(&mut self, db: &dyn HirDatabase, enum_id: EnumId) { + let data = db.enum_data(enum_id); + + // 1. Check the enum name. + let enum_name = data.name.to_string(); + let enum_name_replacement = if let Some(new_name) = to_camel_case(&enum_name) { + let replacement = Replacement { + current_name: data.name.clone(), + suggested_text: new_name, + expected_case: CaseType::UpperCamelCase, + }; + Some(replacement) + } else { + None + }; + + // 2. Check the field names. + let mut enum_fields_replacements = Vec::new(); + + for (_, variant) in data.variants.iter() { + let variant_name = variant.name.to_string(); + if let Some(new_name) = to_camel_case(&variant_name) { + let replacement = Replacement { + current_name: variant.name.clone(), + suggested_text: new_name, + expected_case: CaseType::UpperCamelCase, + }; + enum_fields_replacements.push(replacement); + } + } + + // 3. If there is at least one element to spawn a warning on, go to the source map and generate a warning. + self.create_incorrect_case_diagnostic_for_enum( + enum_id, + db, + enum_name_replacement, + enum_fields_replacements, + ) + } + + /// Given the information about incorrect names in the struct declaration, looks up into the source code + /// for exact locations and adds diagnostics into the sink. + fn create_incorrect_case_diagnostic_for_enum( + &mut self, + enum_id: EnumId, + db: &dyn HirDatabase, + enum_name_replacement: Option, + enum_variants_replacements: Vec, + ) { + // XXX: only look at sources if we do have incorrect names + if enum_name_replacement.is_none() && enum_variants_replacements.is_empty() { + return; + } + + let enum_loc = enum_id.lookup(db.upcast()); + let enum_src = enum_loc.source(db.upcast()); + + if let Some(replacement) = enum_name_replacement { + let ast_ptr = match enum_src.value.name() { + Some(name) => name, + None => { + // We don't want rust-analyzer to panic over this, but it is definitely some kind of error in the logic. + log::error!( + "Replacement ({:?}) was generated for a enum without a name: {:?}", + replacement, + enum_src + ); + return; + } + }; + + let diagnostic = IncorrectCase { + file: enum_src.file_id, + ident_type: "Enum".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: replacement.expected_case, + ident_text: replacement.current_name.to_string(), + suggested_text: replacement.suggested_text, + }; + + self.sink.push(diagnostic); + } + + let enum_variants_list = match enum_src.value.variant_list() { + Some(variants) => variants, + _ => { + if !enum_variants_replacements.is_empty() { + log::error!( + "Replacements ({:?}) were generated for a enum variants which had no fields list: {:?}", + enum_variants_replacements, enum_src + ); + } + return; + } + }; + let mut enum_variants_iter = enum_variants_list.variants(); + for variant_to_rename in enum_variants_replacements { + // We assume that parameters in replacement are in the same order as in the + // actual params list, but just some of them (ones that named correctly) are skipped. + let ast_ptr = loop { + match enum_variants_iter.next() { + Some(variant) + if names_equal(variant.name(), &variant_to_rename.current_name) => + { + break variant.name().unwrap() + } + Some(_) => {} + None => { + log::error!( + "Replacement ({:?}) was generated for a enum variant which was not found: {:?}", + variant_to_rename, enum_src + ); + return; + } + } + }; + + let diagnostic = IncorrectCase { + file: enum_src.file_id, + ident_type: "Variant".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: variant_to_rename.expected_case, + ident_text: variant_to_rename.current_name.to_string(), + suggested_text: variant_to_rename.suggested_text, + }; + + self.sink.push(diagnostic); + } + } + + fn validate_const(&mut self, db: &dyn HirDatabase, const_id: ConstId) { + let data = db.const_data(const_id); + + let name = match &data.name { + Some(name) => name, + None => return, + }; + + let const_name = name.to_string(); + let replacement = if let Some(new_name) = to_upper_snake_case(&const_name) { + Replacement { + current_name: name.clone(), + suggested_text: new_name, + expected_case: CaseType::UpperSnakeCase, + } + } else { + // Nothing to do here. + return; + }; + + let const_loc = const_id.lookup(db.upcast()); + let const_src = const_loc.source(db.upcast()); + + let ast_ptr = match const_src.value.name() { + Some(name) => name, + None => return, + }; + + let diagnostic = IncorrectCase { + file: const_src.file_id, + ident_type: "Constant".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: replacement.expected_case, + ident_text: replacement.current_name.to_string(), + suggested_text: replacement.suggested_text, + }; + + self.sink.push(diagnostic); + } + + fn validate_static(&mut self, db: &dyn HirDatabase, static_id: StaticId) { + let data = db.static_data(static_id); + + let name = match &data.name { + Some(name) => name, + None => return, + }; + + let static_name = name.to_string(); + let replacement = if let Some(new_name) = to_upper_snake_case(&static_name) { + Replacement { + current_name: name.clone(), + suggested_text: new_name, + expected_case: CaseType::UpperSnakeCase, + } + } else { + // Nothing to do here. + return; + }; + + let static_loc = static_id.lookup(db.upcast()); + let static_src = static_loc.source(db.upcast()); + + let ast_ptr = match static_src.value.name() { + Some(name) => name, + None => return, + }; + + let diagnostic = IncorrectCase { + file: static_src.file_id, + ident_type: "Static variable".to_string(), + ident: AstPtr::new(&ast_ptr).into(), + expected_case: replacement.expected_case, + ident_text: replacement.current_name.to_string(), + suggested_text: replacement.suggested_text, + }; + + self.sink.push(diagnostic); + } +} + +fn names_equal(left: Option, right: &Name) -> bool { + if let Some(left) = left { + &left.as_name() == right + } else { + false + } +} + +fn pat_equals_to_name(pat: Option, name: &Name) -> bool { + if let Some(ast::Pat::IdentPat(ident)) = pat { + ident.to_string() == name.to_string() + } else { + false + } +} + +#[cfg(test)] +mod tests { + use crate::diagnostics::tests::check_diagnostics; + + #[test] + fn incorrect_function_name() { + check_diagnostics( + r#" +fn NonSnakeCaseName() {} +// ^^^^^^^^^^^^^^^^ Function `NonSnakeCaseName` should have snake_case name, e.g. `non_snake_case_name` +"#, + ); + } + + #[test] + fn incorrect_function_params() { + check_diagnostics( + r#" +fn foo(SomeParam: u8) {} + // ^^^^^^^^^ Argument `SomeParam` should have snake_case name, e.g. `some_param` + +fn foo2(ok_param: &str, CAPS_PARAM: u8) {} + // ^^^^^^^^^^ Argument `CAPS_PARAM` should have snake_case name, e.g. `caps_param` +"#, + ); + } + + #[test] + fn incorrect_variable_names() { + check_diagnostics( + r#" +fn foo() { + let SOME_VALUE = 10; + // ^^^^^^^^^^ Variable `SOME_VALUE` should have snake_case name, e.g. `some_value` + let AnotherValue = 20; + // ^^^^^^^^^^^^ Variable `AnotherValue` should have snake_case name, e.g. `another_value` +} +"#, + ); + } + + #[test] + fn incorrect_struct_name() { + check_diagnostics( + r#" +struct non_camel_case_name {} + // ^^^^^^^^^^^^^^^^^^^ Structure `non_camel_case_name` should have CamelCase name, e.g. `NonCamelCaseName` +"#, + ); + } + + #[test] + fn incorrect_struct_field() { + check_diagnostics( + r#" +struct SomeStruct { SomeField: u8 } + // ^^^^^^^^^ Field `SomeField` should have snake_case name, e.g. `some_field` +"#, + ); + } + + #[test] + fn incorrect_enum_name() { + check_diagnostics( + r#" +enum some_enum { Val(u8) } + // ^^^^^^^^^ Enum `some_enum` should have CamelCase name, e.g. `SomeEnum` +"#, + ); + } + + #[test] + fn incorrect_enum_variant_name() { + check_diagnostics( + r#" +enum SomeEnum { SOME_VARIANT(u8) } + // ^^^^^^^^^^^^ Variant `SOME_VARIANT` should have CamelCase name, e.g. `SomeVariant` +"#, + ); + } + + #[test] + fn incorrect_const_name() { + check_diagnostics( + r#" +const some_weird_const: u8 = 10; + // ^^^^^^^^^^^^^^^^ Constant `some_weird_const` should have UPPER_SNAKE_CASE name, e.g. `SOME_WEIRD_CONST` + +fn func() { + const someConstInFunc: &str = "hi there"; + // ^^^^^^^^^^^^^^^ Constant `someConstInFunc` should have UPPER_SNAKE_CASE name, e.g. `SOME_CONST_IN_FUNC` + +} +"#, + ); + } + + #[test] + fn incorrect_static_name() { + check_diagnostics( + r#" +static some_weird_const: u8 = 10; + // ^^^^^^^^^^^^^^^^ Static variable `some_weird_const` should have UPPER_SNAKE_CASE name, e.g. `SOME_WEIRD_CONST` + +fn func() { + static someConstInFunc: &str = "hi there"; + // ^^^^^^^^^^^^^^^ Static variable `someConstInFunc` should have UPPER_SNAKE_CASE name, e.g. `SOME_CONST_IN_FUNC` +} +"#, + ); + } + + #[test] + fn fn_inside_impl_struct() { + check_diagnostics( + r#" +struct someStruct; + // ^^^^^^^^^^ Structure `someStruct` should have CamelCase name, e.g. `SomeStruct` + +impl someStruct { + fn SomeFunc(&self) { + // ^^^^^^^^ Function `SomeFunc` should have snake_case name, e.g. `some_func` + static someConstInFunc: &str = "hi there"; + // ^^^^^^^^^^^^^^^ Static variable `someConstInFunc` should have UPPER_SNAKE_CASE name, e.g. `SOME_CONST_IN_FUNC` + let WHY_VAR_IS_CAPS = 10; + // ^^^^^^^^^^^^^^^ Variable `WHY_VAR_IS_CAPS` should have snake_case name, e.g. `why_var_is_caps` + } +} +"#, + ); + } + + #[test] + fn no_diagnostic_for_enum_varinats() { + check_diagnostics( + r#" +enum Option { Some, None } + +fn main() { + match Option::None { + None => (), + Some => (), + } +} +"#, + ); + } + + #[test] + fn non_let_bind() { + check_diagnostics( + r#" +enum Option { Some, None } + +fn main() { + match Option::None { + SOME_VAR @ None => (), + // ^^^^^^^^ Variable `SOME_VAR` should have snake_case name, e.g. `some_var` + Some => (), + } +} +"#, + ); + } +} diff --git a/crates/hir_ty/src/diagnostics/decl_check/case_conv.rs b/crates/hir_ty/src/diagnostics/decl_check/case_conv.rs new file mode 100644 index 0000000000..3800f2a6b3 --- /dev/null +++ b/crates/hir_ty/src/diagnostics/decl_check/case_conv.rs @@ -0,0 +1,194 @@ +//! Functions for string case manipulation, such as detecting the identifier case, +//! and converting it into appropriate form. + +#[derive(Debug)] +enum DetectedCase { + LowerCamelCase, + UpperCamelCase, + LowerSnakeCase, + UpperSnakeCase, + Unknown, +} + +fn detect_case(ident: &str) -> DetectedCase { + let trimmed_ident = ident.trim_matches('_'); + let first_lowercase = trimmed_ident.starts_with(|chr: char| chr.is_ascii_lowercase()); + let mut has_lowercase = first_lowercase; + let mut has_uppercase = false; + let mut has_underscore = false; + + for chr in trimmed_ident.chars() { + if chr == '_' { + has_underscore = true; + } else if chr.is_ascii_uppercase() { + has_uppercase = true; + } else if chr.is_ascii_lowercase() { + has_lowercase = true; + } + } + + if has_uppercase { + if !has_lowercase { + DetectedCase::UpperSnakeCase + } else if !has_underscore { + if first_lowercase { + DetectedCase::LowerCamelCase + } else { + DetectedCase::UpperCamelCase + } + } else { + // It has uppercase, it has lowercase, it has underscore. + // No assumptions here + DetectedCase::Unknown + } + } else { + DetectedCase::LowerSnakeCase + } +} + +/// Converts an identifier to an UpperCamelCase form. +/// Returns `None` if the string is already is UpperCamelCase. +pub fn to_camel_case(ident: &str) -> Option { + let detected_case = detect_case(ident); + + match detected_case { + DetectedCase::UpperCamelCase => return None, + DetectedCase::LowerCamelCase => { + let mut first_capitalized = false; + let output = ident + .chars() + .map(|chr| { + if !first_capitalized && chr.is_ascii_lowercase() { + first_capitalized = true; + chr.to_ascii_uppercase() + } else { + chr + } + }) + .collect(); + return Some(output); + } + _ => {} + } + + let mut output = String::with_capacity(ident.len()); + + let mut capital_added = false; + for chr in ident.chars() { + if chr.is_alphabetic() { + if !capital_added { + output.push(chr.to_ascii_uppercase()); + capital_added = true; + } else { + output.push(chr.to_ascii_lowercase()); + } + } else if chr == '_' { + // Skip this character and make the next one capital. + capital_added = false; + } else { + // Put the characted as-is. + output.push(chr); + } + } + + if output == ident { + // While we didn't detect the correct case at the beginning, there + // may be special cases: e.g. `A` is both valid CamelCase and UPPER_SNAKE_CASE. + None + } else { + Some(output) + } +} + +/// Converts an identifier to a lower_snake_case form. +/// Returns `None` if the string is already in lower_snake_case. +pub fn to_lower_snake_case(ident: &str) -> Option { + // First, assume that it's UPPER_SNAKE_CASE. + match detect_case(ident) { + DetectedCase::LowerSnakeCase => return None, + DetectedCase::UpperSnakeCase => { + return Some(ident.chars().map(|chr| chr.to_ascii_lowercase()).collect()) + } + _ => {} + } + + // Otherwise, assume that it's CamelCase. + let lower_snake_case = stdx::to_lower_snake_case(ident); + + if lower_snake_case == ident { + // While we didn't detect the correct case at the beginning, there + // may be special cases: e.g. `a` is both valid camelCase and snake_case. + None + } else { + Some(lower_snake_case) + } +} + +/// Converts an identifier to an UPPER_SNAKE_CASE form. +/// Returns `None` if the string is already is UPPER_SNAKE_CASE. +pub fn to_upper_snake_case(ident: &str) -> Option { + match detect_case(ident) { + DetectedCase::UpperSnakeCase => return None, + DetectedCase::LowerSnakeCase => { + return Some(ident.chars().map(|chr| chr.to_ascii_uppercase()).collect()) + } + _ => {} + } + + // Normalize the string from whatever form it's in currently, and then just make it uppercase. + let upper_snake_case = stdx::to_upper_snake_case(ident); + + if upper_snake_case == ident { + // While we didn't detect the correct case at the beginning, there + // may be special cases: e.g. `A` is both valid CamelCase and UPPER_SNAKE_CASE. + None + } else { + Some(upper_snake_case) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use expect_test::{expect, Expect}; + + fn check Option>(fun: F, input: &str, expect: Expect) { + // `None` is translated to empty string, meaning that there is nothing to fix. + let output = fun(input).unwrap_or_default(); + + expect.assert_eq(&output); + } + + #[test] + fn test_to_lower_snake_case() { + check(to_lower_snake_case, "lower_snake_case", expect![[""]]); + check(to_lower_snake_case, "UPPER_SNAKE_CASE", expect![["upper_snake_case"]]); + check(to_lower_snake_case, "Weird_Case", expect![["weird_case"]]); + check(to_lower_snake_case, "CamelCase", expect![["camel_case"]]); + check(to_lower_snake_case, "lowerCamelCase", expect![["lower_camel_case"]]); + check(to_lower_snake_case, "a", expect![[""]]); + } + + #[test] + fn test_to_camel_case() { + check(to_camel_case, "CamelCase", expect![[""]]); + check(to_camel_case, "CamelCase_", expect![[""]]); + check(to_camel_case, "_CamelCase", expect![[""]]); + check(to_camel_case, "lowerCamelCase", expect![["LowerCamelCase"]]); + check(to_camel_case, "lower_snake_case", expect![["LowerSnakeCase"]]); + check(to_camel_case, "UPPER_SNAKE_CASE", expect![["UpperSnakeCase"]]); + check(to_camel_case, "Weird_Case", expect![["WeirdCase"]]); + check(to_camel_case, "name", expect![["Name"]]); + check(to_camel_case, "A", expect![[""]]); + } + + #[test] + fn test_to_upper_snake_case() { + check(to_upper_snake_case, "UPPER_SNAKE_CASE", expect![[""]]); + check(to_upper_snake_case, "lower_snake_case", expect![["LOWER_SNAKE_CASE"]]); + check(to_upper_snake_case, "Weird_Case", expect![["WEIRD_CASE"]]); + check(to_upper_snake_case, "CamelCase", expect![["CAMEL_CASE"]]); + check(to_upper_snake_case, "lowerCamelCase", expect![["LOWER_CAMEL_CASE"]]); + check(to_upper_snake_case, "A", expect![[""]]); + } +} diff --git a/crates/hir_ty/src/diagnostics/unsafe_check.rs b/crates/hir_ty/src/diagnostics/unsafe_check.rs index 61ffbf5d15..21a121aad7 100644 --- a/crates/hir_ty/src/diagnostics/unsafe_check.rs +++ b/crates/hir_ty/src/diagnostics/unsafe_check.rs @@ -190,13 +190,13 @@ struct Ty { a: u8, } -static mut static_mut: Ty = Ty { a: 0 }; +static mut STATIC_MUT: Ty = Ty { a: 0 }; fn main() { - let x = static_mut.a; + let x = STATIC_MUT.a; //^^^^^^^^^^ This operation is unsafe and requires an unsafe function or block unsafe { - let x = static_mut.a; + let x = STATIC_MUT.a; } } "#, diff --git a/crates/ide/src/diagnostics.rs b/crates/ide/src/diagnostics.rs index f5d627b6ef..b30cdb6edd 100644 --- a/crates/ide/src/diagnostics.rs +++ b/crates/ide/src/diagnostics.rs @@ -96,6 +96,9 @@ pub(crate) fn diagnostics( .on::(|d| { res.borrow_mut().push(diagnostic_with_fix(d, &sema)); }) + .on::(|d| { + res.borrow_mut().push(warning_with_fix(d, &sema)); + }) // Only collect experimental diagnostics when they're enabled. .filter(|diag| !(diag.is_experimental() && config.disable_experimental)) .filter(|diag| !config.disabled.contains(diag.code().as_str())); @@ -130,6 +133,15 @@ fn diagnostic_with_fix(d: &D, sema: &Semantics(d: &D, sema: &Semantics) -> Diagnostic { + Diagnostic { + range: sema.diagnostics_display_range(d).range, + message: d.message(), + severity: Severity::WeakWarning, + fix: d.fix(&sema), + } +} + fn check_unnecessary_braces_in_use_statement( acc: &mut Vec, file_id: FileId, @@ -245,8 +257,37 @@ mod tests { assert_eq_text!(&after, &actual); assert!( - fix.fix_trigger_range.start() <= file_position.offset - && fix.fix_trigger_range.end() >= file_position.offset, + fix.fix_trigger_range.contains_inclusive(file_position.offset), + "diagnostic fix range {:?} does not touch cursor position {:?}", + fix.fix_trigger_range, + file_position.offset + ); + } + + /// Similar to `check_fix`, but applies all the available fixes. + fn check_fixes(ra_fixture_before: &str, ra_fixture_after: &str) { + let after = trim_indent(ra_fixture_after); + + let (analysis, file_position) = fixture::position(ra_fixture_before); + let diagnostic = analysis + .diagnostics(&DiagnosticsConfig::default(), file_position.file_id) + .unwrap() + .pop() + .unwrap(); + let fix = diagnostic.fix.unwrap(); + let target_file_contents = analysis.file_text(file_position.file_id).unwrap(); + let actual = { + let mut actual = target_file_contents.to_string(); + // Go from the last one to the first one, so that ranges won't be affected by previous edits. + for edit in fix.source_change.source_file_edits.iter().rev() { + edit.edit.apply(&mut actual); + } + actual + }; + + assert_eq_text!(&after, &actual); + assert!( + fix.fix_trigger_range.contains_inclusive(file_position.offset), "diagnostic fix range {:?} does not touch cursor position {:?}", fix.fix_trigger_range, file_position.offset @@ -790,4 +831,100 @@ struct Foo { let diagnostics = analysis.diagnostics(&DiagnosticsConfig::default(), file_id).unwrap(); assert!(!diagnostics.is_empty()); } + + #[test] + fn test_rename_incorrect_case() { + check_fixes( + r#" +pub struct test_struct<|> { one: i32 } + +pub fn some_fn(val: test_struct) -> test_struct { + test_struct { one: val.one + 1 } +} +"#, + r#" +pub struct TestStruct { one: i32 } + +pub fn some_fn(val: TestStruct) -> TestStruct { + TestStruct { one: val.one + 1 } +} +"#, + ); + + check_fixes( + r#" +pub fn some_fn(NonSnakeCase<|>: u8) -> u8 { + NonSnakeCase +} +"#, + r#" +pub fn some_fn(non_snake_case: u8) -> u8 { + non_snake_case +} +"#, + ); + + check_fixes( + r#" +pub fn SomeFn<|>(val: u8) -> u8 { + if val != 0 { SomeFn(val - 1) } else { val } +} +"#, + r#" +pub fn some_fn(val: u8) -> u8 { + if val != 0 { some_fn(val - 1) } else { val } +} +"#, + ); + + check_fixes( + r#" +fn some_fn() { + let whatAWeird_Formatting<|> = 10; + another_func(whatAWeird_Formatting); +} +"#, + r#" +fn some_fn() { + let what_a_weird_formatting = 10; + another_func(what_a_weird_formatting); +} +"#, + ); + } + + #[test] + fn test_uppercase_const_no_diagnostics() { + check_no_diagnostics( + r#" +fn foo() { + const ANOTHER_ITEM<|>: &str = "some_item"; +} +"#, + ); + } + + #[test] + fn test_rename_incorrect_case_struct_method() { + check_fixes( + r#" +pub struct TestStruct; + +impl TestStruct { + pub fn SomeFn<|>() -> TestStruct { + TestStruct + } +} +"#, + r#" +pub struct TestStruct; + +impl TestStruct { + pub fn some_fn() -> TestStruct { + TestStruct + } +} +"#, + ); + } } diff --git a/crates/ide/src/diagnostics/fixes.rs b/crates/ide/src/diagnostics/fixes.rs index 68ae1c2398..0c75e50b01 100644 --- a/crates/ide/src/diagnostics/fixes.rs +++ b/crates/ide/src/diagnostics/fixes.rs @@ -3,7 +3,10 @@ use base_db::FileId; use hir::{ db::AstDatabase, - diagnostics::{Diagnostic, MissingFields, MissingOkInTailExpr, NoSuchField, UnresolvedModule}, + diagnostics::{ + Diagnostic, IncorrectCase, MissingFields, MissingOkInTailExpr, NoSuchField, + UnresolvedModule, + }, HasSource, HirDisplay, Semantics, VariantDef, }; use ide_db::{ @@ -17,7 +20,7 @@ use syntax::{ }; use text_edit::TextEdit; -use crate::diagnostics::Fix; +use crate::{diagnostics::Fix, references::rename::rename_with_semantics, FilePosition}; /// A [Diagnostic] that potentially has a fix available. /// @@ -99,6 +102,23 @@ impl DiagnosticWithFix for MissingOkInTailExpr { } } +impl DiagnosticWithFix for IncorrectCase { + fn fix(&self, sema: &Semantics) -> Option { + let root = sema.db.parse_or_expand(self.file)?; + let name_node = self.ident.to_node(&root); + + let file_id = self.file.original_file(sema.db); + let offset = name_node.syntax().text_range().start(); + let file_position = FilePosition { file_id, offset }; + + let rename_changes = + rename_with_semantics(sema, file_position, &self.suggested_text).ok()?; + + let label = format!("Rename to {}", self.suggested_text); + Some(Fix::new(&label, rename_changes.info, rename_changes.range)) + } +} + fn missing_record_expr_field_fix( sema: &Semantics, usage_file_id: FileId, diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs index f65a05ea33..88e2f2db3f 100644 --- a/crates/ide/src/references.rs +++ b/crates/ide/src/references.rs @@ -9,7 +9,7 @@ //! at the index that the match starts at and its tree parent is //! resolved to the search element definition, we get a reference. -mod rename; +pub(crate) mod rename; use hir::Semantics; use ide_db::{ diff --git a/crates/ide/src/references/rename.rs b/crates/ide/src/references/rename.rs index f3b5cfc8c1..f9a11e43d8 100644 --- a/crates/ide/src/references/rename.rs +++ b/crates/ide/src/references/rename.rs @@ -42,7 +42,14 @@ pub(crate) fn rename( new_name: &str, ) -> Result, RenameError> { let sema = Semantics::new(db); + rename_with_semantics(&sema, position, new_name) +} +pub(crate) fn rename_with_semantics( + sema: &Semantics, + position: FilePosition, + new_name: &str, +) -> Result, RenameError> { match lex_single_syntax_kind(new_name) { Some(res) => match res { (SyntaxKind::IDENT, _) => (), diff --git a/crates/stdx/src/lib.rs b/crates/stdx/src/lib.rs index 011935cade..59d89f47d1 100644 --- a/crates/stdx/src/lib.rs +++ b/crates/stdx/src/lib.rs @@ -28,20 +28,32 @@ pub fn timeit(label: &'static str) -> impl Drop { Guard { label, start: Instant::now() } } -pub fn to_lower_snake_case(s: &str) -> String { +fn to_snake_case char>(s: &str, change_case: F) -> String { let mut buf = String::with_capacity(s.len()); let mut prev = false; for c in s.chars() { + // `&& prev` is required to not insert `_` before the first symbol. if c.is_ascii_uppercase() && prev { - buf.push('_') + // This check is required to not translate `Weird_Case` into `weird__case`. + if !buf.ends_with('_') { + buf.push('_') + } } prev = true; - buf.push(c.to_ascii_lowercase()); + buf.push(change_case(&c)); } buf } +pub fn to_lower_snake_case(s: &str) -> String { + to_snake_case(s, char::to_ascii_lowercase) +} + +pub fn to_upper_snake_case(s: &str) -> String { + to_snake_case(s, char::to_ascii_uppercase) +} + pub fn replace(buf: &mut String, from: char, to: &str) { if !buf.contains(from) { return; diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 3a184094c8..74dbdfaf7b 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -320,6 +320,10 @@ pub fn param(name: String, ty: String) -> ast::Param { ast_from_text(&format!("fn f({}: {}) {{ }}", name, ty)) } +pub fn ret_type(ty: ast::Type) -> ast::RetType { + ast_from_text(&format!("fn f() -> {} {{ }}", ty)) +} + pub fn param_list(pats: impl IntoIterator) -> ast::ParamList { let args = pats.into_iter().join(", "); ast_from_text(&format!("fn f({}) {{ }}", args)) @@ -350,14 +354,20 @@ pub fn fn_( type_params: Option, params: ast::ParamList, body: ast::BlockExpr, + ret_type: Option, ) -> ast::Fn { let type_params = if let Some(type_params) = type_params { format!("<{}>", type_params) } else { "".into() }; + let ret_type = if let Some(ret_type) = ret_type { format!("{} ", ret_type) } else { "".into() }; let visibility = match visibility { None => String::new(), Some(it) => format!("{} ", it), }; - ast_from_text(&format!("{}fn {}{}{} {}", visibility, fn_name, type_params, params, body)) + + ast_from_text(&format!( + "{}fn {}{}{} {}{}", + visibility, fn_name, type_params, params, ret_type, body + )) } fn ast_from_text(text: &str) -> N {