Auto merge of #13971 - lowr:fix/more-precise-builtin-binop-types, r=Veykril

fix: more precise binop inference

While inferring binary operator expressions, Rust puts some extra constraints on the types of the operands for better inference. Relevant part in rustc is [this](159ba8a92c/compiler/rustc_hir_typeck/src/op.rs (L128-L152)).

There are two things we currently fail to consider:
- we should enforce them only when both lhs and rhs type are builtin types that are applicable to the binop
- lhs and rhs types may be single reference to applicable builtin types

This PR basically ports [`enforce_builtin_binop_types()`](159ba8a92c/compiler/rustc_hir_typeck/src/op.rs (L159)) and [`is_builtin_binop()`](159ba8a92c/compiler/rustc_hir_typeck/src/op.rs (LL927)) to our inference context.
This commit is contained in:
bors 2023-01-17 12:39:53 +00:00
commit 492b3deba7
4 changed files with 250 additions and 126 deletions

View file

@ -1,6 +1,6 @@
//! Various extensions traits for Chalk types. //! Various extensions traits for Chalk types.
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, UintTy}; use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, TyVariableKind, UintTy};
use hir_def::{ use hir_def::{
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinType, BuiltinUint}, builtin_type::{BuiltinFloat, BuiltinInt, BuiltinType, BuiltinUint},
generics::TypeOrConstParamData, generics::TypeOrConstParamData,
@ -18,6 +18,8 @@ use crate::{
pub trait TyExt { pub trait TyExt {
fn is_unit(&self) -> bool; fn is_unit(&self) -> bool;
fn is_integral(&self) -> bool;
fn is_floating_point(&self) -> bool;
fn is_never(&self) -> bool; fn is_never(&self) -> bool;
fn is_unknown(&self) -> bool; fn is_unknown(&self) -> bool;
fn is_ty_var(&self) -> bool; fn is_ty_var(&self) -> bool;
@ -51,6 +53,21 @@ impl TyExt for Ty {
matches!(self.kind(Interner), TyKind::Tuple(0, _)) matches!(self.kind(Interner), TyKind::Tuple(0, _))
} }
fn is_integral(&self) -> bool {
matches!(
self.kind(Interner),
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
| TyKind::InferenceVar(_, TyVariableKind::Integer)
)
}
fn is_floating_point(&self) -> bool {
matches!(
self.kind(Interner),
TyKind::Scalar(Scalar::Float(_)) | TyKind::InferenceVar(_, TyVariableKind::Float)
)
}
fn is_never(&self) -> bool { fn is_never(&self) -> bool {
matches!(self.kind(Interner), TyKind::Never) matches!(self.kind(Interner), TyKind::Never)
} }

View file

@ -1041,10 +1041,6 @@ impl Expectation {
} }
} }
fn from_option(ty: Option<Ty>) -> Self {
ty.map_or(Expectation::None, Expectation::HasType)
}
/// The following explanation is copied straight from rustc: /// The following explanation is copied straight from rustc:
/// Provides an expectation for an rvalue expression given an *optional* /// Provides an expectation for an rvalue expression given an *optional*
/// hint, which is not required for type safety (the resulting type might /// hint, which is not required for type safety (the resulting type might

View file

