Support "for loop" MIR lowering

This commit is contained in:
hkalbasi 2023-03-02 11:18:50 +03:30
parent ac04bfd7a7
commit 6377d50bd1
8 changed files with 292 additions and 79 deletions

View file

@ -415,6 +415,43 @@ fn loops() {
);
}
#[test]
fn for_loops() {
check_number(
r#"
//- minicore: iterator
struct Range {
start: u8,
end: u8,
}
impl Iterator for Range {
type Item = u8;
fn next(&mut self) -> Option<u8> {
if self.start >= self.end {
None
} else {
let r = self.start;
self.start = self.start + 1;
Some(r)
}
}
}
const GOAL: u8 = {
let mut sum = 0;
let ar = Range { start: 1, end: 11 };
for i in ar {
sum = sum + i;
}
sum
};
"#,
55,
);
}
#[test]
fn recursion() {
check_number(
@ -518,6 +555,33 @@ fn tuples() {
);
}
#[test]
fn path_pattern_matching() {
check_number(
r#"
enum Season {
Spring,
Summer,
Fall,
Winter,
}
use Season::*;
const fn f(x: Season) -> i32 {
match x {
Spring => 1,
Summer => 2,
Fall => 3,
Winter => 4,
}
}
const GOAL: i32 = f(Spring) + 10 * f(Summer) + 100 * f(Fall) + 1000 * f(Winter);
"#,
4321,
);
}
#[test]
fn pattern_matching_ergonomics() {
check_number(

View file

@ -354,6 +354,8 @@ pub struct InferenceResult {
pub type_of_pat: ArenaMap<PatId, Ty>,
pub type_of_binding: ArenaMap<BindingId, Ty>,
pub type_of_rpit: ArenaMap<RpitId, Ty>,
/// Type of the result of `.into_iter()` on the for. `ExprId` is the one of the whole for loop.
pub type_of_for_iterator: ArenaMap<ExprId, Ty>,
type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch>,
/// Interned common types to return references to.
standard_types: InternedStandardTypes,
@ -549,6 +551,9 @@ impl<'a> InferenceContext<'a> {
for ty in result.type_of_rpit.values_mut() {
*ty = table.resolve_completely(ty.clone());
}
for ty in result.type_of_for_iterator.values_mut() {
*ty = table.resolve_completely(ty.clone());
}
for mismatch in result.type_mismatches.values_mut() {
mismatch.expected = table.resolve_completely(mismatch.expected.clone());
mismatch.actual = table.resolve_completely(mismatch.actual.clone());

View file

@ -242,8 +242,10 @@ impl<'a> InferenceContext<'a> {
let iterable_ty = self.infer_expr(iterable, &Expectation::none());
let into_iter_ty =
self.resolve_associated_type(iterable_ty, self.resolve_into_iter_item());
let pat_ty =
self.resolve_associated_type(into_iter_ty, self.resolve_iterator_item());
let pat_ty = self
.resolve_associated_type(into_iter_ty.clone(), self.resolve_iterator_item());
self.result.type_of_for_iterator.insert(tgt_expr, into_iter_ty);
self.infer_top_pat(pat, &pat_ty);
self.with_breakable_ctx(BreakableKind::Loop, None, label, |this| {

View file

@ -83,6 +83,10 @@ impl Operand {
fn from_bytes(data: Vec<u8>, ty: Ty) -> Self {
Operand::from_concrete_const(data, MemoryMap::default(), ty)
}
fn const_zst(ty: Ty) -> Operand {
Self::from_bytes(vec![], ty)
}
}
#[derive(Debug, PartialEq, Eq, Clone)]

View file

@ -1122,7 +1122,12 @@ impl Evaluator<'_> {
}
fn detect_lang_function(&self, def: FunctionId) -> Option<LangItem> {
lang_attr(self.db.upcast(), def)
let candidate = lang_attr(self.db.upcast(), def)?;
// filter normal lang functions out
if [LangItem::IntoIterIntoIter, LangItem::IteratorNext].contains(&candidate) {
return None;
}
Some(candidate)
}
fn create_memory_map(&self, bytes: &[u8], ty: &Ty, locals: &Locals<'_>) -> Result<MemoryMap> {

View file

@ -9,6 +9,7 @@ use hir_def::{
Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId,
RecordLitField,
},
lang_item::{LangItem, LangItemTarget},
layout::LayoutError,
resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
DefWithBodyId, EnumVariantId, HasModule,
@ -17,8 +18,8 @@ use la_arena::ArenaMap;
use crate::{
consteval::ConstEvalError, db::HirDatabase, display::HirDisplay, infer::TypeMismatch,
inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, utils::generics,
Adjust, AutoBorrow, CallableDefId, TyBuilder, TyExt,
inhabitedness::is_ty_uninhabited_from, layout::layout_of_ty, mapping::ToChalk, static_lifetime,
utils::generics, Adjust, AutoBorrow, CallableDefId, TyBuilder, TyExt,
};
use super::*;
@ -59,6 +60,7 @@ pub enum MirLowerError {
Loop,
/// Something that should never happen and is definitely a bug, but we don't want to panic if it happened
ImplementationError(&'static str),
LangItemNotFound(LangItem),
}
macro_rules! not_supported {
@ -484,13 +486,64 @@ impl MirLowerCtx<'_> {
Ok(())
})
}
Expr::For { .. } => not_supported!("for loop"),
&Expr::For { iterable, pat, body, label } => {
let into_iter_fn = self.resolve_lang_item(LangItem::IntoIterIntoIter)?
.as_function().ok_or(MirLowerError::LangItemNotFound(LangItem::IntoIterIntoIter))?;
let iter_next_fn = self.resolve_lang_item(LangItem::IteratorNext)?
.as_function().ok_or(MirLowerError::LangItemNotFound(LangItem::IteratorNext))?;
let option_some = self.resolve_lang_item(LangItem::OptionSome)?
.as_enum_variant().ok_or(MirLowerError::LangItemNotFound(LangItem::OptionSome))?;
let option = option_some.parent;
let into_iter_fn_op = Operand::const_zst(
TyKind::FnDef(
self.db.intern_callable_def(CallableDefId::FunctionId(into_iter_fn)).into(),
Substitution::from1(Interner, self.expr_ty(iterable))
).intern(Interner));
let iter_next_fn_op = Operand::const_zst(
TyKind::FnDef(
self.db.intern_callable_def(CallableDefId::FunctionId(iter_next_fn)).into(),
Substitution::from1(Interner, self.expr_ty(iterable))
).intern(Interner));
let iterator_ty = &self.infer.type_of_for_iterator[expr_id];
let ref_mut_iterator_ty = TyKind::Ref(Mutability::Mut, static_lifetime(), iterator_ty.clone()).intern(Interner);
let item_ty = &self.infer.type_of_pat[pat];
let option_item_ty = TyKind::Adt(chalk_ir::AdtId(option.into()), Substitution::from1(Interner, item_ty.clone())).intern(Interner);
let iterator_place: Place = self.temp(iterator_ty.clone())?.into();
let option_item_place: Place = self.temp(option_item_ty.clone())?.into();
let ref_mut_iterator_place: Place = self.temp(ref_mut_iterator_ty)?.into();
let Some(current) = self.lower_call_and_args(into_iter_fn_op, Some(iterable).into_iter(), iterator_place.clone(), current, false)?
else {
return Ok(None);
};
self.push_assignment(current, ref_mut_iterator_place.clone(), Rvalue::Ref(BorrowKind::Mut { allow_two_phase_borrow: false }, iterator_place), expr_id.into());
self.lower_loop(current, label, |this, begin| {
this.push_storage_live(pat, begin)?;
let Some(current) = this.lower_call(iter_next_fn_op, vec![Operand::Copy(ref_mut_iterator_place)], option_item_place.clone(), begin, false)?
else {
return Ok(());
};
let end = this.current_loop_end()?;
let (current, _) = this.pattern_matching_variant(
option_item_ty.clone(),
BindingAnnotation::Unannotated,
option_item_place.into(),
option_some.into(),
current,
pat.into(),
Some(end),
&[pat], &None)?;
if let (_, Some(block)) = this.lower_expr_to_some_place(body, current)? {
this.set_goto(block, begin);
}
Ok(())
})
},
Expr::Call { callee, args, .. } => {
let callee_ty = self.expr_ty_after_adjustments(*callee);
match &callee_ty.data(Interner).kind {
chalk_ir::TyKind::FnDef(..) => {
let func = Operand::from_bytes(vec![], callee_ty.clone());
self.lower_call(func, args.iter().copied(), place, current, self.is_uninhabited(expr_id))
self.lower_call_and_args(func, args.iter().copied(), place, current, self.is_uninhabited(expr_id))
}
TyKind::Scalar(_)
| TyKind::Tuple(_, _)
@ -527,7 +580,7 @@ impl MirLowerCtx<'_> {
)
.intern(Interner);
let func = Operand::from_bytes(vec![], ty);
self.lower_call(
self.lower_call_and_args(
func,
iter::once(*receiver).chain(args.iter().copied()),
place,
@ -962,7 +1015,7 @@ impl MirLowerCtx<'_> {
Ok(prev_block)
}
fn lower_call(
fn lower_call_and_args(
&mut self,
func: Operand,
args: impl Iterator<Item = ExprId>,
@ -983,6 +1036,17 @@ impl MirLowerCtx<'_> {
else {
return Ok(None);
};
self.lower_call(func, args, place, current, is_uninhabited)
}
fn lower_call(
&mut self,
func: Operand,
args: Vec<Operand>,
place: Place,
current: BasicBlockId,
is_uninhabited: bool,
) -> Result<Option<BasicBlockId>> {
let b = if is_uninhabited { None } else { Some(self.new_basic_block()) };
self.set_terminator(
current,
@ -1112,7 +1176,22 @@ impl MirLowerCtx<'_> {
Pat::Record { .. } => not_supported!("record pattern"),
Pat::Range { .. } => not_supported!("range pattern"),
Pat::Slice { .. } => not_supported!("slice pattern"),
Pat::Path(_) => not_supported!("path pattern"),
Pat::Path(_) => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
&[],
&None,
)?
}
Pat::Lit(l) => {
let then_target = self.new_basic_block();
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
@ -1183,17 +1262,43 @@ impl MirLowerCtx<'_> {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
args,
ellipsis,
)?
}
Pat::Ref { .. } => not_supported!("& pattern"),
Pat::Box { .. } => not_supported!("box pattern"),
Pat::ConstBlock(_) => not_supported!("const block pattern"),
})
}
fn pattern_matching_variant(
&mut self,
mut cond_ty: Ty,
mut binding_mode: BindingAnnotation,
mut cond_place: Place,
variant: VariantId,
current: BasicBlockId,
span: MirSpan,
current_else: Option<BasicBlockId>,
args: &[PatId],
ellipsis: &Option<usize>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => {
return Err(MirLowerError::TypeError(
"non adt type matched with tuple struct",
))
}
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
};
let fields_type = self.db.field_types(variant);
match variant {
Ok(match variant {
VariantId::EnumVariantId(v) => {
let e = self.db.const_eval_discriminant(v)? as u128;
let next = self.new_basic_block();
@ -1202,7 +1307,7 @@ impl MirLowerCtx<'_> {
current,
tmp.clone(),
Rvalue::Discriminant(cond_place.clone()),
pattern.into(),
span,
);
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
self.set_terminator(
@ -1214,14 +1319,12 @@ impl MirLowerCtx<'_> {
);
let enum_data = self.db.enum_data(v.parent);
let fields =
enum_data.variants[v.local_id].variant_data.fields().iter().map(
|(x, _)| {
enum_data.variants[v.local_id].variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
},
);
});
self.pattern_match_tuple_like(
next,
Some(else_target),
@ -1251,11 +1354,6 @@ impl MirLowerCtx<'_> {
VariantId::UnionId(_) => {
return Err(MirLowerError::TypeError("pattern matching on union"))
}
}
}
Pat::Ref { .. } => not_supported!("& pattern"),
Pat::Box { .. } => not_supported!("box pattern"),
Pat::ConstBlock(_) => not_supported!("const block pattern"),
})
}
@ -1384,6 +1482,11 @@ impl MirLowerCtx<'_> {
});
Ok(())
}
fn resolve_lang_item(&self, item: LangItem) -> Result<LangItemTarget> {
let crate_id = self.owner.module(self.db.upcast()).krate();
self.db.lang_item(crate_id, item).ok_or(MirLowerError::LangItemNotFound(item))
}
}
fn pattern_matching_dereference(

View file

@ -507,6 +507,22 @@ fn f(x: i32) {
x = 5;
//^^^^^ 💡 error: cannot mutate immutable variable `x`
}
"#,
);
}
#[test]
fn for_loop() {
check_diagnostics(
r#"
//- minicore: iterators
fn f(x: [(i32, u8); 10]) {
for (a, mut b) in x {
//^^^^^ 💡 weak: remove this `mut`
a = 2;
//^^^^^ 💡 error: cannot mutate immutable variable `a`
}
}
"#,
);
}

View file

@ -728,6 +728,20 @@ pub mod iter {
self
}
}
pub struct IntoIter<T, const N: usize>([T; N]);
impl<T, const N: usize> IntoIterator for [T; N] {
type Item = T;
type IntoIter = IntoIter<T, N>;
fn into_iter(self) -> I {
IntoIter(self)
}
}
impl<T, const N: usize> Iterator for IntoIter<T, N> {
type Item = T;
fn next(&mut self) -> Option<T> {
loop {}
}
}
}
pub use self::collect::IntoIterator;
}