Support overloaded index MIR lowering

This commit is contained in:
hkalbasi 2023-03-17 19:10:25 +03:30
parent eb4939e217
commit 9ad83deecc
4 changed files with 169 additions and 3 deletions

View file

@ -2,7 +2,8 @@ use base_db::fixture::WithFixture;
use hir_def::db::DefDatabase;
use crate::{
consteval::try_const_usize, db::HirDatabase, test_db::TestDB, Const, ConstScalar, Interner,
consteval::try_const_usize, db::HirDatabase, mir::pad16, test_db::TestDB, Const, ConstScalar,
Interner,
};
use super::{
@ -30,7 +31,12 @@ fn check_number(ra_fixture: &str, answer: i128) {
match &r.data(Interner).value {
chalk_ir::ConstValue::Concrete(c) => match &c.interned {
ConstScalar::Bytes(b, _) => {
assert_eq!(b, &answer.to_le_bytes()[0..b.len()]);
assert_eq!(
b,
&answer.to_le_bytes()[0..b.len()],
"Bytes differ. In decimal form: actual = {}, expected = {answer}",
i128::from_le_bytes(pad16(b, true))
);
}
x => panic!("Expected number but found {:?}", x),
},
@ -215,6 +221,42 @@ fn overloaded_deref_autoref() {
);
}
#[test]
fn overloaded_index() {
check_number(
r#"
//- minicore: index
struct Foo;
impl core::ops::Index<usize> for Foo {
type Output = i32;
fn index(&self, index: usize) -> &i32 {
if index == 7 {
&700
} else {
&1000
}
}
}
impl core::ops::IndexMut<usize> for Foo {
fn index_mut(&mut self, index: usize) -> &mut i32 {
if index == 7 {
&mut 7
} else {
&mut 10
}
}
}
const GOAL: i32 = {
(Foo[2]) + (Foo[7]) + (*&Foo[2]) + (*&Foo[7]) + (*&mut Foo[2]) + (*&mut Foo[7])
};
"#,
3417,
);
}
#[test]
fn function_call() {
check_number(

View file

@ -95,6 +95,21 @@ impl<'a> InferenceContext<'a> {
self.infer_mut_not_expr_iter(fields.iter().map(|x| x.expr).chain(*spread))
}
&Expr::Index { base, index } => {
if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
if mutability == Mutability::Mut {
if let Some(index_trait) = self
.db
.lang_item(self.table.trait_env.krate, LangItem::IndexMut)
.and_then(|l| l.as_trait())
{
if let Some(index_fn) =
self.db.trait_data(index_trait).method_by_name(&name![index_mut])
{
*f = index_fn;
}
}
}
}
self.infer_mut_expr(base, mutability);
self.infer_mut_expr(index, Mutability::Not);
}

View file

@ -1,6 +1,7 @@
//! MIR lowering for places
use super::*;
use hir_def::FunctionId;
use hir_expand::name;
macro_rules! not_supported {
@ -193,7 +194,24 @@ impl MirLowerCtx<'_> {
if index_ty != TyBuilder::usize()
|| !matches!(base_ty.kind(Interner), TyKind::Array(..) | TyKind::Slice(..))
{
not_supported!("overloaded index");
let Some(index_fn) = self.infer.method_resolution(expr_id) else {
return Err(MirLowerError::UnresolvedMethod);
};
let Some((base_place, current)) = self.lower_expr_as_place(current, *base, true)? else {
return Ok(None);
};
let Some((index_operand, current)) = self.lower_expr_to_some_operand(*index, current)? else {
return Ok(None);
};
return self.lower_overloaded_index(
current,
base_place,
self.expr_ty_after_adjustments(*base),
self.expr_ty(expr_id),
index_operand,
expr_id.into(),
index_fn,
);
}
let Some((mut p_base, current)) =
self.lower_expr_as_place(current, *base, true)? else {
@ -210,6 +228,49 @@ impl MirLowerCtx<'_> {
}
}
fn lower_overloaded_index(
&mut self,
current: BasicBlockId,
place: Place,
base_ty: Ty,
result_ty: Ty,
index_operand: Operand,
span: MirSpan,
index_fn: (FunctionId, Substitution),
) -> Result<Option<(Place, BasicBlockId)>> {
let is_mutable = 'b: {
if let Some(index_mut_trait) = self.resolve_lang_item(LangItem::IndexMut)?.as_trait() {
if let Some(index_mut_fn) =
self.db.trait_data(index_mut_trait).method_by_name(&name![index_mut])
{
break 'b index_mut_fn == index_fn.0;
}
}
false
};
let (mutability, borrow_kind) = match is_mutable {
true => (Mutability::Mut, BorrowKind::Mut { allow_two_phase_borrow: false }),
false => (Mutability::Not, BorrowKind::Shared),
};
let base_ref = TyKind::Ref(mutability, static_lifetime(), base_ty).intern(Interner);
let result_ref = TyKind::Ref(mutability, static_lifetime(), result_ty).intern(Interner);
let ref_place: Place = self.temp(base_ref)?.into();
self.push_assignment(current, ref_place.clone(), Rvalue::Ref(borrow_kind, place), span);
let mut result: Place = self.temp(result_ref)?.into();
let index_fn_op = Operand::const_zst(
TyKind::FnDef(
self.db.intern_callable_def(CallableDefId::FunctionId(index_fn.0)).into(),
index_fn.1,
)
.intern(Interner),
);
let Some(current) = self.lower_call(index_fn_op, vec![Operand::Copy(ref_place), index_operand], result.clone(), current, false)? else {
return Ok(None);
};
result.projection.push(ProjectionElem::Deref);
Ok(Some((result, current)))
}
fn lower_overloaded_deref(
&mut self,
current: BasicBlockId,

View file

@ -564,6 +564,54 @@ fn f(x: [(i32, u8); 10]) {
);
}
#[test]
fn overloaded_index() {
check_diagnostics(
r#"
//- minicore: index
use core::ops::{Index, IndexMut};
struct Foo;
impl Index<usize> for Foo {
type Output = (i32, u8);
fn index(&self, index: usize) -> &(i32, u8) {
&(5, 2)
}
}
impl IndexMut<usize> for Foo {
fn index_mut(&mut self, index: usize) -> &mut (i32, u8) {
&mut (5, 2)
}
}
fn f() {
let mut x = Foo;
//^^^^^ 💡 weak: variable does not need to be mutable
let y = &x[2];
let x = Foo;
let y = &mut x[2];
//^^^^ 💡 error: cannot mutate immutable variable `x`
let mut x = &mut Foo;
//^^^^^ 💡 weak: variable does not need to be mutable
let y: &mut (i32, u8) = &mut x[2];
let x = Foo;
let ref mut y = x[7];
//^^^^ 💡 error: cannot mutate immutable variable `x`
let (ref mut y, _) = x[3];
//^^^^ 💡 error: cannot mutate immutable variable `x`
match x[10] {
//^^^^^ 💡 error: cannot mutate immutable variable `x`
(ref y, _) => (),
(_, ref mut y) => (),
}
let mut x = Foo;
let mut i = 5;
//^^^^^ 💡 weak: variable does not need to be mutable
let y = &mut x[i];
}
"#,
);
}
#[test]
fn overloaded_deref() {
check_diagnostics(