@ -10,8 +10,7 @@ use chalk_ir::{
}; };
use hir_def::{ use hir_def::{
expr::{ expr::{
ArithOp, Array, BinaryOp, ClosureKind, CmpOp, Expr, ExprId, LabelId, Literal, Statement, ArithOp, Array, BinaryOp, ClosureKind, Expr, ExprId, LabelId, Literal, Statement, UnaryOp,
UnaryOp,
}, },
generics::TypeOrConstParamData, generics::TypeOrConstParamData,
path::{GenericArg, GenericArgs}, path::{GenericArg, GenericArgs},
@ -1017,11 +1016,21 @@ impl<'a> InferenceContext<'a> {
let (trait_, func) = match trait_func { let (trait_, func) = match trait_func {
Some(it) => it, Some(it) => it,
None => { None => {
let rhs_ty = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()); // HACK: `rhs_ty` is a general inference variable with no clue at all at this
let rhs_ty = self.infer_expr_coerce(rhs, &Expectation::from_option(rhs_ty)); // point. Passing `lhs_ty` as both operands just to check if `lhs_ty` is a builtin
return self // type applicable to `op`.
.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty) let ret_ty = if self.is_builtin_binop(&lhs_ty, &lhs_ty, op) {
.unwrap_or_else(|| self.err_ty()); // Assume both operands are builtin so we can continue inference. No guarantee
// on the correctness, rustc would complain as necessary lang items don't seem
// to exist anyway.
self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op)
} else {
self.err_ty()
};
self.infer_expr_coerce(rhs, &Expectation::has_type(rhs_ty));
return ret_ty;
} }
}; };
@ -1071,11 +1080,9 @@ impl<'a> InferenceContext<'a> {
let ret_ty = self.normalize_associated_types_in(ret_ty); let ret_ty = self.normalize_associated_types_in(ret_ty);
// use knowledge of built-in binary ops, which can sometimes help inference if self.is_builtin_binop(&lhs_ty, &rhs_ty, op) {
if let Some(builtin_rhs) = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()) { // use knowledge of built-in binary ops, which can sometimes help inference
self.unify(&builtin_rhs, &rhs_ty); let builtin_ret = self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op);
}
if let Some(builtin_ret) = self.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty) {
self.unify(&builtin_ret, &ret_ty); self.unify(&builtin_ret, &ret_ty);
} }
@ -1477,92 +1484,124 @@ impl<'a> InferenceContext<'a> {
indices indices
} }
fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> { /// Dereferences a single level of immutable referencing.
let lhs_ty = self.resolve_ty_shallow(&lhs_ty); fn deref_ty_if_possible(&mut self, ty: &Ty) -> Ty {
let rhs_ty = self.resolve_ty_shallow(&rhs_ty); let ty = self.resolve_ty_shallow(ty);
match op { match ty.kind(Interner) {
BinaryOp::LogicOp(_) | BinaryOp::CmpOp(_) => { TyKind::Ref(Mutability::Not, _, inner) => self.resolve_ty_shallow(inner),
Some(TyKind::Scalar(Scalar::Bool).intern(Interner)) _ => ty,
}
BinaryOp::Assignment { .. } => Some(TyBuilder::unit()),
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
// all integer combinations are valid here
if matches!(
lhs_ty.kind(Interner),
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
| TyKind::InferenceVar(_, TyVariableKind::Integer)
) && matches!(
rhs_ty.kind(Interner),
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
| TyKind::InferenceVar(_, TyVariableKind::Integer)
) {
Some(lhs_ty)
} else {
None
}
}
BinaryOp::ArithOp(_) => match (lhs_ty.kind(Interner), rhs_ty.kind(Interner)) {
// (int, int) | (uint, uint) | (float, float)
(TyKind::Scalar(Scalar::Int(_)), TyKind::Scalar(Scalar::Int(_)))
| (TyKind::Scalar(Scalar::Uint(_)), TyKind::Scalar(Scalar::Uint(_)))
| (TyKind::Scalar(Scalar::Float(_)), TyKind::Scalar(Scalar::Float(_))) => {
Some(rhs_ty)
}
// ({int}, int) | ({int}, uint)
(
TyKind::InferenceVar(_, TyVariableKind::Integer),
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
) => Some(rhs_ty),
// (int, {int}) | (uint, {int})
(
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
TyKind::InferenceVar(_, TyVariableKind::Integer),
) => Some(lhs_ty),
// ({float} | float)
(
TyKind::InferenceVar(_, TyVariableKind::Float),
TyKind::Scalar(Scalar::Float(_)),
) => Some(rhs_ty),
// (float, {float})
(
TyKind::Scalar(Scalar::Float(_)),
TyKind::InferenceVar(_, TyVariableKind::Float),
) => Some(lhs_ty),
// ({int}, {int}) | ({float}, {float})
(
TyKind::InferenceVar(_, TyVariableKind::Integer),
TyKind::InferenceVar(_, TyVariableKind::Integer),
)
| (
TyKind::InferenceVar(_, TyVariableKind::Float),
TyKind::InferenceVar(_, TyVariableKind::Float),
) => Some(rhs_ty),
_ => None,
},
} }
} }
fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Option<Ty> { /// Enforces expectations on lhs type and rhs type depending on the operator and returns the
Some(match op { /// output type of the binary op.
BinaryOp::LogicOp(..) => TyKind::Scalar(Scalar::Bool).intern(Interner), fn enforce_builtin_binop_types(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> Ty {
BinaryOp::Assignment { op: None } => lhs_ty, // Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
BinaryOp::CmpOp(CmpOp::Eq { .. }) => match self let lhs = self.deref_ty_if_possible(lhs);
.resolve_ty_shallow(&lhs_ty) let rhs = self.deref_ty_if_possible(rhs);
.kind(Interner)
{ let (op, is_assign) = match op {
TyKind::Scalar(_) | TyKind::Str => lhs_ty, BinaryOp::Assignment { op: Some(inner) } => (BinaryOp::ArithOp(inner), true),
TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty, _ => (op, false),
_ => return None, };
},
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => return None, let output_ty = match op {
BinaryOp::CmpOp(CmpOp::Ord { .. }) BinaryOp::LogicOp(_) => {
| BinaryOp::Assignment { op: Some(_) } let bool_ = self.result.standard_types.bool_.clone();
| BinaryOp::ArithOp(_) => match self.resolve_ty_shallow(&lhs_ty).kind(Interner) { self.unify(&lhs, &bool_);
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_)) => lhs_ty, self.unify(&rhs, &bool_);
TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty, bool_
_ => return None, }
},
}) BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
// result type is same as LHS always
lhs
}
BinaryOp::ArithOp(_) => {
// LHS, RHS, and result will have the same type
self.unify(&lhs, &rhs);
lhs
}
BinaryOp::CmpOp(_) => {
// LHS and RHS will have the same type
self.unify(&lhs, &rhs);
self.result.standard_types.bool_.clone()
}
BinaryOp::Assignment { op: None } => {
stdx::never!("Simple assignment operator is not binary op.");
lhs
}
BinaryOp::Assignment { .. } => unreachable!("handled above"),
};
if is_assign {
self.result.standard_types.unit.clone()
} else {
output_ty
}
}
fn is_builtin_binop(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> bool {
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
let lhs = self.deref_ty_if_possible(lhs);
let rhs = self.deref_ty_if_possible(rhs);
let op = match op {
BinaryOp::Assignment { op: Some(inner) } => BinaryOp::ArithOp(inner),
_ => op,
};
match op {
BinaryOp::LogicOp(_) => true,
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
lhs.is_integral() && rhs.is_integral()
}
BinaryOp::ArithOp(
ArithOp::Add | ArithOp::Sub | ArithOp::Mul | ArithOp::Div | ArithOp::Rem,
) => {
lhs.is_integral() && rhs.is_integral()
|| lhs.is_floating_point() && rhs.is_floating_point()
}
BinaryOp::ArithOp(ArithOp::BitAnd | ArithOp::BitOr | ArithOp::BitXor) => {
lhs.is_integral() && rhs.is_integral()
|| lhs.is_floating_point() && rhs.is_floating_point()
|| matches!(
(lhs.kind(Interner), rhs.kind(Interner)),
(TyKind::Scalar(Scalar::Bool), TyKind::Scalar(Scalar::Bool))
)
}
BinaryOp::CmpOp(_) => {
let is_scalar = |kind| {
matches!(
kind,
&TyKind::Scalar(_)
| TyKind::FnDef(..)
| TyKind::Function(_)
| TyKind::Raw(..)
| TyKind::InferenceVar(
_,
TyVariableKind::Integer | TyVariableKind::Float
)
)
};
is_scalar(lhs.kind(Interner)) && is_scalar(rhs.kind(Interner))
}
BinaryOp::Assignment { op: None } => {
stdx::never!("Simple assignment operator is not binary op.");
false
}
BinaryOp::Assignment { .. } => unreachable!("handled above"),
}
} }
fn with_breakable_ctx<T>( fn with_breakable_ctx<T>(

View file

@ -3507,14 +3507,9 @@ trait Request {
fn bin_op_adt_with_rhs_primitive() { fn bin_op_adt_with_rhs_primitive() {
check_infer_with_mismatches( check_infer_with_mismatches(
r#" r#"
#[lang = "add"] //- minicore: add
pub trait Add<Rhs = Self> {
type Output;
fn add(self, rhs: Rhs) -> Self::Output;
}
struct Wrapper(u32); struct Wrapper(u32);
impl Add<u32> for Wrapper { impl core::ops::Add<u32> for Wrapper {
type Output = Self; type Output = Self;
fn add(self, rhs: u32) -> Wrapper { fn add(self, rhs: u32) -> Wrapper {
Wrapper(rhs) Wrapper(rhs)
@ -3527,29 +3522,106 @@ fn main(){
}"#, }"#,
expect![[r#" expect![[r#"
72..76 'self': Self 95..99 'self': Wrapper
78..81 'rhs': Rhs 101..104 'rhs': u32
192..196 'self': Wrapper 122..150 '{ ... }': Wrapper
198..201 'rhs': u32 132..139 'Wrapper': Wrapper(u32) -> Wrapper
219..247 '{ ... }': Wrapper 132..144 'Wrapper(rhs)': Wrapper
229..236 'Wrapper': Wrapper(u32) -> Wrapper 140..143 'rhs': u32
229..241 'Wrapper(rhs)': Wrapper 162..248 '{ ...um; }': ()
237..240 'rhs': u32 172..179 'wrapped': Wrapper
259..345 '{ ...um; }': () 182..189 'Wrapper': Wrapper(u32) -> Wrapper
269..276 'wrapped': Wrapper 182..193 'Wrapper(10)': Wrapper
279..286 'Wrapper': Wrapper(u32) -> Wrapper 190..192 '10': u32
279..290 'Wrapper(10)': Wrapper 203..206 'num': u32
287..289 '10': u32 214..215 '2': u32
300..303 'num': u32 225..228 'res': Wrapper
311..312 '2': u32 231..238 'wrapped': Wrapper
322..325 'res': Wrapper 231..244 'wrapped + num': Wrapper
328..335 'wrapped': Wrapper 241..244 'num': u32
328..341 'wrapped + num': Wrapper
338..341 'num': u32
"#]], "#]],
) )
} }
#[test]
fn builtin_binop_expectation_works_on_single_reference() {
check_types(
r#"
//- minicore: add
use core::ops::Add;
impl Add<i32> for i32 { type Output = i32 }
impl Add<&i32> for i32 { type Output = i32 }
impl Add<u32> for u32 { type Output = u32 }
impl Add<&u32> for u32 { type Output = u32 }
struct V<T>;
impl<T> V<T> {
fn default() -> Self { loop {} }
fn get(&self, _: &T) -> &T { loop {} }
}
fn take_u32(_: u32) {}
fn minimized() {
let v = V::default();
let p = v.get(&0);
//^ &u32
take_u32(42 + p);
}
"#,
);
}
#[test]
fn no_builtin_binop_expectation_for_general_ty_var() {
// FIXME: Ideally type mismatch should be reported on `take_u32(42 - p)`.
check_types(
r#"
//- minicore: add
use core::ops::Add;
impl Add<i32> for i32 { type Output = i32; }
impl Add<&i32> for i32 { type Output = i32; }
// This is needed to prevent chalk from giving unique solution to `i32: Add<&?0>` after applying
// fallback to integer type variable for `42`.
impl Add<&()> for i32 { type Output = (); }
struct V<T>;
impl<T> V<T> {
fn default() -> Self { loop {} }
fn get(&self) -> &T { loop {} }
}
fn take_u32(_: u32) {}
fn minimized() {
let v = V::default();
let p = v.get();
//^ &{unknown}
take_u32(42 + p);
}
"#,
);
}
#[test]
fn no_builtin_binop_expectation_for_non_builtin_types() {
check_no_mismatches(
r#"
//- minicore: default, eq
struct S;
impl Default for S { fn default() -> Self { S } }
impl Default for i32 { fn default() -> Self { 0 } }
impl PartialEq<S> for i32 { fn eq(&self, _: &S) -> bool { true } }
impl PartialEq<i32> for i32 { fn eq(&self, _: &S) -> bool { true } }
fn take_s(_: S) {}
fn test() {
let s = Default::default();
let _eq = 0 == s;
take_s(s);
}
"#,
)
}
#[test] #[test]
fn array_length() { fn array_length() {
check_infer( check_infer(