From 6615fda695dfe6cb2c6792eed5d472b43ac63423 Mon Sep 17 00:00:00 2001 From: Jonas Schievink Date: Thu, 17 Dec 2020 01:19:56 +0100 Subject: [PATCH] decl_check: don't pass `db` around so often --- crates/hir/src/code_model.rs | 10 +- crates/hir_ty/src/diagnostics.rs | 15 ++- crates/hir_ty/src/diagnostics/decl_check.rs | 103 ++++++++++---------- 3 files changed, 68 insertions(+), 60 deletions(-) diff --git a/crates/hir/src/code_model.rs b/crates/hir/src/code_model.rs index 9bfcd215a5..b1a1bc3374 100644 --- a/crates/hir/src/code_model.rs +++ b/crates/hir/src/code_model.rs @@ -267,7 +267,12 @@ impl ModuleDef { _ => return, }; - hir_ty::diagnostics::validate_module_item(db, id, sink) + let module = match self.module(db) { + Some(it) => it, + None => return, + }; + + hir_ty::diagnostics::validate_module_item(db, module.id.krate, id, sink) } } @@ -780,8 +785,9 @@ impl Function { } pub fn diagnostics(self, db: &dyn HirDatabase, sink: &mut DiagnosticSink) { + let krate = self.module(db).id.krate; hir_def::diagnostics::validate_body(db.upcast(), self.id.into(), sink); - hir_ty::diagnostics::validate_module_item(db, self.id.into(), sink); + hir_ty::diagnostics::validate_module_item(db, krate, self.id.into(), sink); hir_ty::diagnostics::validate_body(db, self.id.into(), sink); } diff --git a/crates/hir_ty/src/diagnostics.rs b/crates/hir_ty/src/diagnostics.rs index 1c72f766ea..14e18f5a11 100644 --- a/crates/hir_ty/src/diagnostics.rs +++ b/crates/hir_ty/src/diagnostics.rs @@ -6,6 +6,7 @@ mod decl_check; use std::{any::Any, fmt}; +use base_db::CrateId; use hir_def::{DefWithBodyId, ModuleDefId}; use hir_expand::diagnostics::{Diagnostic, DiagnosticCode, DiagnosticSink}; use hir_expand::{name::Name, HirFileId, InFile}; @@ -18,12 +19,13 @@ pub use crate::diagnostics::expr::{record_literal_missing_fields, record_pattern pub fn validate_module_item( db: &dyn HirDatabase, + krate: CrateId, owner: ModuleDefId, sink: &mut DiagnosticSink<'_>, ) { let _p = profile::span("validate_module_item"); - let mut validator = decl_check::DeclValidator::new(owner, sink); - validator.validate_item(db); + let mut validator = decl_check::DeclValidator::new(db, krate, sink); + validator.validate_item(owner); } pub fn validate_body(db: &dyn HirDatabase, owner: DefWithBodyId, sink: &mut DiagnosticSink<'_>) { @@ -407,7 +409,7 @@ mod tests { for (module_id, _) in crate_def_map.modules.iter() { for decl in crate_def_map[module_id].scope.declarations() { let mut sink = DiagnosticSinkBuilder::new().build(&mut cb); - validate_module_item(self, decl, &mut sink); + validate_module_item(self, krate, decl, &mut sink); if let ModuleDefId::FunctionId(f) = decl { fns.push(f) @@ -419,7 +421,12 @@ mod tests { for item in impl_data.items.iter() { if let AssocItemId::FunctionId(f) = item { let mut sink = DiagnosticSinkBuilder::new().build(&mut cb); - validate_module_item(self, ModuleDefId::FunctionId(*f), &mut sink); + validate_module_item( + self, + krate, + ModuleDefId::FunctionId(*f), + &mut sink, + ); fns.push(*f) } } diff --git a/crates/hir_ty/src/diagnostics/decl_check.rs b/crates/hir_ty/src/diagnostics/decl_check.rs index 25587e1168..eaeb6899f1 100644 --- a/crates/hir_ty/src/diagnostics/decl_check.rs +++ b/crates/hir_ty/src/diagnostics/decl_check.rs @@ -12,6 +12,7 @@ mod case_conv; +use base_db::CrateId; use hir_def::{ adt::VariantData, expr::{Pat, PatId}, @@ -40,7 +41,8 @@ mod allow { } pub(super) struct DeclValidator<'a, 'b: 'a> { - owner: ModuleDefId, + db: &'a dyn HirDatabase, + krate: CrateId, sink: &'a mut DiagnosticSink<'b>, } @@ -53,26 +55,27 @@ struct Replacement { impl<'a, 'b> DeclValidator<'a, 'b> { pub(super) fn new( - owner: ModuleDefId, + db: &'a dyn HirDatabase, + krate: CrateId, sink: &'a mut DiagnosticSink<'b>, ) -> DeclValidator<'a, 'b> { - DeclValidator { owner, sink } + DeclValidator { db, krate, sink } } - pub(super) fn validate_item(&mut self, db: &dyn HirDatabase) { - match self.owner { - ModuleDefId::FunctionId(func) => self.validate_func(db, func), - ModuleDefId::AdtId(adt) => self.validate_adt(db, adt), - ModuleDefId::ConstId(const_id) => self.validate_const(db, const_id), - ModuleDefId::StaticId(static_id) => self.validate_static(db, static_id), + pub(super) fn validate_item(&mut self, item: ModuleDefId) { + match item { + ModuleDefId::FunctionId(func) => self.validate_func(func), + ModuleDefId::AdtId(adt) => self.validate_adt(adt), + ModuleDefId::ConstId(const_id) => self.validate_const(const_id), + ModuleDefId::StaticId(static_id) => self.validate_static(static_id), _ => return, } } - fn validate_adt(&mut self, db: &dyn HirDatabase, adt: AdtId) { + fn validate_adt(&mut self, adt: AdtId) { match adt { - AdtId::StructId(struct_id) => self.validate_struct(db, struct_id), - AdtId::EnumId(enum_id) => self.validate_enum(db, enum_id), + AdtId::StructId(struct_id) => self.validate_struct(struct_id), + AdtId::EnumId(enum_id) => self.validate_enum(enum_id), AdtId::UnionId(_) => { // Unions aren't yet supported by this validator. } @@ -82,27 +85,27 @@ impl<'a, 'b> DeclValidator<'a, 'b> { /// Checks whether not following the convention is allowed for this item. /// /// Currently this method doesn't check parent attributes. - fn allowed(&self, db: &dyn HirDatabase, id: AttrDefId, allow_name: &str) -> bool { - db.attrs(id).by_key("allow").tt_values().any(|tt| tt.to_string().contains(allow_name)) + fn allowed(&self, id: AttrDefId, allow_name: &str) -> bool { + self.db.attrs(id).by_key("allow").tt_values().any(|tt| tt.to_string().contains(allow_name)) } - fn validate_func(&mut self, db: &dyn HirDatabase, func: FunctionId) { - let data = db.function_data(func); + fn validate_func(&mut self, func: FunctionId) { + let data = self.db.function_data(func); if data.is_extern { mark::hit!(extern_func_incorrect_case_ignored); return; } - let body = db.body(func.into()); + let body = self.db.body(func.into()); // Recursively validate inner scope items, such as static variables and constants. for (item_id, _) in body.item_scope.values() { - let mut validator = DeclValidator::new(item_id, self.sink); - validator.validate_item(db); + let mut validator = DeclValidator::new(self.db, self.krate, self.sink); + validator.validate_item(item_id); } // Check whether non-snake case identifiers are allowed for this function. - if self.allowed(db, func.into(), allow::NON_SNAKE_CASE) { + if self.allowed(func.into(), allow::NON_SNAKE_CASE) { return; } @@ -169,11 +172,10 @@ impl<'a, 'b> DeclValidator<'a, 'b> { // If there is at least one element to spawn a warning on, go to the source map and generate a warning. self.create_incorrect_case_diagnostic_for_func( func, - db, fn_name_replacement, fn_param_replacements, ); - self.create_incorrect_case_diagnostic_for_variables(func, db, pats_replacements); + self.create_incorrect_case_diagnostic_for_variables(func, pats_replacements); } /// Given the information about incorrect names in the function declaration, looks up into the source code @@ -181,7 +183,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> { fn create_incorrect_case_diagnostic_for_func( &mut self, func: FunctionId, - db: &dyn HirDatabase, fn_name_replacement: Option, fn_param_replacements: Vec, ) { @@ -190,8 +191,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> { return; } - let fn_loc = func.lookup(db.upcast()); - let fn_src = fn_loc.source(db.upcast()); + let fn_loc = func.lookup(self.db.upcast()); + let fn_src = fn_loc.source(self.db.upcast()); // Diagnostic for function name. if let Some(replacement) = fn_name_replacement { @@ -282,7 +283,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> { fn create_incorrect_case_diagnostic_for_variables( &mut self, func: FunctionId, - db: &dyn HirDatabase, pats_replacements: Vec<(PatId, Replacement)>, ) { // XXX: only look at source_map if we do have missing fields @@ -290,12 +290,12 @@ impl<'a, 'b> DeclValidator<'a, 'b> { return; } - let (_, source_map) = db.body_with_source_map(func.into()); + let (_, source_map) = self.db.body_with_source_map(func.into()); for (id, replacement) in pats_replacements { if let Ok(source_ptr) = source_map.pat_syntax(id) { if let Some(expr) = source_ptr.value.as_ref().left() { - let root = source_ptr.file_syntax(db.upcast()); + let root = source_ptr.file_syntax(self.db.upcast()); if let ast::Pat::IdentPat(ident_pat) = expr.to_node(&root) { let parent = match ident_pat.syntax().parent() { Some(parent) => parent, @@ -333,12 +333,11 @@ impl<'a, 'b> DeclValidator<'a, 'b> { } } - fn validate_struct(&mut self, db: &dyn HirDatabase, struct_id: StructId) { - let data = db.struct_data(struct_id); + fn validate_struct(&mut self, struct_id: StructId) { + let data = self.db.struct_data(struct_id); - let non_camel_case_allowed = - self.allowed(db, struct_id.into(), allow::NON_CAMEL_CASE_TYPES); - let non_snake_case_allowed = self.allowed(db, struct_id.into(), allow::NON_SNAKE_CASE); + let non_camel_case_allowed = self.allowed(struct_id.into(), allow::NON_CAMEL_CASE_TYPES); + let non_snake_case_allowed = self.allowed(struct_id.into(), allow::NON_SNAKE_CASE); // Check the structure name. let struct_name = data.name.to_string(); @@ -379,7 +378,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> { // If there is at least one element to spawn a warning on, go to the source map and generate a warning. self.create_incorrect_case_diagnostic_for_struct( struct_id, - db, struct_name_replacement, struct_fields_replacements, ); @@ -390,7 +388,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> { fn create_incorrect_case_diagnostic_for_struct( &mut self, struct_id: StructId, - db: &dyn HirDatabase, struct_name_replacement: Option, struct_fields_replacements: Vec, ) { @@ -399,8 +396,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> { return; } - let struct_loc = struct_id.lookup(db.upcast()); - let struct_src = struct_loc.source(db.upcast()); + let struct_loc = struct_id.lookup(self.db.upcast()); + let struct_src = struct_loc.source(self.db.upcast()); if let Some(replacement) = struct_name_replacement { let ast_ptr = match struct_src.value.name() { @@ -473,11 +470,11 @@ impl<'a, 'b> DeclValidator<'a, 'b> { } } - fn validate_enum(&mut self, db: &dyn HirDatabase, enum_id: EnumId) { - let data = db.enum_data(enum_id); + fn validate_enum(&mut self, enum_id: EnumId) { + let data = self.db.enum_data(enum_id); // Check whether non-camel case names are allowed for this enum. - if self.allowed(db, enum_id.into(), allow::NON_CAMEL_CASE_TYPES) { + if self.allowed(enum_id.into(), allow::NON_CAMEL_CASE_TYPES) { return; } @@ -512,7 +509,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> { // If there is at least one element to spawn a warning on, go to the source map and generate a warning. self.create_incorrect_case_diagnostic_for_enum( enum_id, - db, enum_name_replacement, enum_fields_replacements, ) @@ -523,7 +519,6 @@ impl<'a, 'b> DeclValidator<'a, 'b> { fn create_incorrect_case_diagnostic_for_enum( &mut self, enum_id: EnumId, - db: &dyn HirDatabase, enum_name_replacement: Option, enum_variants_replacements: Vec, ) { @@ -532,8 +527,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> { return; } - let enum_loc = enum_id.lookup(db.upcast()); - let enum_src = enum_loc.source(db.upcast()); + let enum_loc = enum_id.lookup(self.db.upcast()); + let enum_src = enum_loc.source(self.db.upcast()); if let Some(replacement) = enum_name_replacement { let ast_ptr = match enum_src.value.name() { @@ -608,10 +603,10 @@ impl<'a, 'b> DeclValidator<'a, 'b> { } } - fn validate_const(&mut self, db: &dyn HirDatabase, const_id: ConstId) { - let data = db.const_data(const_id); + fn validate_const(&mut self, const_id: ConstId) { + let data = self.db.const_data(const_id); - if self.allowed(db, const_id.into(), allow::NON_UPPER_CASE_GLOBAL) { + if self.allowed(const_id.into(), allow::NON_UPPER_CASE_GLOBAL) { return; } @@ -632,8 +627,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> { return; }; - let const_loc = const_id.lookup(db.upcast()); - let const_src = const_loc.source(db.upcast()); + let const_loc = const_id.lookup(self.db.upcast()); + let const_src = const_loc.source(self.db.upcast()); let ast_ptr = match const_src.value.name() { Some(name) => name, @@ -652,14 +647,14 @@ impl<'a, 'b> DeclValidator<'a, 'b> { self.sink.push(diagnostic); } - fn validate_static(&mut self, db: &dyn HirDatabase, static_id: StaticId) { - let data = db.static_data(static_id); + fn validate_static(&mut self, static_id: StaticId) { + let data = self.db.static_data(static_id); if data.is_extern { mark::hit!(extern_static_incorrect_case_ignored); return; } - if self.allowed(db, static_id.into(), allow::NON_UPPER_CASE_GLOBAL) { + if self.allowed(static_id.into(), allow::NON_UPPER_CASE_GLOBAL) { return; } @@ -680,8 +675,8 @@ impl<'a, 'b> DeclValidator<'a, 'b> { return; }; - let static_loc = static_id.lookup(db.upcast()); - let static_src = static_loc.source(db.upcast()); + let static_loc = static_id.lookup(self.db.upcast()); + let static_src = static_loc.source(self.db.upcast()); let ast_ptr = match static_src.value.name() { Some(name) => name,