Merge pull request #18594 from ChayimFriedman2/async-closures

feat: Support `AsyncFnX` traits
This commit is contained in:
Lukas Wirth 2024-12-06 12:48:47 +00:00 committed by GitHub
commit abc7147bb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 289 additions and 66 deletions

View file

@ -376,6 +376,9 @@ language_item_table! {
Fn, sym::fn_, fn_trait, Target::Trait, GenericRequirement::Exact(1); Fn, sym::fn_, fn_trait, Target::Trait, GenericRequirement::Exact(1);
FnMut, sym::fn_mut, fn_mut_trait, Target::Trait, GenericRequirement::Exact(1); FnMut, sym::fn_mut, fn_mut_trait, Target::Trait, GenericRequirement::Exact(1);
FnOnce, sym::fn_once, fn_once_trait, Target::Trait, GenericRequirement::Exact(1); FnOnce, sym::fn_once, fn_once_trait, Target::Trait, GenericRequirement::Exact(1);
AsyncFn, sym::async_fn, async_fn_trait, Target::Trait, GenericRequirement::Exact(1);
AsyncFnMut, sym::async_fn_mut, async_fn_mut_trait, Target::Trait, GenericRequirement::Exact(1);
AsyncFnOnce, sym::async_fn_once, async_fn_once_trait, Target::Trait, GenericRequirement::Exact(1);
FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None; FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None;

View file

@ -1287,8 +1287,8 @@ impl InferenceContext<'_> {
tgt_expr: ExprId, tgt_expr: ExprId,
) { ) {
match fn_x { match fn_x {
FnTrait::FnOnce => (), FnTrait::FnOnce | FnTrait::AsyncFnOnce => (),
FnTrait::FnMut => { FnTrait::FnMut | FnTrait::AsyncFnMut => {
if let TyKind::Ref(Mutability::Mut, lt, inner) = derefed_callee.kind(Interner) { if let TyKind::Ref(Mutability::Mut, lt, inner) = derefed_callee.kind(Interner) {
if adjustments if adjustments
.last() .last()
@ -1312,7 +1312,7 @@ impl InferenceContext<'_> {
)); ));
} }
} }
FnTrait::Fn => { FnTrait::Fn | FnTrait::AsyncFn => {
if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) { if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) {
adjustments.push(Adjustment::borrow( adjustments.push(Adjustment::borrow(
Mutability::Not, Mutability::Not,

View file

@ -794,11 +794,16 @@ impl<'a> InferenceTable<'a> {
ty: &Ty, ty: &Ty,
num_args: usize, num_args: usize,
) -> Option<(FnTrait, Vec<Ty>, Ty)> { ) -> Option<(FnTrait, Vec<Ty>, Ty)> {
for (fn_trait_name, output_assoc_name, subtraits) in [
(FnTrait::FnOnce, sym::Output.clone(), &[FnTrait::Fn, FnTrait::FnMut][..]),
(FnTrait::AsyncFnMut, sym::CallRefFuture.clone(), &[FnTrait::AsyncFn]),
(FnTrait::AsyncFnOnce, sym::CallOnceFuture.clone(), &[]),
] {
let krate = self.trait_env.krate; let krate = self.trait_env.krate;
let fn_once_trait = FnTrait::FnOnce.get_id(self.db, krate)?; let fn_trait = fn_trait_name.get_id(self.db, krate)?;
let trait_data = self.db.trait_data(fn_once_trait); let trait_data = self.db.trait_data(fn_trait);
let output_assoc_type = let output_assoc_type =
trait_data.associated_type_by_name(&Name::new_symbol_root(sym::Output.clone()))?; trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?;
let mut arg_tys = Vec::with_capacity(num_args); let mut arg_tys = Vec::with_capacity(num_args);
let arg_ty = TyBuilder::tuple(num_args) let arg_ty = TyBuilder::tuple(num_args)
@ -813,20 +818,19 @@ impl<'a> InferenceTable<'a> {
}) })
.build(); .build();
let b = TyBuilder::trait_ref(self.db, fn_once_trait); let b = TyBuilder::trait_ref(self.db, fn_trait);
if b.remaining() != 2 { if b.remaining() != 2 {
return None; return None;
} }
let mut trait_ref = b.push(ty.clone()).push(arg_ty).build(); let mut trait_ref = b.push(ty.clone()).push(arg_ty).build();
let projection = { let projection = TyBuilder::assoc_type_projection(
TyBuilder::assoc_type_projection(
self.db, self.db,
output_assoc_type, output_assoc_type,
Some(trait_ref.substitution.clone()), Some(trait_ref.substitution.clone()),
) )
.build() .fill_with_unknown()
}; .build();
let trait_env = self.trait_env.env.clone(); let trait_env = self.trait_env.env.clone();
let obligation = InEnvironment { let obligation = InEnvironment {
@ -834,13 +838,15 @@ impl<'a> InferenceTable<'a> {
environment: trait_env.clone(), environment: trait_env.clone(),
}; };
let canonical = self.canonicalize(obligation.clone()); let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() { if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some()
{
self.register_obligation(obligation.goal); self.register_obligation(obligation.goal);
let return_ty = self.normalize_projection_ty(projection); let return_ty = self.normalize_projection_ty(projection);
for fn_x in [FnTrait::Fn, FnTrait::FnMut, FnTrait::FnOnce] { for &fn_x in subtraits {
let fn_x_trait = fn_x.get_id(self.db, krate)?; let fn_x_trait = fn_x.get_id(self.db, krate)?;
trait_ref.trait_id = to_chalk_trait_id(fn_x_trait); trait_ref.trait_id = to_chalk_trait_id(fn_x_trait);
let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> = InEnvironment { let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> =
InEnvironment {
goal: trait_ref.clone().cast(Interner), goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(), environment: trait_env.clone(),
}; };
@ -853,11 +859,11 @@ impl<'a> InferenceTable<'a> {
return Some((fn_x, arg_tys, return_ty)); return Some((fn_x, arg_tys, return_ty));
} }
} }
unreachable!("It should at least implement FnOnce at this point"); return Some((fn_trait_name, arg_tys, return_ty));
} else {
None
} }
} }
None
}
pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T
where where

View file

@ -2023,11 +2023,11 @@ pub fn mir_body_for_closure_query(
ctx.result.locals.alloc(Local { ty: infer[*root].clone() }); ctx.result.locals.alloc(Local { ty: infer[*root].clone() });
let closure_local = ctx.result.locals.alloc(Local { let closure_local = ctx.result.locals.alloc(Local {
ty: match kind { ty: match kind {
FnTrait::FnOnce => infer[expr].clone(), FnTrait::FnOnce | FnTrait::AsyncFnOnce => infer[expr].clone(),
FnTrait::FnMut => { FnTrait::FnMut | FnTrait::AsyncFnMut => {
TyKind::Ref(Mutability::Mut, error_lifetime(), infer[expr].clone()).intern(Interner) TyKind::Ref(Mutability::Mut, error_lifetime(), infer[expr].clone()).intern(Interner)
} }
FnTrait::Fn => { FnTrait::Fn | FnTrait::AsyncFn => {
TyKind::Ref(Mutability::Not, error_lifetime(), infer[expr].clone()).intern(Interner) TyKind::Ref(Mutability::Not, error_lifetime(), infer[expr].clone()).intern(Interner)
} }
}, },
@ -2055,8 +2055,10 @@ pub fn mir_body_for_closure_query(
let mut err = None; let mut err = None;
let closure_local = ctx.result.locals.iter().nth(1).unwrap().0; let closure_local = ctx.result.locals.iter().nth(1).unwrap().0;
let closure_projection = match kind { let closure_projection = match kind {
FnTrait::FnOnce => vec![], FnTrait::FnOnce | FnTrait::AsyncFnOnce => vec![],
FnTrait::FnMut | FnTrait::Fn => vec![ProjectionElem::Deref], FnTrait::FnMut | FnTrait::Fn | FnTrait::AsyncFnMut | FnTrait::AsyncFn => {
vec![ProjectionElem::Deref]
}
}; };
ctx.result.walk_places(|p, store| { ctx.result.walk_places(|p, store| {
if let Some(it) = upvar_map.get(&p.local) { if let Some(it) = upvar_map.get(&p.local) {

View file

@ -4834,3 +4834,53 @@ fn bar(v: *const ()) {
"#]], "#]],
); );
} }
#[test]
fn async_fn_traits() {
check_infer(
r#"
//- minicore: async_fn
async fn foo<T: AsyncFn(u32) -> i32>(a: T) {
let fut1 = a(0);
fut1.await;
}
async fn bar<T: AsyncFnMut(u32) -> i32>(mut b: T) {
let fut2 = b(0);
fut2.await;
}
async fn baz<T: AsyncFnOnce(u32) -> i32>(c: T) {
let fut3 = c(0);
fut3.await;
}
"#,
expect![[r#"
37..38 'a': T
43..83 '{ ...ait; }': ()
43..83 '{ ...ait; }': impl Future<Output = ()>
53..57 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
60..61 'a': T
60..64 'a(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
62..63 '0': u32
70..74 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
70..80 'fut1.await': i32
124..129 'mut b': T
134..174 '{ ...ait; }': ()
134..174 '{ ...ait; }': impl Future<Output = ()>
144..148 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
151..152 'b': T
151..155 'b(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
153..154 '0': u32
161..165 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
161..171 'fut2.await': i32
216..217 'c': T
222..262 '{ ...ait; }': ()
222..262 '{ ...ait; }': impl Future<Output = ()>
232..236 'fut3': AsyncFnOnce::CallOnceFuture<T, (u32,)>
239..240 'c': T
239..243 'c(0)': AsyncFnOnce::CallOnceFuture<T, (u32,)>
241..242 '0': u32
249..253 'fut3': AsyncFnOnce::CallOnceFuture<T, (u32,)>
249..259 'fut3.await': i32
"#]],
);
}

View file

@ -220,6 +220,10 @@ pub enum FnTrait {
FnOnce, FnOnce,
FnMut, FnMut,
Fn, Fn,
AsyncFnOnce,
AsyncFnMut,
AsyncFn,
} }
impl fmt::Display for FnTrait { impl fmt::Display for FnTrait {
@ -228,6 +232,9 @@ impl fmt::Display for FnTrait {
FnTrait::FnOnce => write!(f, "FnOnce"), FnTrait::FnOnce => write!(f, "FnOnce"),
FnTrait::FnMut => write!(f, "FnMut"), FnTrait::FnMut => write!(f, "FnMut"),
FnTrait::Fn => write!(f, "Fn"), FnTrait::Fn => write!(f, "Fn"),
FnTrait::AsyncFnOnce => write!(f, "AsyncFnOnce"),
FnTrait::AsyncFnMut => write!(f, "AsyncFnMut"),
FnTrait::AsyncFn => write!(f, "AsyncFn"),
} }
} }
} }
@ -238,6 +245,9 @@ impl FnTrait {
FnTrait::FnOnce => "call_once", FnTrait::FnOnce => "call_once",
FnTrait::FnMut => "call_mut", FnTrait::FnMut => "call_mut",
FnTrait::Fn => "call", FnTrait::Fn => "call",
FnTrait::AsyncFnOnce => "async_call_once",
FnTrait::AsyncFnMut => "async_call_mut",
FnTrait::AsyncFn => "async_call",
} }
} }
@ -246,6 +256,9 @@ impl FnTrait {
FnTrait::FnOnce => LangItem::FnOnce, FnTrait::FnOnce => LangItem::FnOnce,
FnTrait::FnMut => LangItem::FnMut, FnTrait::FnMut => LangItem::FnMut,
FnTrait::Fn => LangItem::Fn, FnTrait::Fn => LangItem::Fn,
FnTrait::AsyncFnOnce => LangItem::AsyncFnOnce,
FnTrait::AsyncFnMut => LangItem::AsyncFnMut,
FnTrait::AsyncFn => LangItem::AsyncFn,
} }
} }
@ -254,15 +267,19 @@ impl FnTrait {
LangItem::FnOnce => Some(FnTrait::FnOnce), LangItem::FnOnce => Some(FnTrait::FnOnce),
LangItem::FnMut => Some(FnTrait::FnMut), LangItem::FnMut => Some(FnTrait::FnMut),
LangItem::Fn => Some(FnTrait::Fn), LangItem::Fn => Some(FnTrait::Fn),
LangItem::AsyncFnOnce => Some(FnTrait::AsyncFnOnce),
LangItem::AsyncFnMut => Some(FnTrait::AsyncFnMut),
LangItem::AsyncFn => Some(FnTrait::AsyncFn),
_ => None, _ => None,
} }
} }
pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind { pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind {
// Chalk doesn't support async fn traits.
match self { match self {
FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce, FnTrait::AsyncFnOnce | FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce,
FnTrait::FnMut => rust_ir::ClosureKind::FnMut, FnTrait::AsyncFnMut | FnTrait::FnMut => rust_ir::ClosureKind::FnMut,
FnTrait::Fn => rust_ir::ClosureKind::Fn, FnTrait::AsyncFn | FnTrait::Fn => rust_ir::ClosureKind::Fn,
} }
} }
@ -271,6 +288,9 @@ impl FnTrait {
FnTrait::FnOnce => Name::new_symbol_root(sym::call_once.clone()), FnTrait::FnOnce => Name::new_symbol_root(sym::call_once.clone()),
FnTrait::FnMut => Name::new_symbol_root(sym::call_mut.clone()), FnTrait::FnMut => Name::new_symbol_root(sym::call_mut.clone()),
FnTrait::Fn => Name::new_symbol_root(sym::call.clone()), FnTrait::Fn => Name::new_symbol_root(sym::call.clone()),
FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once.clone()),
FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut.clone()),
FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call.clone()),
} }
} }

