4667: Infer labelled breaks correctly r=flodiebold a=robojumper

Fixes #4663.

Co-authored-by: robojumper <robojumper@gmail.com>
This commit is contained in:
bors[bot] 2020-05-31 12:03:24 +00:00 committed by GitHub
commit 5579ba8af5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 172 additions and 37 deletions

View file

@ -134,7 +134,7 @@ impl ExprCollector<'_> {
self.make_expr(expr, Err(SyntheticSyntax))
}
fn empty_block(&mut self) -> ExprId {
self.alloc_expr_desugared(Expr::Block { statements: Vec::new(), tail: None })
self.alloc_expr_desugared(Expr::Block { statements: Vec::new(), tail: None, label: None })
}
fn missing_expr(&mut self) -> ExprId {
self.alloc_expr_desugared(Expr::Missing)
@ -215,7 +215,16 @@ impl ExprCollector<'_> {
ast::Expr::BlockExpr(e) => self.collect_block(e),
ast::Expr::LoopExpr(e) => {
let body = self.collect_block_opt(e.loop_body());
self.alloc_expr(Expr::Loop { body }, syntax_ptr)
self.alloc_expr(
Expr::Loop {
body,
label: e
.label()
.and_then(|l| l.lifetime_token())
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
)
}
ast::Expr::WhileExpr(e) => {
let body = self.collect_block_opt(e.loop_body());
@ -230,25 +239,56 @@ impl ExprCollector<'_> {
let pat = self.collect_pat(pat);
let match_expr = self.collect_expr_opt(condition.expr());
let placeholder_pat = self.missing_pat();
let break_ = self.alloc_expr_desugared(Expr::Break { expr: None });
let break_ =
self.alloc_expr_desugared(Expr::Break { expr: None, label: None });
let arms = vec![
MatchArm { pat, expr: body, guard: None },
MatchArm { pat: placeholder_pat, expr: break_, guard: None },
];
let match_expr =
self.alloc_expr_desugared(Expr::Match { expr: match_expr, arms });
return self.alloc_expr(Expr::Loop { body: match_expr }, syntax_ptr);
return self.alloc_expr(
Expr::Loop {
body: match_expr,
label: e
.label()
.and_then(|l| l.lifetime_token())
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
);
}
},
};
self.alloc_expr(Expr::While { condition, body }, syntax_ptr)
self.alloc_expr(
Expr::While {
condition,
body,
label: e
.label()
.and_then(|l| l.lifetime_token())
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
)
}
ast::Expr::ForExpr(e) => {
let iterable = self.collect_expr_opt(e.iterable());
let pat = self.collect_pat_opt(e.pat());
let body = self.collect_block_opt(e.loop_body());
self.alloc_expr(Expr::For { iterable, pat, body }, syntax_ptr)
self.alloc_expr(
Expr::For {
iterable,
pat,
body,
label: e
.label()
.and_then(|l| l.lifetime_token())
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
)
}
ast::Expr::CallExpr(e) => {
let callee = self.collect_expr_opt(e.expr());
@ -301,13 +341,16 @@ impl ExprCollector<'_> {
.unwrap_or(Expr::Missing);
self.alloc_expr(path, syntax_ptr)
}
ast::Expr::ContinueExpr(_e) => {
// FIXME: labels
self.alloc_expr(Expr::Continue, syntax_ptr)
}
ast::Expr::ContinueExpr(e) => self.alloc_expr(
Expr::Continue { label: e.lifetime_token().map(|l| Name::new_lifetime(&l)) },
syntax_ptr,
),
ast::Expr::BreakExpr(e) => {
let expr = e.expr().map(|e| self.collect_expr(e));
self.alloc_expr(Expr::Break { expr }, syntax_ptr)
self.alloc_expr(
Expr::Break { expr, label: e.lifetime_token().map(|l| Name::new_lifetime(&l)) },
syntax_ptr,
)
}
ast::Expr::ParenExpr(e) => {
let inner = self.collect_expr_opt(e.expr());
@ -529,7 +572,8 @@ impl ExprCollector<'_> {
})
.collect();
let tail = block.expr().map(|e| self.collect_expr(e));
self.alloc_expr(Expr::Block { statements, tail }, syntax_node_ptr)
let label = block.label().and_then(|l| l.lifetime_token()).map(|t| Name::new_lifetime(&t));
self.alloc_expr(Expr::Block { statements, tail, label }, syntax_node_ptr)
}
fn collect_block_items(&mut self, block: &ast::BlockExpr) {

View file

@ -138,10 +138,10 @@ fn compute_block_scopes(
fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope: ScopeId) {
scopes.set_scope(expr, scope);
match &body[expr] {
Expr::Block { statements, tail } => {
Expr::Block { statements, tail, .. } => {
compute_block_scopes(&statements, *tail, body, scopes, scope);
}
Expr::For { iterable, pat, body: body_expr } => {
Expr::For { iterable, pat, body: body_expr, .. } => {
compute_expr_scopes(*iterable, body, scopes, scope);
let scope = scopes.new_scope(scope);
scopes.add_bindings(body, scope, *pat);

View file

@ -52,18 +52,22 @@ pub enum Expr {
Block {
statements: Vec<Statement>,
tail: Option<ExprId>,
label: Option<Name>,
},
Loop {
body: ExprId,
label: Option<Name>,
},
While {
condition: ExprId,
body: ExprId,
label: Option<Name>,
},
For {
iterable: ExprId,
pat: PatId,
body: ExprId,
label: Option<Name>,
},
Call {
callee: ExprId,
@ -79,9 +83,12 @@ pub enum Expr {
expr: ExprId,
arms: Vec<MatchArm>,
},
Continue,
Continue {
label: Option<Name>,
},
Break {
expr: Option<ExprId>,
label: Option<Name>,
},
Return {
expr: Option<ExprId>,
@ -225,7 +232,7 @@ impl Expr {
f(*else_branch);
}
}
Expr::Block { statements, tail } => {
Expr::Block { statements, tail, .. } => {
for stmt in statements {
match stmt {
Statement::Let { initializer, .. } => {
@ -241,8 +248,8 @@ impl Expr {
}
}
Expr::TryBlock { body } => f(*body),
Expr::Loop { body } => f(*body),
Expr::While { condition, body } => {
Expr::Loop { body, .. } => f(*body),
Expr::While { condition, body, .. } => {
f(*condition);
f(*body);
}
@ -268,8 +275,8 @@ impl Expr {
f(arm.expr);
}
}
Expr::Continue => {}
Expr::Break { expr } | Expr::Return { expr } => {
Expr::Continue { .. } => {}
Expr::Break { expr, .. } | Expr::Return { expr } => {
if let Some(expr) = expr {
f(*expr);
}

View file

@ -37,6 +37,11 @@ impl Name {
Name(Repr::TupleField(idx))
}
pub fn new_lifetime(lt: &ra_syntax::SyntaxToken) -> Name {
assert!(lt.kind() == ra_syntax::SyntaxKind::LIFETIME);
Name(Repr::Text(lt.text().clone()))
}
/// Shortcut to create inline plain text name
const fn new_inline_ascii(text: &[u8]) -> Name {
Name::new_text(SmolStr::new_inline_from_ascii(text.len(), text))

View file

@ -219,6 +219,17 @@ struct InferenceContext<'a> {
struct BreakableContext {
pub may_break: bool,
pub break_ty: Ty,
pub label: Option<name::Name>,
}
fn find_breakable<'c>(
ctxs: &'c mut [BreakableContext],
label: Option<&name::Name>,
) -> Option<&'c mut BreakableContext> {
match label {
Some(_) => ctxs.iter_mut().rev().find(|ctx| ctx.label.as_ref() == label),
None => ctxs.last_mut(),
}
}
impl<'a> InferenceContext<'a> {

View file

@ -22,8 +22,8 @@ use crate::{
};
use super::{
BindingMode, BreakableContext, Diverges, Expectation, InferenceContext, InferenceDiagnostic,
TypeMismatch,
find_breakable, BindingMode, BreakableContext, Diverges, Expectation, InferenceContext,
InferenceDiagnostic, TypeMismatch,
};
impl<'a> InferenceContext<'a> {
@ -86,16 +86,20 @@ impl<'a> InferenceContext<'a> {
self.coerce_merge_branch(&then_ty, &else_ty)
}
Expr::Block { statements, tail } => self.infer_block(statements, *tail, expected),
Expr::Block { statements, tail, .. } => {
// FIXME: Breakable block inference
self.infer_block(statements, *tail, expected)
}
Expr::TryBlock { body } => {
let _inner = self.infer_expr(*body, expected);
// FIXME should be std::result::Result<{inner}, _>
Ty::Unknown
}
Expr::Loop { body } => {
Expr::Loop { body, label } => {
self.breakables.push(BreakableContext {
may_break: false,
break_ty: self.table.new_type_var(),
label: label.clone(),
});
self.infer_expr(*body, &Expectation::has_type(Ty::unit()));
@ -110,8 +114,12 @@ impl<'a> InferenceContext<'a> {
Ty::simple(TypeCtor::Never)
}
}
Expr::While { condition, body } => {
self.breakables.push(BreakableContext { may_break: false, break_ty: Ty::Unknown });
Expr::While { condition, body, label } => {
self.breakables.push(BreakableContext {
may_break: false,
break_ty: Ty::Unknown,
label: label.clone(),
});
// while let is desugared to a match loop, so this is always simple while
self.infer_expr(*condition, &Expectation::has_type(Ty::simple(TypeCtor::Bool)));
self.infer_expr(*body, &Expectation::has_type(Ty::unit()));
@ -120,10 +128,14 @@ impl<'a> InferenceContext<'a> {
self.diverges = Diverges::Maybe;
Ty::unit()
}
Expr::For { iterable, body, pat } => {
Expr::For { iterable, body, pat, label } => {
let iterable_ty = self.infer_expr(*iterable, &Expectation::none());
self.breakables.push(BreakableContext { may_break: false, break_ty: Ty::Unknown });
self.breakables.push(BreakableContext {
may_break: false,
break_ty: Ty::Unknown,
label: label.clone(),
});
let pat_ty =
self.resolve_associated_type(iterable_ty, self.resolve_into_iter_item());
@ -236,23 +248,24 @@ impl<'a> InferenceContext<'a> {
let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr);
self.infer_path(&resolver, p, tgt_expr.into()).unwrap_or(Ty::Unknown)
}
Expr::Continue => Ty::simple(TypeCtor::Never),
Expr::Break { expr } => {
Expr::Continue { .. } => Ty::simple(TypeCtor::Never),
Expr::Break { expr, label } => {
let val_ty = if let Some(expr) = expr {
self.infer_expr(*expr, &Expectation::none())
} else {
Ty::unit()
};
let last_ty = if let Some(ctxt) = self.breakables.last() {
ctxt.break_ty.clone()
} else {
Ty::Unknown
};
let last_ty =
if let Some(ctxt) = find_breakable(&mut self.breakables, label.as_ref()) {
ctxt.break_ty.clone()
} else {
Ty::Unknown
};
let merged_type = self.coerce_merge_branch(&last_ty, &val_ty);
if let Some(ctxt) = self.breakables.last_mut() {
if let Some(ctxt) = find_breakable(&mut self.breakables, label.as_ref()) {
ctxt.break_ty = merged_type;
ctxt.may_break = true;
} else {

View file

@ -1943,3 +1943,57 @@ fn test() {
"###
);
}
#[test]
fn infer_labelled_break_with_val() {
assert_snapshot!(
infer(r#"
fn foo() {
let _x = || 'outer: loop {
let inner = 'inner: loop {
let i = Default::default();
if (break 'outer i) {
loop { break 'inner 5i8; };
} else if true {
break 'inner 6;
}
break 7;
};
break inner < 8;
};
}
"#),
@r###"
10..336 '{ ... }; }': ()
20..22 '_x': || -> bool
25..333 '|| 'ou... }': || -> bool
28..333 ''outer... }': bool
41..333 '{ ... }': ()
55..60 'inner': i8
63..301 ''inner... }': i8
76..301 '{ ... }': ()
94..95 'i': bool
98..114 'Defaul...efault': {unknown}
98..116 'Defaul...ault()': bool
130..270 'if (br... }': ()
134..148 'break 'outer i': !
147..148 'i': bool
150..209 '{ ... }': ()
168..194 'loop {...5i8; }': !
173..194 '{ brea...5i8; }': ()
175..191 'break ...er 5i8': !
188..191 '5i8': i8
215..270 'if tru... }': ()
218..222 'true': bool
223..270 '{ ... }': ()
241..255 'break 'inner 6': !
254..255 '6': i8
283..290 'break 7': !
289..290 '7': i8
311..326 'break inner < 8': !
317..322 'inner': i8
317..326 'inner < 8': bool
325..326 '8': i8
"###
);
}

View file

@ -1081,6 +1081,7 @@ pub struct BlockExpr {
impl ast::AttrsOwner for BlockExpr {}
impl ast::ModuleItemOwner for BlockExpr {}
impl BlockExpr {
pub fn label(&self) -> Option<Label> { support::child(&self.syntax) }
pub fn l_curly_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T!['{']) }
pub fn statements(&self) -> AstChildren<Stmt> { support::children(&self.syntax) }
pub fn expr(&self) -> Option<Expr> { support::child(&self.syntax) }

View file

@ -1058,7 +1058,7 @@ pub(crate) const AST_SRC: AstSrc = AstSrc {
/// [Reference](https://doc.rust-lang.org/reference/expressions/block-expr.html)
/// [Labels for blocks RFC](https://github.com/rust-lang/rfcs/blob/master/text/2046-label-break-value.md)
struct BlockExpr: AttrsOwner, ModuleItemOwner {
T!['{'], statements: [Stmt], Expr, T!['}'],
Label, T!['{'], statements: [Stmt], Expr, T!['}'],
}
/// Return expression.