From d00a285fa757307bbe0f8dac9e49ac247cf9dab1 Mon Sep 17 00:00:00 2001 From: Phil Ellison Date: Sat, 10 Aug 2019 16:40:48 +0100 Subject: [PATCH] Initial implementation of Ok-wrapping --- crates/ra_hir/src/diagnostics.rs | 31 +++++++++++++++ crates/ra_hir/src/expr/validation.rs | 56 +++++++++++++++++++++++++++- crates/ra_hir/src/ty.rs | 2 +- crates/ra_ide_api/src/diagnostics.rs | 50 +++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/crates/ra_hir/src/diagnostics.rs b/crates/ra_hir/src/diagnostics.rs index 301109cb8d..718345d754 100644 --- a/crates/ra_hir/src/diagnostics.rs +++ b/crates/ra_hir/src/diagnostics.rs @@ -143,3 +143,34 @@ impl AstDiagnostic for MissingFields { ast::RecordFieldList::cast(node).unwrap() } } + +#[derive(Debug)] +pub struct MissingOkInTailExpr { + pub file: HirFileId, + pub expr: AstPtr, +} + +impl Diagnostic for MissingOkInTailExpr { + fn message(&self) -> String { + "wrap return expression in Ok".to_string() + } + fn file(&self) -> HirFileId { + self.file + } + fn syntax_node_ptr(&self) -> SyntaxNodePtr { + self.expr.into() + } + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } +} + +impl AstDiagnostic for MissingOkInTailExpr { + type AST = ast::Expr; + + fn ast(&self, db: &impl HirDatabase) -> Self::AST { + let root = db.parse_or_expand(self.file()).unwrap(); + let node = self.syntax_node_ptr().to_node(&root); + ast::Expr::cast(node).unwrap() + } +} diff --git a/crates/ra_hir/src/expr/validation.rs b/crates/ra_hir/src/expr/validation.rs index 62f7d41f5d..f5e641557d 100644 --- a/crates/ra_hir/src/expr/validation.rs +++ b/crates/ra_hir/src/expr/validation.rs @@ -6,11 +6,12 @@ use ra_syntax::ast::{AstNode, RecordLit}; use super::{Expr, ExprId, RecordLitField}; use crate::{ adt::AdtDef, - diagnostics::{DiagnosticSink, MissingFields}, + diagnostics::{DiagnosticSink, MissingFields, MissingOkInTailExpr}, expr::AstPtr, - ty::InferenceResult, + ty::{InferenceResult, Ty, TypeCtor}, Function, HasSource, HirDatabase, Name, Path, }; +use ra_syntax::ast; pub(crate) struct ExprValidator<'a, 'b: 'a> { func: Function, @@ -29,11 +30,23 @@ impl<'a, 'b> ExprValidator<'a, 'b> { pub(crate) fn validate_body(&mut self, db: &impl HirDatabase) { let body = self.func.body(db); + + // The final expr in the function body is the whole body, + // so the expression being returned is the penultimate expr. + let mut penultimate_expr = None; + let mut final_expr = None; + for e in body.exprs() { + penultimate_expr = final_expr; + final_expr = Some(e); + if let (id, Expr::RecordLit { path, fields, spread }) = e { self.validate_record_literal(id, path, fields, *spread, db); } } + if let Some(e) = penultimate_expr { + self.validate_results_in_tail_expr(e.0, db); + } } fn validate_record_literal( @@ -87,4 +100,43 @@ impl<'a, 'b> ExprValidator<'a, 'b> { }) } } + + fn validate_results_in_tail_expr(&mut self, id: ExprId, db: &impl HirDatabase) { + let expr_ty = &self.infer[id]; + let func_ty = self.func.ty(db); + let func_sig = func_ty.callable_sig(db).unwrap(); + let ret = func_sig.ret(); + let ret = match ret { + Ty::Apply(t) => t, + _ => return, + }; + let ret_enum = match ret.ctor { + TypeCtor::Adt(AdtDef::Enum(e)) => e, + _ => return, + }; + let enum_name = ret_enum.name(db); + if enum_name.is_none() || enum_name.unwrap().to_string() != "Result" { + return; + } + let params = &ret.parameters; + if params.len() == 2 && ¶ms[0] == expr_ty { + let source_map = self.func.body_source_map(db); + let file_id = self.func.source(db).file_id; + let parse = db.parse(file_id.original_file(db)); + let source_file = parse.tree(); + let expr_syntax = source_map.expr_syntax(id); + if expr_syntax.is_none() { + return; + } + let expr_syntax = expr_syntax.unwrap(); + let node = expr_syntax.to_node(source_file.syntax()); + let ast = ast::Expr::cast(node); + if ast.is_none() { + return; + } + let ast = ast.unwrap(); + + self.sink.push(MissingOkInTailExpr { file: file_id, expr: AstPtr::new(&ast) }); + } + } } diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs index b54c803185..0efd94cef2 100644 --- a/crates/ra_hir/src/ty.rs +++ b/crates/ra_hir/src/ty.rs @@ -516,7 +516,7 @@ impl Ty { } } - fn callable_sig(&self, db: &impl HirDatabase) -> Option { + pub fn callable_sig(&self, db: &impl HirDatabase) -> Option { match self { Ty::Apply(a_ty) => match a_ty.ctor { TypeCtor::FnPtr { .. } => Some(FnSig::from_fn_ptr_substs(&a_ty.parameters)), diff --git a/crates/ra_ide_api/src/diagnostics.rs b/crates/ra_ide_api/src/diagnostics.rs index c2b959cb3c..be51977679 100644 --- a/crates/ra_ide_api/src/diagnostics.rs +++ b/crates/ra_ide_api/src/diagnostics.rs @@ -75,6 +75,19 @@ pub(crate) fn diagnostics(db: &RootDatabase, file_id: FileId) -> Vec severity: Severity::Error, fix: Some(fix), }) + }) + .on::(|d| { + let node = d.ast(db); + let mut builder = TextEditBuilder::default(); + let replacement = format!("Ok({})", node.syntax().text()); + builder.replace(node.syntax().text_range(), replacement); + let fix = SourceChange::source_file_edit_from("wrap with ok", file_id, builder.finish()); + res.borrow_mut().push(Diagnostic { + range: d.highlight_range(), + message: d.message(), + severity: Severity::Error, + fix: Some(fix), + }) }); if let Some(m) = source_binder::module_from_file_id(db, file_id) { m.diagnostics(db, &mut sink); @@ -218,6 +231,43 @@ mod tests { assert_eq!(diagnostics.len(), 0); } + #[test] + fn test_wrap_return_type() { + let before = r#" + enum Result { Ok(T), Err(E) } + struct String { } + + fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err("div by zero".into()); + } + x / y + } + "#; + let after = r#" + enum Result { Ok(T), Err(E) } + struct String { } + + fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err("div by zero".into()); + } + Ok(x / y) + } + "#; + check_apply_diagnostic_fix(before, after); + } + + #[test] + fn test_wrap_return_type_not_applicable() { + let content = r#" + fn foo() -> Result { + 0 + } + "#; + check_no_diagnostic(content); + } + #[test] fn test_fill_struct_fields_empty() { let before = r"