diff --git a/crates/ra_hir/src/code_model.rs b/crates/ra_hir/src/code_model.rs index 131180a638..13e763e520 100644 --- a/crates/ra_hir/src/code_model.rs +++ b/crates/ra_hir/src/code_model.rs @@ -25,7 +25,7 @@ use hir_expand::{ use hir_ty::{ autoderef, display::{HirDisplayError, HirFormatter}, - expr::ExprValidator, + expr::{ExprValidator, UnsafeValidator}, method_resolution, ApplicationTy, Canonical, GenericPredicate, InEnvironment, Substs, TraitEnvironment, Ty, TyDefId, TypeCtor, }; @@ -36,7 +36,6 @@ use rustc_hash::FxHashSet; use crate::{ db::{DefDatabase, HirDatabase}, - diagnostics::UnsafeValidator, has_source::HasSource, CallableDef, HirDisplay, InFile, Name, }; @@ -680,7 +679,7 @@ impl Function { infer.add_diagnostics(db, self.id, sink); let mut validator = ExprValidator::new(self.id, infer.clone(), sink); validator.validate_body(db); - let mut validator = UnsafeValidator::new(&self, infer, sink); + let mut validator = UnsafeValidator::new(self.id, infer, sink); validator.validate_body(db); } } diff --git a/crates/ra_hir/src/diagnostics.rs b/crates/ra_hir/src/diagnostics.rs index 562f3fe5c4..c82883d0c1 100644 --- a/crates/ra_hir/src/diagnostics.rs +++ b/crates/ra_hir/src/diagnostics.rs @@ -2,53 +2,3 @@ pub use hir_def::diagnostics::UnresolvedModule; pub use hir_expand::diagnostics::{AstDiagnostic, Diagnostic, DiagnosticSink}; pub use hir_ty::diagnostics::{MissingFields, MissingMatchArms, MissingOkInTailExpr, NoSuchField}; - -use std::sync::Arc; - -use crate::code_model::Function; -use crate::db::HirDatabase; -use crate::has_source::HasSource; -use hir_ty::{ - diagnostics::{MissingUnsafe, UnnecessaryUnsafe}, - expr::unsafe_expressions, - InferenceResult, -}; -use ra_syntax::AstPtr; - -pub struct UnsafeValidator<'a, 'b: 'a> { - func: &'a Function, - infer: Arc, - sink: &'a mut DiagnosticSink<'b>, -} - -impl<'a, 'b> UnsafeValidator<'a, 'b> { - pub fn new( - func: &'a Function, - infer: Arc, - sink: &'a mut DiagnosticSink<'b>, - ) -> UnsafeValidator<'a, 'b> { - UnsafeValidator { func, infer, sink } - } - - pub fn validate_body(&mut self, db: &dyn HirDatabase) { - let def = self.func.id.into(); - let unsafe_expressions = unsafe_expressions(db, self.infer.as_ref(), def); - let func_data = db.function_data(self.func.id); - let unnecessary = func_data.is_unsafe && unsafe_expressions.len() == 0; - let missing = !func_data.is_unsafe && unsafe_expressions.len() > 0; - if !(unnecessary || missing) { - return; - } - - let in_file = self.func.source(db); - let file = in_file.file_id; - let fn_def = AstPtr::new(&in_file.value); - let fn_name = func_data.name.clone().into(); - - if unnecessary { - self.sink.push(UnnecessaryUnsafe { file, fn_def, fn_name }) - } else { - self.sink.push(MissingUnsafe { file, fn_def, fn_name }) - } - } -} diff --git a/crates/ra_hir_ty/src/diagnostics.rs b/crates/ra_hir_ty/src/diagnostics.rs index 3469cc6806..c6ca322fa3 100644 --- a/crates/ra_hir_ty/src/diagnostics.rs +++ b/crates/ra_hir_ty/src/diagnostics.rs @@ -3,7 +3,7 @@ use std::any::Any; use hir_expand::{db::AstDatabase, name::Name, HirFileId, InFile}; -use ra_syntax::{ast, AstNode, AstPtr, SyntaxNodePtr}; +use ra_syntax::{ast::{self, NameOwner}, AstNode, AstPtr, SyntaxNodePtr}; use stdx::format_to; pub use hir_def::{diagnostics::UnresolvedModule, expr::MatchArm, path::Path}; @@ -174,12 +174,11 @@ impl AstDiagnostic for BreakOutsideOfLoop { pub struct MissingUnsafe { pub file: HirFileId, pub fn_def: AstPtr, - pub fn_name: Name, } impl Diagnostic for MissingUnsafe { fn message(&self) -> String { - format!("Missing unsafe marker on fn `{}`", self.fn_name) + format!("Missing unsafe keyword on fn") } fn source(&self) -> InFile { InFile { file_id: self.file, value: self.fn_def.clone().into() } @@ -190,12 +189,12 @@ impl Diagnostic for MissingUnsafe { } impl AstDiagnostic for MissingUnsafe { - type AST = ast::FnDef; + type AST = ast::Name; fn ast(&self, db: &impl AstDatabase) -> Self::AST { let root = db.parse_or_expand(self.source().file_id).unwrap(); let node = self.source().value.to_node(&root); - ast::FnDef::cast(node).unwrap() + ast::FnDef::cast(node).unwrap().name().unwrap() } } @@ -203,12 +202,11 @@ impl AstDiagnostic for MissingUnsafe { pub struct UnnecessaryUnsafe { pub file: HirFileId, pub fn_def: AstPtr, - pub fn_name: Name, } impl Diagnostic for UnnecessaryUnsafe { fn message(&self) -> String { - format!("Unnecessary unsafe marker on fn `{}`", self.fn_name) + format!("Unnecessary unsafe keyword on fn") } fn source(&self) -> InFile { InFile { file_id: self.file, value: self.fn_def.clone().into() } @@ -219,11 +217,11 @@ impl Diagnostic for UnnecessaryUnsafe { } impl AstDiagnostic for UnnecessaryUnsafe { - type AST = ast::FnDef; + type AST = ast::Name; fn ast(&self, db: &impl AstDatabase) -> Self::AST { let root = db.parse_or_expand(self.source().file_id).unwrap(); let node = self.source().value.to_node(&root); - ast::FnDef::cast(node).unwrap() + ast::FnDef::cast(node).unwrap().name().unwrap() } } diff --git a/crates/ra_hir_ty/src/expr.rs b/crates/ra_hir_ty/src/expr.rs index 795f1762c5..7532e2dc7c 100644 --- a/crates/ra_hir_ty/src/expr.rs +++ b/crates/ra_hir_ty/src/expr.rs @@ -2,14 +2,19 @@ use std::sync::Arc; -use hir_def::{path::path, resolver::HasResolver, AdtId, DefWithBodyId, FunctionId}; +use hir_def::{ + path::path, resolver::HasResolver, src::HasSource, AdtId, DefWithBodyId, FunctionId, Lookup, +}; use hir_expand::diagnostics::DiagnosticSink; use ra_syntax::{ast, AstPtr}; use rustc_hash::FxHashSet; use crate::{ db::HirDatabase, - diagnostics::{MissingFields, MissingMatchArms, MissingOkInTailExpr, MissingPatFields}, + diagnostics::{ + MissingFields, MissingMatchArms, MissingOkInTailExpr, MissingPatFields, MissingUnsafe, + UnnecessaryUnsafe, + }, utils::variant_data, ApplicationTy, InferenceResult, Ty, TypeCtor, _match::{is_useful, MatchCheckCtx, Matrix, PatStack, Usefulness}, @@ -321,16 +326,63 @@ pub fn unsafe_expressions( let mut unsafe_expr_ids = vec![]; let body = db.body(def); for (id, expr) in body.exprs.iter() { - if let Expr::Call { callee, .. } = expr { - if infer - .method_resolution(*callee) - .map(|func| db.function_data(func).is_unsafe) - .unwrap_or(false) - { - unsafe_expr_ids.push(id); + match expr { + Expr::Call { callee, .. } => { + if infer + .method_resolution(*callee) + .map(|func| db.function_data(func).is_unsafe) + .unwrap_or(false) + { + unsafe_expr_ids.push(id); + } } + Expr::UnaryOp { expr, op: UnaryOp::Deref } => { + if let Ty::Apply(ApplicationTy { ctor: TypeCtor::RawPtr(..), .. }) = &infer[*expr] { + unsafe_expr_ids.push(id); + } + } + _ => {} } } unsafe_expr_ids } + +pub struct UnsafeValidator<'a, 'b: 'a> { + func: FunctionId, + infer: Arc, + sink: &'a mut DiagnosticSink<'b>, +} + +impl<'a, 'b> UnsafeValidator<'a, 'b> { + pub fn new( + func: FunctionId, + infer: Arc, + sink: &'a mut DiagnosticSink<'b>, + ) -> UnsafeValidator<'a, 'b> { + UnsafeValidator { func, infer, sink } + } + + pub fn validate_body(&mut self, db: &dyn HirDatabase) { + let def = self.func.into(); + let unsafe_expressions = unsafe_expressions(db, self.infer.as_ref(), def); + let func_data = db.function_data(self.func); + let unnecessary = func_data.is_unsafe && unsafe_expressions.len() == 0; + let missing = !func_data.is_unsafe && unsafe_expressions.len() > 0; + if !(unnecessary || missing) { + return; + } + + let loc = self.func.lookup(db.upcast()); + let in_file = loc.source(db.upcast()); + + let file = in_file.file_id; + let fn_def = AstPtr::new(&in_file.value); + + if unnecessary { + self.sink.push(UnnecessaryUnsafe { file, fn_def }) + } else { + self.sink.push(MissingUnsafe { file, fn_def }) + } + } +} diff --git a/crates/ra_hir_ty/src/test_db.rs b/crates/ra_hir_ty/src/test_db.rs index ad04e3e0f9..9ccf2aa377 100644 --- a/crates/ra_hir_ty/src/test_db.rs +++ b/crates/ra_hir_ty/src/test_db.rs @@ -11,7 +11,11 @@ use ra_db::{salsa, CrateId, FileId, FileLoader, FileLoaderDelegate, SourceDataba use rustc_hash::FxHashSet; use stdx::format_to; -use crate::{db::HirDatabase, diagnostics::Diagnostic, expr::ExprValidator}; +use crate::{ + db::HirDatabase, + diagnostics::Diagnostic, + expr::{ExprValidator, UnsafeValidator}, +}; #[salsa::database( ra_db::SourceDatabaseExtStorage, @@ -119,7 +123,9 @@ impl TestDB { let infer = self.infer(f.into()); let mut sink = DiagnosticSink::new(&mut cb); infer.add_diagnostics(self, f, &mut sink); - let mut validator = ExprValidator::new(f, infer, &mut sink); + let mut validator = ExprValidator::new(f, infer.clone(), &mut sink); + validator.validate_body(self); + let mut validator = UnsafeValidator::new(f, infer, &mut sink); validator.validate_body(self); } } diff --git a/crates/ra_hir_ty/src/tests.rs b/crates/ra_hir_ty/src/tests.rs index 85ff26a368..4ff2b2d4a5 100644 --- a/crates/ra_hir_ty/src/tests.rs +++ b/crates/ra_hir_ty/src/tests.rs @@ -538,6 +538,84 @@ fn missing_record_pat_field_no_diagnostic_if_not_exhaustive() { assert_snapshot!(diagnostics, @""); } +#[test] +fn missing_unsafe_diagnostic_with_raw_ptr() { + let diagnostics = TestDB::with_files( + r" +//- /lib.rs +fn missing_unsafe() { + let x = &5 as *usize; + let y = *x; +} +", + ) + .diagnostics() + .0; + + assert_snapshot!(diagnostics, @r#""fn missing_unsafe() {\n let x = &5 as *usize;\n let y = *x;\n}": Missing unsafe keyword on fn"#); +} + +#[test] +fn missing_unsafe_diagnostic_with_unsafe_call() { + let diagnostics = TestDB::with_files( + r" +//- /lib.rs +unsafe fn unsafe_fn() { + let x = &5 as *usize; + let y = *x; +} + +fn missing_unsafe() { + unsafe_fn(); +} +", + ) + .diagnostics() + .0; + + assert_snapshot!(diagnostics, @r#""fn missing_unsafe() {\n unsafe_fn();\n}": Missing unsafe keyword on fn"#); +} + +#[test] +fn missing_unsafe_diagnostic_with_unsafe_method_call() { + let diagnostics = TestDB::with_files( + r" +//- /lib.rs +struct HasUnsafe; + +impl HasUnsafe { + unsafe fn unsafe_fn() { + let x = &5 as *usize; + let y = *x; + } +} + +fn missing_unsafe() { + HasUnsafe.unsafe_fn(); +} + +", + ) + .diagnostics() + .0; + + assert_snapshot!(diagnostics, @r#""fn missing_unsafe() {\n HasUnsafe.unsafe_fn();\n}": Missing unsafe keyword on fn"#); +} + +#[test] +fn unnecessary_unsafe_diagnostic() { + let diagnostics = TestDB::with_files( + r" +//- /lib.rs +unsafe fn actually_safe_fn() {} +", + ) + .diagnostics() + .0; + + assert_snapshot!(diagnostics, @r#""unsafe fn actually_safe_fn() {}": Unnecessary unsafe keyword on fn"#); +} + #[test] fn break_outside_of_loop() { let diagnostics = TestDB::with_files( diff --git a/crates/ra_ide/src/syntax_highlighting/tests.rs b/crates/ra_ide/src/syntax_highlighting/tests.rs index 39cd74ac3a..43f554a292 100644 --- a/crates/ra_ide/src/syntax_highlighting/tests.rs +++ b/crates/ra_ide/src/syntax_highlighting/tests.rs @@ -384,9 +384,11 @@ impl HasUnsafeFn { } fn main() { + let x = &5 as *usize; unsafe { unsafe_fn(); HasUnsafeFn.unsafe_method(); + let y = *x; } } "#