View file

@ -37,4 +37,25 @@ fn foo() {
"#, "#,
); );
} }
#[test]
fn no_error_for_async_fn_traits() {
check_diagnostics(
r#"
//- minicore: async_fn
async fn f(it: impl AsyncFn(u32) -> i32) {
let fut = it(0);
let _: i32 = fut.await;
}
async fn g(mut it: impl AsyncFnMut(u32) -> i32) {
let fut = it(0);
let _: i32 = fut.await;
}
async fn h(it: impl AsyncFnOnce(u32) -> i32) {
let fut = it(0);
let _: i32 = fut.await;
}
"#,
);
}
} }

View file

@ -150,6 +150,9 @@ define_symbols! {
C, C,
call_mut, call_mut,
call_once, call_once,
async_call_once,
async_call_mut,
async_call,
call, call,
cdecl, cdecl,
Center, Center,
@ -221,6 +224,9 @@ define_symbols! {
fn_mut, fn_mut,
fn_once_output, fn_once_output,
fn_once, fn_once,
async_fn_once,
async_fn_mut,
async_fn,
fn_ptr_addr, fn_ptr_addr,
fn_ptr_trait, fn_ptr_trait,
format_alignment, format_alignment,
@ -334,6 +340,8 @@ define_symbols! {
Option, Option,
Ord, Ord,
Output, Output,
CallRefFuture,
CallOnceFuture,
owned_box, owned_box,
packed, packed,
panic_2015, panic_2015,

View file

@ -12,6 +12,7 @@
//! asm: //! asm:
//! assert: //! assert:
//! as_ref: sized //! as_ref: sized
//! async_fn: fn, tuple, future, copy
//! bool_impl: option, fn //! bool_impl: option, fn
//! builtin_impls: //! builtin_impls:
//! cell: copy, drop //! cell: copy, drop
@ -29,7 +30,7 @@
//! eq: sized //! eq: sized
//! error: fmt //! error: fmt
//! fmt: option, result, transmute, coerce_unsized, copy, clone, derive //! fmt: option, result, transmute, coerce_unsized, copy, clone, derive
//! fn: //! fn: tuple
//! from: sized //! from: sized
//! future: pin //! future: pin
//! coroutine: pin //! coroutine: pin
@ -60,6 +61,7 @@
//! sync: sized //! sync: sized
//! transmute: //! transmute:
//! try: infallible //! try: infallible
//! tuple:
//! unpin: sized //! unpin: sized
//! unsize: sized //! unsize: sized
//! todo: panic //! todo: panic
@ -138,10 +140,10 @@ pub mod marker {
} }
// endregion:copy // endregion:copy
// region:fn // region:tuple
#[lang = "tuple_trait"] #[lang = "tuple_trait"]
pub trait Tuple {} pub trait Tuple {}
// endregion:fn // endregion:tuple
// region:phantom_data // region:phantom_data
#[lang = "phantom_data"] #[lang = "phantom_data"]
@ -682,6 +684,116 @@ pub mod ops {
} }
pub use self::function::{Fn, FnMut, FnOnce}; pub use self::function::{Fn, FnMut, FnOnce};
// endregion:fn // endregion:fn
// region:async_fn
mod async_function {
use crate::{future::Future, marker::Tuple};
#[lang = "async_fn"]
#[fundamental]
pub trait AsyncFn<Args: Tuple>: AsyncFnMut<Args> {
extern "rust-call" fn async_call(&self, args: Args) -> Self::CallRefFuture<'_>;
}
#[lang = "async_fn_mut"]
#[fundamental]
pub trait AsyncFnMut<Args: Tuple>: AsyncFnOnce<Args> {
#[lang = "call_ref_future"]
type CallRefFuture<'a>: Future<Output = Self::Output>
where
Self: 'a;
extern "rust-call" fn async_call_mut(&mut self, args: Args) -> Self::CallRefFuture<'_>;
}
#[lang = "async_fn_once"]
#[fundamental]
pub trait AsyncFnOnce<Args: Tuple> {
#[lang = "async_fn_once_output"]
type Output;
#[lang = "call_once_future"]
type CallOnceFuture: Future<Output = Self::Output>;
extern "rust-call" fn async_call_once(self, args: Args) -> Self::CallOnceFuture;
}
mod impls {
use super::{AsyncFn, AsyncFnMut, AsyncFnOnce};
use crate::marker::Tuple;
impl<A: Tuple, F: ?Sized> AsyncFn<A> for &F
where
F: AsyncFn<A>,
{
extern "rust-call" fn async_call(&self, args: A) -> Self::CallRefFuture<'_> {
F::async_call(*self, args)
}
}
#[unstable(feature = "async_fn_traits", issue = "none")]
impl<A: Tuple, F: ?Sized> AsyncFnMut<A> for &F
where
F: AsyncFn<A>,
{
type CallRefFuture<'a>
= F::CallRefFuture<'a>
where
Self: 'a;
extern "rust-call" fn async_call_mut(
&mut self,
args: A,
) -> Self::CallRefFuture<'_> {
F::async_call(*self, args)
}
}
#[unstable(feature = "async_fn_traits", issue = "none")]
impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce<A> for &'a F
where
F: AsyncFn<A>,
{
type Output = F::Output;
type CallOnceFuture = F::CallRefFuture<'a>;
extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture {
F::async_call(self, args)
}
}
#[unstable(feature = "async_fn_traits", issue = "none")]
impl<A: Tuple, F: ?Sized> AsyncFnMut<A> for &mut F
where
F: AsyncFnMut<A>,
{
type CallRefFuture<'a>
= F::CallRefFuture<'a>
where
Self: 'a;
extern "rust-call" fn async_call_mut(
&mut self,
args: A,
) -> Self::CallRefFuture<'_> {
F::async_call_mut(*self, args)
}
}
#[unstable(feature = "async_fn_traits", issue = "none")]
impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce<A> for &'a mut F
where
F: AsyncFnMut<A>,
{
type Output = F::Output;
type CallOnceFuture = F::CallRefFuture<'a>;
extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture {
F::async_call_mut(self, args)
}
}
}
}
pub use self::async_function::{AsyncFn, AsyncFnMut, AsyncFnOnce};
// endregion:async_fn
// region:try // region:try
mod try_ { mod try_ {
use crate::convert::Infallible; use crate::convert::Infallible;
@ -1684,6 +1796,7 @@ pub mod prelude {
marker::Sync, // :sync marker::Sync, // :sync
mem::drop, // :drop mem::drop, // :drop
ops::Drop, // :drop ops::Drop, // :drop
ops::{AsyncFn, AsyncFnMut, AsyncFnOnce}, // :async_fn
ops::{Fn, FnMut, FnOnce}, // :fn ops::{Fn, FnMut, FnOnce}, // :fn
option::Option::{self, None, Some}, // :option option::Option::{self, None, Some}, // :option
panic, // :panic panic, // :panic