diff --git a/crates/hir-def/src/body.rs b/crates/hir-def/src/body.rs index 545d2bebf5..c6c1849003 100644 --- a/crates/hir-def/src/body.rs +++ b/crates/hir-def/src/body.rs @@ -422,6 +422,13 @@ impl Body { } } + pub fn walk_child_bindings(&self, pat: PatId, f: &mut impl FnMut(BindingId)) { + if let Pat::Bind { id, .. } = self[pat] { + f(id) + } + self[pat].walk_child_pats(|p| self.walk_child_bindings(p, f)); + } + pub fn pretty_print(&self, db: &dyn DefDatabase, owner: DefWithBodyId) -> String { pretty::print_body_hir(db, self, owner) } diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index f05688aa55..0f0e68a560 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -103,6 +103,22 @@ fn references() { "#, 5, ); + check_number( + r#" + struct Foo(i32); + impl Foo { + fn method(&mut self, x: i32) { + self.0 = 2 * self.0 + x; + } + } + const GOAL: i32 = { + let mut x = Foo(3); + x.method(5); + x.0 + }; + "#, + 11, + ); } #[test] @@ -358,7 +374,7 @@ fn ifs() { if a < b { b } else { a } } - const GOAL: u8 = max(max(1, max(10, 3)), 0-122); + const GOAL: i32 = max(max(1, max(10, 3)), 0-122); "#, 10, ); @@ -366,7 +382,7 @@ fn ifs() { check_number( r#" const fn max(a: &i32, b: &i32) -> &i32 { - if a < b { b } else { a } + if *a < *b { b } else { a } } const GOAL: i32 = *max(max(&1, max(&10, &3)), &5); @@ -464,6 +480,16 @@ fn tuples() { "#, 20, ); + check_number( + r#" + const GOAL: u8 = { + let mut a = (10, 20, 3, 15); + a.1 = 2; + a.0 + a.1 + a.2 + a.3 + }; + "#, + 30, + ); check_number( r#" struct TupleLike(i32, u8, i64, u16); @@ -539,7 +565,7 @@ fn let_else() { let Some(x) = x else { return 10 }; 2 * x } - const GOAL: u8 = f(Some(1000)) + f(None); + const GOAL: i32 = f(Some(1000)) + f(None); "#, 2010, ); @@ -615,7 +641,7 @@ fn options() { 0 } } - const GOAL: u8 = f(Some(Some(10))) + f(Some(None)) + f(None); + const GOAL: i32 = f(Some(Some(10))) + f(Some(None)) + f(None); "#, 11, ); @@ -746,24 +772,24 @@ fn enums() { r#" enum E { F1 = 1, - F2 = 2 * E::F1 as u8, - F3 = 3 * E::F2 as u8, + F2 = 2 * E::F1 as isize, // Rustc expects an isize here + F3 = 3 * E::F2 as isize, } - const GOAL: i32 = E::F3 as u8; + const GOAL: u8 = E::F3 as u8; "#, 6, ); check_number( r#" enum E { F1 = 1, F2, } - const GOAL: i32 = E::F2 as u8; + const GOAL: u8 = E::F2 as u8; "#, 2, ); check_number( r#" enum E { F1, } - const GOAL: i32 = E::F1 as u8; + const GOAL: u8 = E::F1 as u8; "#, 0, ); @@ -894,8 +920,22 @@ fn exec_limits() { } sum } - const GOAL: usize = f(10000); + const GOAL: i32 = f(10000); "#, 10000 * 10000, ); } + +#[test] +fn type_error() { + let e = eval_goal( + r#" + const GOAL: u8 = { + let x: u16 = 2; + let y: (u8, u8) = x; + y.0 + }; + "#, + ); + assert!(matches!(e, Err(ConstEvalError::MirLowerError(MirLowerError::TypeMismatch(_))))); +} diff --git a/crates/hir-ty/src/mir.rs b/crates/hir-ty/src/mir.rs index 140caad545..5d8a81a3ee 100644 --- a/crates/hir-ty/src/mir.rs +++ b/crates/hir-ty/src/mir.rs @@ -7,17 +7,19 @@ use crate::{ }; use chalk_ir::Mutability; use hir_def::{ - expr::{Expr, Ordering}, + expr::{BindingId, Expr, ExprId, Ordering, PatId}, DefWithBodyId, FieldId, UnionId, VariantId, }; -use la_arena::{Arena, Idx, RawIdx}; +use la_arena::{Arena, ArenaMap, Idx, RawIdx}; mod eval; mod lower; +pub mod borrowck; pub use eval::{interpret_mir, pad16, Evaluator, MirEvalError}; pub use lower::{lower_to_mir, mir_body_query, mir_body_recover, MirLowerError}; use smallvec::{smallvec, SmallVec}; +use stdx::impl_from; use super::consteval::{intern_const_scalar, try_const_usize}; @@ -181,6 +183,11 @@ impl SwitchTargets { iter::zip(&self.values, &self.targets).map(|(x, y)| (*x, *y)) } + /// Returns a slice with all possible jump targets (including the fallback target). + pub fn all_targets(&self) -> &[BasicBlockId] { + &self.targets + } + /// Finds the `BasicBlock` to which this `SwitchInt` will branch given the /// specific value. This cannot fail, as it'll return the `otherwise` /// branch if there's not a specific match for the value. @@ -758,7 +765,7 @@ pub enum Rvalue { } #[derive(Debug, PartialEq, Eq, Clone)] -pub enum Statement { +pub enum StatementKind { Assign(Place, Rvalue), //FakeRead(Box<(FakeReadCause, Place)>), //SetDiscriminant { @@ -773,6 +780,17 @@ pub enum Statement { //Intrinsic(Box), Nop, } +impl StatementKind { + fn with_span(self, span: MirSpan) -> Statement { + Statement { kind: self, span } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Statement { + pub kind: StatementKind, + pub span: MirSpan, +} #[derive(Debug, Default, PartialEq, Eq)] pub struct BasicBlock { @@ -803,6 +821,7 @@ pub struct MirBody { pub start_block: BasicBlockId, pub owner: DefWithBodyId, pub arg_count: usize, + pub binding_locals: ArenaMap, } impl MirBody {} @@ -810,3 +829,12 @@ impl MirBody {} fn const_as_usize(c: &Const) -> usize { try_const_usize(c).unwrap() as usize } + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum MirSpan { + ExprId(ExprId), + PatId(PatId), + Unknown, +} + +impl_from!(ExprId, PatId for MirSpan); diff --git a/crates/hir-ty/src/mir/borrowck.rs b/crates/hir-ty/src/mir/borrowck.rs new file mode 100644 index 0000000000..fcf9a67fe8 --- /dev/null +++ b/crates/hir-ty/src/mir/borrowck.rs @@ -0,0 +1,201 @@ +//! MIR borrow checker, which is used in diagnostics like `unused_mut` + +// Currently it is an ad-hoc implementation, only useful for mutability analysis. Feel free to remove all of these +// and implement a proper borrow checker. + +use la_arena::ArenaMap; +use stdx::never; + +use super::{ + BasicBlockId, BorrowKind, LocalId, MirBody, MirSpan, Place, ProjectionElem, Rvalue, + StatementKind, Terminator, +}; + +#[derive(Debug)] +pub enum Mutability { + Mut { span: MirSpan }, + Not, +} + +fn is_place_direct(lvalue: &Place) -> bool { + !lvalue.projection.iter().any(|x| *x == ProjectionElem::Deref) +} + +enum ProjectionCase { + /// Projection is a local + Direct, + /// Projection is some field or slice of a local + DirectPart, + /// Projection is deref of something + Indirect, +} + +fn place_case(lvalue: &Place) -> ProjectionCase { + let mut is_part_of = false; + for proj in lvalue.projection.iter().rev() { + match proj { + ProjectionElem::Deref => return ProjectionCase::Indirect, // It's indirect + ProjectionElem::ConstantIndex { .. } + | ProjectionElem::Subslice { .. } + | ProjectionElem::Field(_) + | ProjectionElem::TupleField(_) + | ProjectionElem::Index(_) => { + is_part_of = true; + } + ProjectionElem::OpaqueCast(_) => (), + } + } + if is_part_of { + ProjectionCase::DirectPart + } else { + ProjectionCase::Direct + } +} + +/// Returns a map from basic blocks to the set of locals that might be ever initialized before +/// the start of the block. Only `StorageDead` can remove something from this map, and we ignore +/// `Uninit` and `drop` and similars after initialization. +fn ever_initialized_map(body: &MirBody) -> ArenaMap> { + let mut result: ArenaMap> = + body.basic_blocks.iter().map(|x| (x.0, ArenaMap::default())).collect(); + fn dfs( + body: &MirBody, + b: BasicBlockId, + l: LocalId, + result: &mut ArenaMap>, + ) { + let mut is_ever_initialized = result[b][l]; // It must be filled, as we use it as mark for dfs + let block = &body.basic_blocks[b]; + for statement in &block.statements { + match &statement.kind { + StatementKind::Assign(p, _) => { + if p.projection.len() == 0 && p.local == l { + is_ever_initialized = true; + } + } + StatementKind::StorageDead(p) => { + if *p == l { + is_ever_initialized = false; + } + } + StatementKind::Deinit(_) | StatementKind::Nop | StatementKind::StorageLive(_) => (), + } + } + let Some(terminator) = &block.terminator else { + never!("Terminator should be none only in construction"); + return; + }; + let targets = match terminator { + Terminator::Goto { target } => vec![*target], + Terminator::SwitchInt { targets, .. } => targets.all_targets().to_vec(), + Terminator::Resume + | Terminator::Abort + | Terminator::Return + | Terminator::Unreachable => vec![], + Terminator::Call { target, cleanup, destination, .. } => { + if destination.projection.len() == 0 && destination.local == l { + is_ever_initialized = true; + } + target.into_iter().chain(cleanup.into_iter()).copied().collect() + } + Terminator::Drop { .. } + | Terminator::DropAndReplace { .. } + | Terminator::Assert { .. } + | Terminator::Yield { .. } + | Terminator::GeneratorDrop + | Terminator::FalseEdge { .. } + | Terminator::FalseUnwind { .. } => { + never!("We don't emit these MIR terminators yet"); + vec![] + } + }; + for target in targets { + if !result[target].contains_idx(l) || !result[target][l] && is_ever_initialized { + result[target].insert(l, is_ever_initialized); + dfs(body, target, l, result); + } + } + } + for (b, block) in body.basic_blocks.iter() { + for statement in &block.statements { + if let StatementKind::Assign(p, _) = &statement.kind { + if p.projection.len() == 0 { + let l = p.local; + if !result[b].contains_idx(l) { + result[b].insert(l, false); + dfs(body, b, l, &mut result); + } + } + } + } + } + result +} + +pub fn mutability_of_locals(body: &MirBody) -> ArenaMap { + let mut result: ArenaMap = + body.locals.iter().map(|x| (x.0, Mutability::Not)).collect(); + let ever_init_maps = ever_initialized_map(body); + for (block_id, ever_init_map) in ever_init_maps.iter() { + let mut ever_init_map = ever_init_map.clone(); + let block = &body.basic_blocks[block_id]; + for statement in &block.statements { + match &statement.kind { + StatementKind::Assign(place, value) => { + match place_case(place) { + ProjectionCase::Direct => { + if ever_init_map.get(place.local).copied().unwrap_or_default() { + result[place.local] = Mutability::Mut { span: statement.span }; + } else { + ever_init_map.insert(place.local, true); + } + } + ProjectionCase::DirectPart => { + // Partial initialization is not supported, so it is definitely `mut` + result[place.local] = Mutability::Mut { span: statement.span }; + } + ProjectionCase::Indirect => (), + } + if let Rvalue::Ref(BorrowKind::Mut { .. }, p) = value { + if is_place_direct(p) { + result[p.local] = Mutability::Mut { span: statement.span }; + } + } + } + StatementKind::StorageDead(p) => { + ever_init_map.insert(*p, false); + } + StatementKind::Deinit(_) | StatementKind::StorageLive(_) | StatementKind::Nop => (), + } + } + let Some(terminator) = &block.terminator else { + never!("Terminator should be none only in construction"); + continue; + }; + match terminator { + Terminator::Goto { .. } + | Terminator::Resume + | Terminator::Abort + | Terminator::Return + | Terminator::Unreachable + | Terminator::FalseEdge { .. } + | Terminator::FalseUnwind { .. } + | Terminator::GeneratorDrop + | Terminator::SwitchInt { .. } + | Terminator::Drop { .. } + | Terminator::DropAndReplace { .. } + | Terminator::Assert { .. } + | Terminator::Yield { .. } => (), + Terminator::Call { destination, .. } => { + if destination.projection.len() == 0 { + if ever_init_map.get(destination.local).copied().unwrap_or_default() { + result[destination.local] = Mutability::Mut { span: MirSpan::Unknown }; + } else { + ever_init_map.insert(destination.local, true); + } + } + } + } + } + result +} diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 1ec32010a1..245cfdb4dd 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -29,7 +29,7 @@ use crate::{ use super::{ const_as_usize, return_slot, AggregateKind, BinOp, CastKind, LocalId, MirBody, MirLowerError, - Operand, Place, ProjectionElem, Rvalue, Statement, Terminator, UnOp, + Operand, Place, ProjectionElem, Rvalue, StatementKind, Terminator, UnOp, }; pub struct Evaluator<'a> { @@ -395,7 +395,8 @@ impl Evaluator<'_> { .locals .iter() .map(|(id, x)| { - let size = self.size_of_sized(&x.ty, &locals, "no unsized local")?; + let size = + self.size_of_sized(&x.ty, &locals, "no unsized local in extending stack")?; let my_ptr = stack_ptr; stack_ptr += size; Ok((id, Stack(my_ptr))) @@ -425,16 +426,16 @@ impl Evaluator<'_> { return Err(MirEvalError::ExecutionLimitExceeded); } for statement in ¤t_block.statements { - match statement { - Statement::Assign(l, r) => { + match &statement.kind { + StatementKind::Assign(l, r) => { let addr = self.place_addr(l, &locals)?; let result = self.eval_rvalue(r, &locals)?.to_vec(&self)?; self.write_memory(addr, &result)?; } - Statement::Deinit(_) => not_supported!("de-init statement"), - Statement::StorageLive(_) => not_supported!("storage-live statement"), - Statement::StorageDead(_) => not_supported!("storage-dead statement"), - Statement::Nop => (), + StatementKind::Deinit(_) => not_supported!("de-init statement"), + StatementKind::StorageLive(_) + | StatementKind::StorageDead(_) + | StatementKind::Nop => (), } } let Some(terminator) = current_block.terminator.as_ref() else { diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 1bcdd3a505..73ae5eaeef 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -16,8 +16,9 @@ use hir_def::{ use la_arena::ArenaMap; use crate::{ - consteval::ConstEvalError, db::HirDatabase, layout::layout_of_ty, mapping::ToChalk, - utils::generics, Adjust, AutoBorrow, CallableDefId, TyBuilder, TyExt, + consteval::ConstEvalError, db::HirDatabase, infer::TypeMismatch, + inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, utils::generics, + Adjust, AutoBorrow, CallableDefId, TyBuilder, TyExt, }; use super::*; @@ -25,13 +26,13 @@ use super::*; #[derive(Debug, Clone, Copy)] struct LoopBlocks { begin: BasicBlockId, - end: BasicBlockId, + /// `None` for loops that are not terminating + end: Option, } struct MirLowerCtx<'a> { result: MirBody, owner: DefWithBodyId, - binding_locals: ArenaMap, current_loop_blocks: Option, discr_temp: Option, db: &'a dyn HirDatabase, @@ -48,11 +49,15 @@ pub enum MirLowerError { UnresolvedMethod, UnresolvedField, MissingFunctionDefinition, + TypeMismatch(TypeMismatch), + /// This should be never happen. Type mismatch should catch everything. TypeError(&'static str), NotSupported(String), ContinueWithoutLoop, BreakWithoutLoop, Loop, + /// Something that should never happen and is definitely a bug, but we don't want to panic if it happened + ImplementationError(&'static str), } macro_rules! not_supported { @@ -113,7 +118,9 @@ impl MirLowerCtx<'_> { ResolveValueResult::Partial(..) => return None, }; match pr { - ValueNs::LocalBinding(pat_id) => Some(self.binding_locals[pat_id].into()), + ValueNs::LocalBinding(pat_id) => { + Some(self.result.binding_locals[pat_id].into()) + } _ => None, } } @@ -125,6 +132,11 @@ impl MirLowerCtx<'_> { } _ => None, }, + Expr::Field { expr, .. } => { + let mut r = self.lower_expr_as_place(*expr)?; + self.push_field_projection(&mut r, expr_id).ok()?; + Some(r) + } _ => None, } } @@ -133,12 +145,12 @@ impl MirLowerCtx<'_> { &mut self, expr_id: ExprId, current: BasicBlockId, - ) -> Result<(Operand, BasicBlockId)> { + ) -> Result<(Operand, Option)> { if !self.has_adjustments(expr_id) { match &self.body.exprs[expr_id] { Expr::Literal(l) => { let ty = self.expr_ty(expr_id); - return Ok((self.lower_literal_to_operand(ty, l)?, current)); + return Ok((self.lower_literal_to_operand(ty, l)?, Some(current))); } _ => (), } @@ -151,27 +163,44 @@ impl MirLowerCtx<'_> { &mut self, expr_id: ExprId, prev_block: BasicBlockId, - ) -> Result<(Place, BasicBlockId)> { + ) -> Result<(Place, Option)> { if let Some(p) = self.lower_expr_as_place(expr_id) { - return Ok((p, prev_block)); + return Ok((p, Some(prev_block))); } let ty = self.expr_ty_after_adjustments(expr_id); let place = self.temp(ty)?; Ok((place.into(), self.lower_expr_to_place(expr_id, place.into(), prev_block)?)) } + fn lower_expr_to_some_place_without_adjust( + &mut self, + expr_id: ExprId, + prev_block: BasicBlockId, + ) -> Result<(Place, Option)> { + if let Some(p) = self.lower_expr_as_place_without_adjust(expr_id) { + return Ok((p, Some(prev_block))); + } + let ty = self.expr_ty(expr_id); + let place = self.temp(ty)?; + Ok(( + place.into(), + self.lower_expr_to_place_without_adjust(expr_id, place.into(), prev_block)?, + )) + } + fn lower_expr_to_place( &mut self, expr_id: ExprId, place: Place, prev_block: BasicBlockId, - ) -> Result { + ) -> Result> { if let Some(x) = self.infer.expr_adjustments.get(&expr_id) { if x.len() > 0 { - let tmp = self.temp(self.expr_ty(expr_id))?; - let current = - self.lower_expr_to_place_without_adjust(expr_id, tmp.into(), prev_block)?; - let mut r = Place::from(tmp); + let (mut r, Some(current)) = + self.lower_expr_to_some_place_without_adjust(expr_id, prev_block)? + else { + return Ok(None); + }; for adjustment in x { match &adjustment.kind { Adjust::NeverToAny => (), @@ -185,6 +214,7 @@ impl MirLowerCtx<'_> { current, tmp.into(), Rvalue::Ref(BorrowKind::from_chalk(*m), r), + expr_id.into(), ); r = tmp.into(); } @@ -199,13 +229,14 @@ impl MirLowerCtx<'_> { Operand::Copy(r).into(), target.clone(), ), + expr_id.into(), ); r = tmp.into(); } } } - self.push_assignment(current, place, Operand::Copy(r).into()); - return Ok(current); + self.push_assignment(current, place, Operand::Copy(r).into(), expr_id.into()); + return Ok(Some(current)); } } self.lower_expr_to_place_without_adjust(expr_id, place, prev_block) @@ -216,7 +247,7 @@ impl MirLowerCtx<'_> { expr_id: ExprId, place: Place, mut current: BasicBlockId, - ) -> Result { + ) -> Result> { match &self.body.exprs[expr_id] { Expr::Missing => Err(MirLowerError::IncompleteExpr), Expr::Path(p) => { @@ -235,7 +266,10 @@ impl MirLowerCtx<'_> { .0 //.ok_or(ConstEvalError::SemanticError("unresolved assoc item"))? { - hir_def::AssocItemId::ConstId(c) => self.lower_const(c, current, place), + hir_def::AssocItemId::ConstId(c) => { + self.lower_const(c, current, place, expr_id.into())?; + Ok(Some(current)) + }, _ => return Err(unresolved_name()), }; } @@ -245,14 +279,26 @@ impl MirLowerCtx<'_> { self.push_assignment( current, place, - Operand::Copy(self.binding_locals[pat_id].into()).into(), + Operand::Copy(self.result.binding_locals[pat_id].into()).into(), + expr_id.into(), ); - Ok(current) + Ok(Some(current)) + } + ValueNs::ConstId(const_id) => { + self.lower_const(const_id, current, place, expr_id.into())?; + Ok(Some(current)) } - ValueNs::ConstId(const_id) => self.lower_const(const_id, current, place), ValueNs::EnumVariantId(variant_id) => { let ty = self.infer.type_of_expr[expr_id].clone(); - self.lower_enum_variant(variant_id, current, place, ty, vec![]) + let current = self.lower_enum_variant( + variant_id, + current, + place, + ty, + vec![], + expr_id.into(), + )?; + Ok(Some(current)) } ValueNs::GenericParam(p) => { let Some(def) = self.owner.as_generic_def_id() else { @@ -276,12 +322,13 @@ impl MirLowerCtx<'_> { .intern(Interner), ) .into(), + expr_id.into(), ); - Ok(current) + Ok(Some(current)) } ValueNs::StructId(_) => { // It's probably a unit struct or a zero sized function, so no action is needed. - Ok(current) + Ok(Some(current)) } x => { not_supported!("unknown name {x:?} in value name space"); @@ -289,19 +336,18 @@ impl MirLowerCtx<'_> { } } Expr::If { condition, then_branch, else_branch } => { - let (discr, current) = self.lower_expr_to_some_operand(*condition, current)?; + let (discr, Some(current)) = self.lower_expr_to_some_operand(*condition, current)? else { + return Ok(None); + }; let start_of_then = self.new_basic_block(); - let end = self.new_basic_block(); let end_of_then = self.lower_expr_to_place(*then_branch, place.clone(), start_of_then)?; - self.set_goto(end_of_then, end); - let mut start_of_else = end; - if let Some(else_branch) = else_branch { - start_of_else = self.new_basic_block(); - let end_of_else = - self.lower_expr_to_place(*else_branch, place, start_of_else)?; - self.set_goto(end_of_else, end); - } + let start_of_else = self.new_basic_block(); + let end_of_else = if let Some(else_branch) = else_branch { + self.lower_expr_to_place(*else_branch, place, start_of_else)? + } else { + Some(start_of_else) + }; self.set_terminator( current, Terminator::SwitchInt { @@ -309,11 +355,13 @@ impl MirLowerCtx<'_> { targets: SwitchTargets::static_if(1, start_of_then, start_of_else), }, ); - Ok(end) + Ok(self.merge_blocks(end_of_then, end_of_else)) } Expr::Let { pat, expr } => { - let (cond_place, current) = self.lower_expr_to_some_place(*expr, current)?; - let result = self.new_basic_block(); + self.push_storage_live(*pat, current)?; + let (cond_place, Some(current)) = self.lower_expr_to_some_place(*expr, current)? else { + return Ok(None); + }; let (then_target, else_target) = self.pattern_match( current, None, @@ -322,13 +370,23 @@ impl MirLowerCtx<'_> { *pat, BindingAnnotation::Unannotated, )?; - self.write_bytes_to_place(then_target, place.clone(), vec![1], TyBuilder::bool())?; - self.set_goto(then_target, result); + self.write_bytes_to_place( + then_target, + place.clone(), + vec![1], + TyBuilder::bool(), + MirSpan::Unknown, + )?; if let Some(else_target) = else_target { - self.write_bytes_to_place(else_target, place, vec![0], TyBuilder::bool())?; - self.set_goto(else_target, result); + self.write_bytes_to_place( + else_target, + place, + vec![0], + TyBuilder::bool(), + MirSpan::Unknown, + )?; } - Ok(result) + Ok(self.merge_blocks(Some(then_target), else_target)) } Expr::Unsafe { id: _, statements, tail } => { self.lower_block_to_place(None, statements, current, *tail, place) @@ -344,52 +402,63 @@ impl MirLowerCtx<'_> { initializer, else_branch, type_ref: _, - } => match initializer { - Some(expr_id) => { - let else_block; - let init_place; - (init_place, current) = - self.lower_expr_to_some_place(*expr_id, current)?; - (current, else_block) = self.pattern_match( - current, - None, - init_place, - self.expr_ty_after_adjustments(*expr_id), - *pat, - BindingAnnotation::Unannotated, - )?; - match (else_block, else_branch) { - (None, _) => (), - (Some(else_block), None) => { - self.set_terminator(else_block, Terminator::Unreachable); - } - (Some(else_block), Some(else_branch)) => { - let (_, b) = self - .lower_expr_to_some_place(*else_branch, else_block)?; + } => { + self.push_storage_live(*pat, current)?; + if let Some(expr_id) = initializer { + let else_block; + let (init_place, Some(c)) = + self.lower_expr_to_some_place(*expr_id, current)? + else { + return Ok(None); + }; + current = c; + (current, else_block) = self.pattern_match( + current, + None, + init_place, + self.expr_ty_after_adjustments(*expr_id), + *pat, + BindingAnnotation::Unannotated, + )?; + match (else_block, else_branch) { + (None, _) => (), + (Some(else_block), None) => { + self.set_terminator(else_block, Terminator::Unreachable); + } + (Some(else_block), Some(else_branch)) => { + let (_, b) = self + .lower_expr_to_some_place(*else_branch, else_block)?; + if let Some(b) = b { self.set_terminator(b, Terminator::Unreachable); } } } - None => continue, - }, + } }, hir_def::expr::Statement::Expr { expr, has_semi: _ } => { - (_, current) = self.lower_expr_to_some_place(*expr, current)?; + let (_, Some(c)) = self.lower_expr_to_some_place(*expr, current)? else { + return Ok(None); + }; + current = c; } } } match tail { Some(tail) => self.lower_expr_to_place(*tail, place, current), - None => Ok(current), + None => Ok(Some(current)), } } - Expr::Loop { body, label } => self.lower_loop(current, *label, |this, begin, _| { - let (_, block) = this.lower_expr_to_some_place(*body, begin)?; - this.set_goto(block, begin); + Expr::Loop { body, label } => self.lower_loop(current, *label, |this, begin| { + if let (_, Some(block)) = this.lower_expr_to_some_place(*body, begin)? { + this.set_goto(block, begin); + } Ok(()) }), Expr::While { condition, body, label } => { - self.lower_loop(current, *label, |this, begin, end| { - let (discr, to_switch) = this.lower_expr_to_some_operand(*condition, begin)?; + self.lower_loop(current, *label, |this, begin| { + let (discr, Some(to_switch)) = this.lower_expr_to_some_operand(*condition, begin)? else { + return Ok(()); + }; + let end = this.current_loop_end()?; let after_cond = this.new_basic_block(); this.set_terminator( to_switch, @@ -398,8 +467,9 @@ impl MirLowerCtx<'_> { targets: SwitchTargets::static_if(1, after_cond, end), }, ); - let (_, block) = this.lower_expr_to_some_place(*body, after_cond)?; - this.set_goto(block, begin); + if let (_, Some(block)) = this.lower_expr_to_some_place(*body, after_cond)? { + this.set_goto(block, begin); + } Ok(()) }) } @@ -409,7 +479,7 @@ impl MirLowerCtx<'_> { match &callee_ty.data(Interner).kind { chalk_ir::TyKind::FnDef(..) => { let func = Operand::from_bytes(vec![], callee_ty.clone()); - self.lower_call(func, args.iter().copied(), place, current) + self.lower_call(func, args.iter().copied(), place, current, self.is_uninhabited(expr_id)) } TyKind::Scalar(_) | TyKind::Tuple(_, _) @@ -451,16 +521,21 @@ impl MirLowerCtx<'_> { iter::once(*receiver).chain(args.iter().copied()), place, current, + self.is_uninhabited(expr_id), ) } Expr::Match { expr, arms } => { - let (cond_place, mut current) = self.lower_expr_to_some_place(*expr, current)?; + let (cond_place, Some(mut current)) = self.lower_expr_to_some_place(*expr, current)? + else { + return Ok(None); + }; let cond_ty = self.expr_ty_after_adjustments(*expr); - let end = self.new_basic_block(); + let mut end = None; for MatchArm { pat, guard, expr } in arms.iter() { if guard.is_some() { not_supported!("pattern matching with guard"); } + self.push_storage_live(*pat, current)?; let (then, otherwise) = self.pattern_match( current, None, @@ -469,8 +544,10 @@ impl MirLowerCtx<'_> { *pat, BindingAnnotation::Unannotated, )?; - let block = self.lower_expr_to_place(*expr, place.clone(), then)?; - self.set_goto(block, end); + if let Some(block) = self.lower_expr_to_place(*expr, place.clone(), then)? { + let r = end.get_or_insert_with(|| self.new_basic_block()); + self.set_goto(block, *r); + } match otherwise { Some(o) => current = o, None => { @@ -491,8 +568,7 @@ impl MirLowerCtx<'_> { let loop_data = self.current_loop_blocks.ok_or(MirLowerError::ContinueWithoutLoop)?; self.set_goto(current, loop_data.begin); - let otherwise = self.new_basic_block(); - Ok(otherwise) + Ok(None) } }, Expr::Break { expr, label } => { @@ -502,19 +578,23 @@ impl MirLowerCtx<'_> { match label { Some(_) => not_supported!("break with label"), None => { - let loop_data = - self.current_loop_blocks.ok_or(MirLowerError::BreakWithoutLoop)?; - self.set_goto(current, loop_data.end); - Ok(self.new_basic_block()) + let end = + self.current_loop_end()?; + self.set_goto(current, end); + Ok(None) } } } Expr::Return { expr } => { if let Some(expr) = expr { - current = self.lower_expr_to_place(*expr, return_slot().into(), current)?; + if let Some(c) = self.lower_expr_to_place(*expr, return_slot().into(), current)? { + current = c; + } else { + return Ok(None); + } } self.set_terminator(current, Terminator::Return); - Ok(self.new_basic_block()) + Ok(None) } Expr::Yield { .. } => not_supported!("yield"), Expr::RecordLit { fields, .. } => { @@ -533,8 +613,10 @@ impl MirLowerCtx<'_> { for RecordLitField { name, expr } in fields.iter() { let field_id = variant_data.field(name).ok_or(MirLowerError::UnresolvedField)?; - let op; - (op, current) = self.lower_expr_to_some_operand(*expr, current)?; + let (op, Some(c)) = self.lower_expr_to_some_operand(*expr, current)? else { + return Ok(None); + }; + current = c; operands[u32::from(field_id.into_raw()) as usize] = Some(op); } self.push_assignment( @@ -546,8 +628,9 @@ impl MirLowerCtx<'_> { MirLowerError::TypeError("missing field in record literal"), )?, ), + expr_id.into(), ); - Ok(current) + Ok(Some(current)) } VariantId::UnionId(union_id) => { let [RecordLitField { name, expr }] = fields.as_ref() else { @@ -563,22 +646,18 @@ impl MirLowerCtx<'_> { } } } - Expr::Field { expr, name } => { - let (mut current_place, current) = self.lower_expr_to_some_place(*expr, current)?; - if let TyKind::Tuple(..) = self.expr_ty_after_adjustments(*expr).kind(Interner) { - let index = name - .as_tuple_index() - .ok_or(MirLowerError::TypeError("named field on tuple"))?; - current_place.projection.push(ProjectionElem::TupleField(index)) - } else { - let field = self - .infer - .field_resolution(expr_id) - .ok_or(MirLowerError::UnresolvedField)?; - current_place.projection.push(ProjectionElem::Field(field)); - } - self.push_assignment(current, place, Operand::Copy(current_place).into()); - Ok(current) + Expr::Field { expr, .. } => { + let (mut current_place, Some(current)) = self.lower_expr_to_some_place(*expr, current)? else { + return Ok(None); + }; + self.push_field_projection(&mut current_place, expr_id)?; + self.push_assignment( + current, + place, + Operand::Copy(current_place).into(), + expr_id.into(), + ); + Ok(Some(current)) } Expr::Await { .. } => not_supported!("await"), Expr::Try { .. } => not_supported!("? operator"), @@ -587,40 +666,53 @@ impl MirLowerCtx<'_> { Expr::Async { .. } => not_supported!("async block"), Expr::Const { .. } => not_supported!("anonymous const block"), Expr::Cast { expr, type_ref: _ } => { - let (x, current) = self.lower_expr_to_some_operand(*expr, current)?; + let (x, Some(current)) = self.lower_expr_to_some_operand(*expr, current)? else { + return Ok(None); + }; let source_ty = self.infer[*expr].clone(); let target_ty = self.infer[expr_id].clone(); self.push_assignment( current, place, Rvalue::Cast(cast_kind(&source_ty, &target_ty)?, x, target_ty), + expr_id.into(), ); - Ok(current) + Ok(Some(current)) } Expr::Ref { expr, rawness: _, mutability } => { - let p; - (p, current) = self.lower_expr_to_some_place(*expr, current)?; + let (p, Some(current)) = self.lower_expr_to_some_place(*expr, current)? else { + return Ok(None); + }; let bk = BorrowKind::from_hir(*mutability); - self.push_assignment(current, place, Rvalue::Ref(bk, p)); - Ok(current) + self.push_assignment(current, place, Rvalue::Ref(bk, p), expr_id.into()); + Ok(Some(current)) } Expr::Box { .. } => not_supported!("box expression"), Expr::UnaryOp { expr, op } => match op { hir_def::expr::UnaryOp::Deref => { - let (mut tmp, current) = self.lower_expr_to_some_place(*expr, current)?; + let (mut tmp, Some(current)) = self.lower_expr_to_some_place(*expr, current)? else { + return Ok(None); + }; tmp.projection.push(ProjectionElem::Deref); - self.push_assignment(current, place, Operand::Copy(tmp).into()); - Ok(current) + self.push_assignment(current, place, Operand::Copy(tmp).into(), expr_id.into()); + Ok(Some(current)) } - hir_def::expr::UnaryOp::Not => { - let (op, current) = self.lower_expr_to_some_operand(*expr, current)?; - self.push_assignment(current, place, Rvalue::UnaryOp(UnOp::Not, op)); - Ok(current) - } - hir_def::expr::UnaryOp::Neg => { - let (op, current) = self.lower_expr_to_some_operand(*expr, current)?; - self.push_assignment(current, place, Rvalue::UnaryOp(UnOp::Neg, op)); - Ok(current) + hir_def::expr::UnaryOp::Not | hir_def::expr::UnaryOp::Neg => { + let (operand, Some(current)) = self.lower_expr_to_some_operand(*expr, current)? else { + return Ok(None); + }; + let operation = match op { + hir_def::expr::UnaryOp::Not => UnOp::Not, + hir_def::expr::UnaryOp::Neg => UnOp::Neg, + _ => unreachable!(), + }; + self.push_assignment( + current, + place, + Rvalue::UnaryOp(operation, operand), + expr_id.into(), + ); + Ok(Some(current)) } }, Expr::BinaryOp { lhs, rhs, op } => { @@ -632,15 +724,18 @@ impl MirLowerCtx<'_> { let Some(lhs_place) = self.lower_expr_as_place(*lhs) else { not_supported!("assignment to complex place"); }; - let rhs_op; - (rhs_op, current) = self.lower_expr_to_some_operand(*rhs, current)?; - self.push_assignment(current, lhs_place, rhs_op.into()); - return Ok(current); + let (rhs_op, Some(current)) = self.lower_expr_to_some_operand(*rhs, current)? else { + return Ok(None); + }; + self.push_assignment(current, lhs_place, rhs_op.into(), expr_id.into()); + return Ok(Some(current)); } - let lhs_op; - (lhs_op, current) = self.lower_expr_to_some_operand(*lhs, current)?; - let rhs_op; - (rhs_op, current) = self.lower_expr_to_some_operand(*rhs, current)?; + let (lhs_op, Some(current)) = self.lower_expr_to_some_operand(*lhs, current)? else { + return Ok(None); + }; + let (rhs_op, Some(current)) = self.lower_expr_to_some_operand(*rhs, current)? else { + return Ok(None); + }; self.push_assignment( current, place, @@ -657,34 +752,44 @@ impl MirLowerCtx<'_> { lhs_op, rhs_op, ), + expr_id.into(), ); - Ok(current) + Ok(Some(current)) } Expr::Range { .. } => not_supported!("range"), Expr::Index { base, index } => { - let mut p_base; - (p_base, current) = self.lower_expr_to_some_place(*base, current)?; + let (mut p_base, Some(current)) = self.lower_expr_to_some_place(*base, current)? else { + return Ok(None); + }; let l_index = self.temp(self.expr_ty_after_adjustments(*index))?; - current = self.lower_expr_to_place(*index, l_index.into(), current)?; + let Some(current) = self.lower_expr_to_place(*index, l_index.into(), current)? else { + return Ok(None); + }; p_base.projection.push(ProjectionElem::Index(l_index)); - self.push_assignment(current, place, Operand::Copy(p_base).into()); - Ok(current) + self.push_assignment(current, place, Operand::Copy(p_base).into(), expr_id.into()); + Ok(Some(current)) } Expr::Closure { .. } => not_supported!("closure"), Expr::Tuple { exprs, is_assignee_expr: _ } => { - let r = Rvalue::Aggregate( - AggregateKind::Tuple(self.expr_ty(expr_id)), - exprs + let Some(values) = exprs .iter() .map(|x| { - let o; - (o, current) = self.lower_expr_to_some_operand(*x, current)?; - Ok(o) + let (o, Some(c)) = self.lower_expr_to_some_operand(*x, current)? else { + return Ok(None); + }; + current = c; + Ok(Some(o)) }) - .collect::>()?, + .collect::>>()? + else { + return Ok(None); + }; + let r = Rvalue::Aggregate( + AggregateKind::Tuple(self.expr_ty(expr_id)), + values, ); - self.push_assignment(current, place, r); - Ok(current) + self.push_assignment(current, place, r, expr_id.into()); + Ok(Some(current)) } Expr::Array(l) => match l { Array::ElementList { elements, .. } => { @@ -696,86 +801,54 @@ impl MirLowerCtx<'_> { )) } }; - let r = Rvalue::Aggregate( - AggregateKind::Array(elem_ty), - elements + let Some(values) = elements .iter() .map(|x| { - let o; - (o, current) = self.lower_expr_to_some_operand(*x, current)?; - Ok(o) + let (o, Some(c)) = self.lower_expr_to_some_operand(*x, current)? else { + return Ok(None); + }; + current = c; + Ok(Some(o)) }) - .collect::>()?, + .collect::>>()? + else { + return Ok(None); + }; + let r = Rvalue::Aggregate( + AggregateKind::Array(elem_ty), + values, ); - self.push_assignment(current, place, r); - Ok(current) + self.push_assignment(current, place, r, expr_id.into()); + Ok(Some(current)) } Array::Repeat { .. } => not_supported!("array repeat"), }, Expr::Literal(l) => { let ty = self.expr_ty(expr_id); let op = self.lower_literal_to_operand(ty, l)?; - self.push_assignment(current, place, op.into()); - Ok(current) + self.push_assignment(current, place, op.into(), expr_id.into()); + Ok(Some(current)) } Expr::Underscore => not_supported!("underscore"), } } - fn lower_block_to_place( - &mut self, - label: Option, - statements: &[hir_def::expr::Statement], - mut current: BasicBlockId, - tail: Option, - place: Place, - ) -> Result { - if label.is_some() { - not_supported!("block with label"); - } - for statement in statements.iter() { - match statement { - hir_def::expr::Statement::Let { pat, initializer, else_branch, type_ref: _ } => { - match initializer { - Some(expr_id) => { - let else_block; - let init_place; - (init_place, current) = - self.lower_expr_to_some_place(*expr_id, current)?; - (current, else_block) = self.pattern_match( - current, - None, - init_place, - self.expr_ty(*expr_id), - *pat, - BindingAnnotation::Unannotated, - )?; - match (else_block, else_branch) { - (None, _) => (), - (Some(else_block), None) => { - self.set_terminator(else_block, Terminator::Unreachable); - } - (Some(else_block), Some(else_branch)) => { - let (_, b) = - self.lower_expr_to_some_place(*else_branch, else_block)?; - self.set_terminator(b, Terminator::Unreachable); - } - } - } - None => continue, - } - } - hir_def::expr::Statement::Expr { expr, has_semi: _ } => { - let ty = self.expr_ty(*expr); - let temp = self.temp(ty)?; - current = self.lower_expr_to_place(*expr, temp.into(), current)?; - } + fn push_field_projection(&self, place: &mut Place, expr_id: ExprId) -> Result<()> { + if let Expr::Field { expr, name } = &self.body[expr_id] { + if let TyKind::Tuple(..) = self.expr_ty_after_adjustments(*expr).kind(Interner) { + let index = name + .as_tuple_index() + .ok_or(MirLowerError::TypeError("named field on tuple"))?; + place.projection.push(ProjectionElem::TupleField(index)) + } else { + let field = + self.infer.field_resolution(expr_id).ok_or(MirLowerError::UnresolvedField)?; + place.projection.push(ProjectionElem::Field(field)); } + } else { + not_supported!("") } - match tail { - Some(tail) => self.lower_expr_to_place(tail, place, current), - None => Ok(current), - } + Ok(()) } fn lower_literal_to_operand(&mut self, ty: Ty, l: &Literal) -> Result { @@ -824,9 +897,10 @@ impl MirLowerCtx<'_> { const_id: hir_def::ConstId, prev_block: BasicBlockId, place: Place, - ) -> Result { + span: MirSpan, + ) -> Result<()> { let c = self.db.const_eval(const_id)?; - self.write_const_to_place(c, prev_block, place) + self.write_const_to_place(c, prev_block, place, span) } fn write_const_to_place( @@ -834,9 +908,10 @@ impl MirLowerCtx<'_> { c: Const, prev_block: BasicBlockId, place: Place, - ) -> Result { - self.push_assignment(prev_block, place, Operand::Constant(c).into()); - Ok(prev_block) + span: MirSpan, + ) -> Result<()> { + self.push_assignment(prev_block, place, Operand::Constant(c).into(), span); + Ok(()) } fn write_bytes_to_place( @@ -845,9 +920,10 @@ impl MirLowerCtx<'_> { place: Place, cv: Vec, ty: Ty, - ) -> Result { - self.push_assignment(prev_block, place, Operand::from_bytes(cv, ty).into()); - Ok(prev_block) + span: MirSpan, + ) -> Result<()> { + self.push_assignment(prev_block, place, Operand::from_bytes(cv, ty).into(), span); + Ok(()) } fn lower_enum_variant( @@ -857,6 +933,7 @@ impl MirLowerCtx<'_> { place: Place, ty: Ty, fields: Vec, + span: MirSpan, ) -> Result { let subst = match ty.kind(Interner) { TyKind::Adt(_, subst) => subst.clone(), @@ -866,6 +943,7 @@ impl MirLowerCtx<'_> { prev_block, place, Rvalue::Aggregate(AggregateKind::Adt(variant_id.into(), subst), fields), + span, ); Ok(prev_block) } @@ -876,26 +954,29 @@ impl MirLowerCtx<'_> { args: impl Iterator, place: Place, mut current: BasicBlockId, - ) -> Result { - let args = args + is_uninhabited: bool, + ) -> Result> { + let Some(args) = args .map(|arg| { - let temp; - (temp, current) = self.lower_expr_to_some_operand(arg, current)?; - Ok(temp) + if let (temp, Some(c)) = self.lower_expr_to_some_operand(arg, current)? { + current = c; + Ok(Some(temp)) + } else { + Ok(None) + } }) - .collect::>>()?; - let b = self.result.basic_blocks.alloc(BasicBlock { - statements: vec![], - terminator: None, - is_cleanup: false, - }); + .collect::>>>()? + else { + return Ok(None); + }; + let b = if is_uninhabited { None } else { Some(self.new_basic_block()) }; self.set_terminator( current, Terminator::Call { func, args, destination: place, - target: Some(b), + target: b, cleanup: None, from_hir_call: true, }, @@ -929,8 +1010,18 @@ impl MirLowerCtx<'_> { ty.unwrap_or_else(|| self.expr_ty(e)) } - fn push_assignment(&mut self, block: BasicBlockId, place: Place, rvalue: Rvalue) { - self.result.basic_blocks[block].statements.push(Statement::Assign(place, rvalue)); + fn push_statement(&mut self, block: BasicBlockId, statement: Statement) { + self.result.basic_blocks[block].statements.push(statement); + } + + fn push_assignment( + &mut self, + block: BasicBlockId, + place: Place, + rvalue: Rvalue, + span: MirSpan, + ) { + self.push_statement(block, StatementKind::Assign(place, rvalue).with_span(span)); } /// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if @@ -983,8 +1074,14 @@ impl MirLowerCtx<'_> { let then_target = self.new_basic_block(); let mut finished = false; for pat in &**pats { - let (next, next_else) = - self.pattern_match(current, None, cond_place.clone(), cond_ty.clone(), *pat, binding_mode)?; + let (next, next_else) = self.pattern_match( + current, + None, + cond_place.clone(), + cond_ty.clone(), + *pat, + binding_mode, + )?; self.set_goto(next, then_target); match next_else { Some(t) => { @@ -1036,7 +1133,7 @@ impl MirLowerCtx<'_> { (then_target, Some(else_target)) } Pat::Bind { id, subpat } => { - let target_place = self.binding_locals[*id]; + let target_place = self.result.binding_locals[*id]; let mode = self.body.bindings[*id].mode; if let Some(subpat) = subpat { (current, current_else) = self.pattern_match( @@ -1064,6 +1161,7 @@ impl MirLowerCtx<'_> { cond_place, ), }, + pattern.into(), ); (current, current_else) } @@ -1090,6 +1188,7 @@ impl MirLowerCtx<'_> { current, tmp.clone(), Rvalue::Discriminant(cond_place.clone()), + pattern.into(), ); let else_target = current_else.unwrap_or_else(|| self.new_basic_block()); self.set_terminator( @@ -1183,23 +1282,94 @@ impl MirLowerCtx<'_> { &mut self, prev_block: BasicBlockId, label: Option, - f: impl FnOnce(&mut MirLowerCtx<'_>, BasicBlockId, BasicBlockId) -> Result<()>, - ) -> Result { + f: impl FnOnce(&mut MirLowerCtx<'_>, BasicBlockId) -> Result<()>, + ) -> Result> { if label.is_some() { not_supported!("loop with label"); } let begin = self.new_basic_block(); - let end = self.new_basic_block(); - let prev = mem::replace(&mut self.current_loop_blocks, Some(LoopBlocks { begin, end })); + let prev = + mem::replace(&mut self.current_loop_blocks, Some(LoopBlocks { begin, end: None })); self.set_goto(prev_block, begin); - f(self, begin, end)?; - self.current_loop_blocks = prev; - Ok(end) + f(self, begin)?; + let my = mem::replace(&mut self.current_loop_blocks, prev) + .ok_or(MirLowerError::ImplementationError("current_loop_blocks is corrupt"))?; + Ok(my.end) } fn has_adjustments(&self, expr_id: ExprId) -> bool { !self.infer.expr_adjustments.get(&expr_id).map(|x| x.is_empty()).unwrap_or(true) } + + fn merge_blocks( + &mut self, + b1: Option, + b2: Option, + ) -> Option { + match (b1, b2) { + (None, None) => None, + (None, Some(b)) | (Some(b), None) => Some(b), + (Some(b1), Some(b2)) => { + let bm = self.new_basic_block(); + self.set_goto(b1, bm); + self.set_goto(b2, bm); + Some(bm) + } + } + } + + fn current_loop_end(&mut self) -> Result { + let r = match self + .current_loop_blocks + .as_mut() + .ok_or(MirLowerError::ImplementationError("Current loop access out of loop"))? + .end + { + Some(x) => x, + None => { + let s = self.new_basic_block(); + self.current_loop_blocks + .as_mut() + .ok_or(MirLowerError::ImplementationError("Current loop access out of loop"))? + .end = Some(s); + s + } + }; + Ok(r) + } + + fn is_uninhabited(&self, expr_id: ExprId) -> bool { + is_ty_uninhabited_from(&self.infer[expr_id], self.owner.module(self.db.upcast()), self.db) + } + + /// This function push `StorageLive` statements for each binding in the pattern. + fn push_storage_live(&mut self, pat: PatId, current: BasicBlockId) -> Result<()> { + // Current implementation is wrong. It adds no `StorageDead` at the end of scope, and before each break + // and continue. It just add a `StorageDead` before the `StorageLive`, which is not wrong, but unneeeded in + // the proper implementation. Due this limitation, implementing a borrow checker on top of this mir will falsely + // allow this: + // + // ``` + // let x; + // loop { + // let y = 2; + // x = &y; + // if some_condition { + // break; // we need to add a StorageDead(y) above this to kill the x borrow + // } + // } + // use(x) + // ``` + // But I think this approach work for mutability analysis, as user can't write code which mutates a binding + // after StorageDead, except loops, which are handled by this hack. + let span = pat.into(); + self.body.walk_child_bindings(pat, &mut |b| { + let l = self.result.binding_locals[b]; + self.push_statement(current, StatementKind::StorageDead(l).with_span(span)); + self.push_statement(current, StatementKind::StorageLive(l).with_span(span)); + }); + Ok(()) + } } fn pattern_matching_dereference( @@ -1257,6 +1427,11 @@ pub fn lower_to_mir( // need to take this input explicitly. root_expr: ExprId, ) -> Result { + if let (Some((_, x)), _) | (_, Some((_, x))) = + (infer.expr_type_mismatches().next(), infer.pat_type_mismatches().next()) + { + return Err(MirLowerError::TypeMismatch(x.clone())); + } let mut basic_blocks = Arena::new(); let start_block = basic_blocks.alloc(BasicBlock { statements: vec![], terminator: None, is_cleanup: false }); @@ -1299,13 +1474,19 @@ pub fn lower_to_mir( ); } } - let mir = MirBody { basic_blocks, locals, start_block, owner, arg_count: body.params.len() }; + let mir = MirBody { + basic_blocks, + locals, + start_block, + binding_locals, + owner, + arg_count: body.params.len(), + }; let mut ctx = MirLowerCtx { result: mir, db, infer, body, - binding_locals, owner, current_loop_blocks: None, discr_temp: None, @@ -1313,7 +1494,7 @@ pub fn lower_to_mir( let mut current = start_block; for ¶m in &body.params { if let Pat::Bind { id, .. } = body[param] { - if param_locals[param] == ctx.binding_locals[id] { + if param_locals[param] == ctx.result.binding_locals[id] { continue; } } @@ -1330,7 +1511,8 @@ pub fn lower_to_mir( } current = r.0; } - let b = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)?; - ctx.result.basic_blocks[b].terminator = Some(Terminator::Return); + if let Some(b) = ctx.lower_expr_to_place(root_expr, return_slot().into(), current)? { + ctx.result.basic_blocks[b].terminator = Some(Terminator::Return); + } Ok(ctx.result) } diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs index b30c664e24..c257ee2ae3 100644 --- a/crates/hir/src/diagnostics.rs +++ b/crates/hir/src/diagnostics.rs @@ -10,7 +10,7 @@ use hir_def::path::ModPath; use hir_expand::{name::Name, HirFileId, InFile}; use syntax::{ast, AstPtr, SyntaxNodePtr, TextRange}; -use crate::{AssocItem, Field, MacroKind, Type}; +use crate::{AssocItem, Field, Local, MacroKind, Type}; macro_rules! diagnostics { ($($diag:ident,)*) => { @@ -41,6 +41,7 @@ diagnostics![ MissingFields, MissingMatchArms, MissingUnsafe, + NeedMut, NoSuchField, PrivateAssocItem, PrivateField, @@ -54,6 +55,7 @@ diagnostics![ UnresolvedMethodCall, UnresolvedModule, UnresolvedProcMacro, + UnusedMut, ]; #[derive(Debug)] @@ -209,4 +211,15 @@ pub struct TypeMismatch { pub actual: Type, } +#[derive(Debug)] +pub struct NeedMut { + pub local: Local, + pub span: InFile, +} + +#[derive(Debug)] +pub struct UnusedMut { + pub local: Local, +} + pub use hir_ty::diagnostics::IncorrectCase; diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index b83d83b5ed..4b65a93cac 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -63,7 +63,7 @@ use hir_ty::{ display::HexifiedConst, layout::layout_of_ty, method_resolution::{self, TyFingerprint}, - mir::interpret_mir, + mir::{self, interpret_mir}, primitive::UintTy, traits::FnTrait, AliasTy, CallableDefId, CallableSig, Canonical, CanonicalVarKinds, Cast, ClosureId, @@ -85,12 +85,12 @@ use crate::db::{DefDatabase, HirDatabase}; pub use crate::{ attrs::{HasAttrs, Namespace}, diagnostics::{ - AnyDiagnostic, BreakOutsideOfLoop, ExpectedFunction, InactiveCode, IncorrectCase, - InvalidDeriveTarget, MacroError, MalformedDerive, MismatchedArgCount, MissingFields, - MissingMatchArms, MissingUnsafe, NoSuchField, PrivateAssocItem, PrivateField, + AnyDiagnostic, BreakOutsideOfLoop, InactiveCode, IncorrectCase, InvalidDeriveTarget, + MacroError, MalformedDerive, MismatchedArgCount, MissingFields, MissingMatchArms, + MissingUnsafe, NeedMut, NoSuchField, PrivateAssocItem, PrivateField, ReplaceFilterMapNextWithFindMap, TypeMismatch, UnimplementedBuiltinMacro, - UnresolvedExternCrate, UnresolvedField, UnresolvedImport, UnresolvedMacroCall, - UnresolvedMethodCall, UnresolvedModule, UnresolvedProcMacro, + UnresolvedExternCrate, UnresolvedImport, UnresolvedMacroCall, UnresolvedModule, + UnresolvedProcMacro, UnusedMut, }, has_source::HasSource, semantics::{PathResolution, Semantics, SemanticsScope, TypeInfo, VisibleTraits}, @@ -1500,6 +1500,38 @@ impl DefWithBody { } } + let hir_body = db.body(self.into()); + + if let Ok(mir_body) = db.mir_body(self.into()) { + let mol = mir::borrowck::mutability_of_locals(&mir_body); + for (binding_id, _) in hir_body.bindings.iter() { + let need_mut = &mol[mir_body.binding_locals[binding_id]]; + let local = Local { parent: self.into(), binding_id }; + match (need_mut, local.is_mut(db)) { + (mir::borrowck::Mutability::Mut { .. }, true) + | (mir::borrowck::Mutability::Not, false) => (), + (mir::borrowck::Mutability::Mut { span }, false) => { + let span: InFile = match span { + mir::MirSpan::ExprId(e) => match source_map.expr_syntax(*e) { + Ok(s) => s.map(|x| x.into()), + Err(_) => continue, + }, + mir::MirSpan::PatId(p) => match source_map.pat_syntax(*p) { + Ok(s) => s.map(|x| match x { + Either::Left(e) => e.into(), + Either::Right(e) => e.into(), + }), + Err(_) => continue, + }, + mir::MirSpan::Unknown => continue, + }; + acc.push(NeedMut { local, span }.into()); + } + (mir::borrowck::Mutability::Not, true) => acc.push(UnusedMut { local }.into()), + } + } + } + for diagnostic in BodyValidationDiagnostic::collect(db, self.into()) { match diagnostic { BodyValidationDiagnostic::RecordMissingFields { @@ -2490,6 +2522,10 @@ impl LocalSource { pub fn syntax(&self) -> &SyntaxNode { self.source.value.syntax() } + + pub fn syntax_ptr(self) -> InFile { + self.source.map(|x| SyntaxNodePtr::new(x.syntax())) + } } impl Local { diff --git a/crates/ide-diagnostics/src/handlers/mutability_errors.rs b/crates/ide-diagnostics/src/handlers/mutability_errors.rs new file mode 100644 index 0000000000..a78b58fdc8 --- /dev/null +++ b/crates/ide-diagnostics/src/handlers/mutability_errors.rs @@ -0,0 +1,302 @@ +use crate::{Diagnostic, DiagnosticsContext, Severity}; + +// Diagnostic: need-mut +// +// This diagnostic is triggered on mutating an immutable variable. +pub(crate) fn need_mut(ctx: &DiagnosticsContext<'_>, d: &hir::NeedMut) -> Diagnostic { + Diagnostic::new( + "need-mut", + format!("cannot mutate immutable variable `{}`", d.local.name(ctx.sema.db)), + ctx.sema.diagnostics_display_range(d.span.clone()).range, + ) +} + +// Diagnostic: unused-mut +// +// This diagnostic is triggered when a mutable variable isn't actually mutated. +pub(crate) fn unused_mut(ctx: &DiagnosticsContext<'_>, d: &hir::UnusedMut) -> Diagnostic { + Diagnostic::new( + "unused-mut", + "remove this `mut`", + ctx.sema.diagnostics_display_range(d.local.primary_source(ctx.sema.db).syntax_ptr()).range, + ) + .severity(Severity::WeakWarning) +} + +#[cfg(test)] +mod tests { + use crate::tests::check_diagnostics; + + #[test] + fn unused_mut_simple() { + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let mut x = 2; + //^^^^^ weak: remove this `mut` + f(x); +} +"#, + ); + } + + #[test] + fn no_false_positive_simple() { + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let x = 2; + f(x); +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let mut x = 2; + x = 5; + f(x); +} +"#, + ); + } + + #[test] + fn field_mutate() { + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let mut x = (2, 7); + //^^^^^ weak: remove this `mut` + f(x.1); +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let mut x = (2, 7); + x.0 = 5; + f(x.1); +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let x = (2, 7); + x.0 = 5; + //^^^^^^^ error: cannot mutate immutable variable `x` + f(x.1); +} +"#, + ); + } + + #[test] + fn mutable_reference() { + check_diagnostics( + r#" +fn main() { + let mut x = &mut 2; + //^^^^^ weak: remove this `mut` + *x = 5; +} +"#, + ); + check_diagnostics( + r#" +fn main() { + let x = 2; + &mut x; + //^^^^^^ error: cannot mutate immutable variable `x` +} +"#, + ); + check_diagnostics( + r#" +fn main() { + let x_own = 2; + let ref mut x_ref = x_own; + //^^^^^^^^^^^^^ error: cannot mutate immutable variable `x_own` +} +"#, + ); + check_diagnostics( + r#" +struct Foo; +impl Foo { + fn method(&mut self, x: i32) {} +} +fn main() { + let x = Foo; + x.method(2); + //^ error: cannot mutate immutable variable `x` +} +"#, + ); + } + + #[test] + fn match_bindings() { + check_diagnostics( + r#" +fn main() { + match (2, 3) { + (x, mut y) => { + //^^^^^ weak: remove this `mut` + x = 7; + //^^^^^ error: cannot mutate immutable variable `x` + } + } +} +"#, + ); + } + + #[test] + fn mutation_in_dead_code() { + // This one is interesting. Dead code is not represented at all in the MIR, so + // there would be no mutablility error for locals in dead code. Rustc tries to + // not emit `unused_mut` in this case, but since it works without `mut`, and + // special casing it is not trivial, we emit it. + check_diagnostics( + r#" +fn main() { + return; + let mut x = 2; + //^^^^^ weak: remove this `mut` + &mut x; +} +"#, + ); + check_diagnostics( + r#" +fn main() { + loop {} + let mut x = 2; + //^^^^^ weak: remove this `mut` + &mut x; +} +"#, + ); + check_diagnostics( + r#" +enum X {} +fn g() -> X { + loop {} +} +fn f() -> ! { + loop {} +} +fn main(b: bool) { + if b { + f(); + } else { + g(); + } + let mut x = 2; + //^^^^^ weak: remove this `mut` + &mut x; +} +"#, + ); + check_diagnostics( + r#" +fn main(b: bool) { + if b { + loop {} + } else { + return; + } + let mut x = 2; + //^^^^^ weak: remove this `mut` + &mut x; +} +"#, + ); + } + + #[test] + fn initialization_is_not_mutation() { + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let mut x; + //^^^^^ weak: remove this `mut` + x = 5; + f(x); +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main(b: bool) { + let mut x; + //^^^^^ weak: remove this `mut` + if b { + x = 1; + } else { + x = 3; + } + f(x); +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main(b: bool) { + let x; + if b { + x = 1; + } + x = 3; + //^^^^^ error: cannot mutate immutable variable `x` + f(x); +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + let x; + loop { + x = 1; + //^^^^^ error: cannot mutate immutable variable `x` + f(x); + } +} +"#, + ); + check_diagnostics( + r#" +fn f(_: i32) {} +fn main() { + loop { + let mut x = 1; + //^^^^^ weak: remove this `mut` + f(x); + if let mut y = 2 { + //^^^^^ weak: remove this `mut` + f(y); + } + match 3 { + mut z => f(z), + //^^^^^ weak: remove this `mut` + } + } +} +"#, + ); + } +} diff --git a/crates/ide-diagnostics/src/lib.rs b/crates/ide-diagnostics/src/lib.rs index c8635ff801..f6c9b79c30 100644 --- a/crates/ide-diagnostics/src/lib.rs +++ b/crates/ide-diagnostics/src/lib.rs @@ -37,6 +37,7 @@ mod handlers { pub(crate) mod missing_fields; pub(crate) mod missing_match_arms; pub(crate) mod missing_unsafe; + pub(crate) mod mutability_errors; pub(crate) mod no_such_field; pub(crate) mod private_assoc_item; pub(crate) mod private_field; @@ -273,7 +274,8 @@ pub fn diagnostics( AnyDiagnostic::InvalidDeriveTarget(d) => handlers::invalid_derive_target::invalid_derive_target(&ctx, &d), AnyDiagnostic::UnresolvedField(d) => handlers::unresolved_field::unresolved_field(&ctx, &d), AnyDiagnostic::UnresolvedMethodCall(d) => handlers::unresolved_method::unresolved_method(&ctx, &d), - + AnyDiagnostic::NeedMut(d) => handlers::mutability_errors::need_mut(&ctx, &d), + AnyDiagnostic::UnusedMut(d) => handlers::mutability_errors::unused_mut(&ctx, &d), AnyDiagnostic::InactiveCode(d) => match handlers::inactive_code::inactive_code(&ctx, &d) { Some(it) => it, None => continue,