diff --git a/crates/hir-def/src/lang_item.rs b/crates/hir-def/src/lang_item.rs index 166c965d14..0629d87e54 100644 --- a/crates/hir-def/src/lang_item.rs +++ b/crates/hir-def/src/lang_item.rs @@ -376,6 +376,9 @@ language_item_table! { Fn, sym::fn_, fn_trait, Target::Trait, GenericRequirement::Exact(1); FnMut, sym::fn_mut, fn_mut_trait, Target::Trait, GenericRequirement::Exact(1); FnOnce, sym::fn_once, fn_once_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFn, sym::async_fn, async_fn_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFnMut, sym::async_fn_mut, async_fn_mut_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFnOnce, sym::async_fn_once, async_fn_once_trait, Target::Trait, GenericRequirement::Exact(1); FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None; diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 32b4ea2f28..c21ff19c45 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -1287,8 +1287,8 @@ impl InferenceContext<'_> { tgt_expr: ExprId, ) { match fn_x { - FnTrait::FnOnce => (), - FnTrait::FnMut => { + FnTrait::FnOnce | FnTrait::AsyncFnOnce => (), + FnTrait::FnMut | FnTrait::AsyncFnMut => { if let TyKind::Ref(Mutability::Mut, lt, inner) = derefed_callee.kind(Interner) { if adjustments .last() @@ -1312,7 +1312,7 @@ impl InferenceContext<'_> { )); } } - FnTrait::Fn => { + FnTrait::Fn | FnTrait::AsyncFn => { if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) { adjustments.push(Adjustment::borrow( Mutability::Not, diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index 4d0539c135..165861c1b1 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -794,69 +794,75 @@ impl<'a> InferenceTable<'a> { ty: &Ty, num_args: usize, ) -> Option<(FnTrait, Vec, Ty)> { - let krate = self.trait_env.krate; - let fn_once_trait = FnTrait::FnOnce.get_id(self.db, krate)?; - let trait_data = self.db.trait_data(fn_once_trait); - let output_assoc_type = - trait_data.associated_type_by_name(&Name::new_symbol_root(sym::Output.clone()))?; + for (fn_trait_name, output_assoc_name, subtraits) in [ + (FnTrait::FnOnce, sym::Output.clone(), &[FnTrait::Fn, FnTrait::FnMut][..]), + (FnTrait::AsyncFnMut, sym::CallRefFuture.clone(), &[FnTrait::AsyncFn]), + (FnTrait::AsyncFnOnce, sym::CallOnceFuture.clone(), &[]), + ] { + let krate = self.trait_env.krate; + let fn_trait = fn_trait_name.get_id(self.db, krate)?; + let trait_data = self.db.trait_data(fn_trait); + let output_assoc_type = + trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?; - let mut arg_tys = Vec::with_capacity(num_args); - let arg_ty = TyBuilder::tuple(num_args) - .fill(|it| { - let arg = match it { - ParamKind::Type => self.new_type_var(), - ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"), - ParamKind::Const(_) => unreachable!("Tuple with const parameter"), - }; - arg_tys.push(arg.clone()); - arg.cast(Interner) - }) - .build(); + let mut arg_tys = Vec::with_capacity(num_args); + let arg_ty = TyBuilder::tuple(num_args) + .fill(|it| { + let arg = match it { + ParamKind::Type => self.new_type_var(), + ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"), + ParamKind::Const(_) => unreachable!("Tuple with const parameter"), + }; + arg_tys.push(arg.clone()); + arg.cast(Interner) + }) + .build(); - let b = TyBuilder::trait_ref(self.db, fn_once_trait); - if b.remaining() != 2 { - return None; - } - let mut trait_ref = b.push(ty.clone()).push(arg_ty).build(); + let b = TyBuilder::trait_ref(self.db, fn_trait); + if b.remaining() != 2 { + return None; + } + let mut trait_ref = b.push(ty.clone()).push(arg_ty).build(); - let projection = { - TyBuilder::assoc_type_projection( + let projection = TyBuilder::assoc_type_projection( self.db, output_assoc_type, Some(trait_ref.substitution.clone()), ) - .build() - }; + .fill_with_unknown() + .build(); - let trait_env = self.trait_env.env.clone(); - let obligation = InEnvironment { - goal: trait_ref.clone().cast(Interner), - environment: trait_env.clone(), - }; - let canonical = self.canonicalize(obligation.clone()); - if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() { - self.register_obligation(obligation.goal); - let return_ty = self.normalize_projection_ty(projection); - for fn_x in [FnTrait::Fn, FnTrait::FnMut, FnTrait::FnOnce] { - let fn_x_trait = fn_x.get_id(self.db, krate)?; - trait_ref.trait_id = to_chalk_trait_id(fn_x_trait); - let obligation: chalk_ir::InEnvironment> = InEnvironment { - goal: trait_ref.clone().cast(Interner), - environment: trait_env.clone(), - }; - let canonical = self.canonicalize(obligation.clone()); - if self - .db - .trait_solve(krate, self.trait_env.block, canonical.cast(Interner)) - .is_some() - { - return Some((fn_x, arg_tys, return_ty)); + let trait_env = self.trait_env.env.clone(); + let obligation = InEnvironment { + goal: trait_ref.clone().cast(Interner), + environment: trait_env.clone(), + }; + let canonical = self.canonicalize(obligation.clone()); + if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() + { + self.register_obligation(obligation.goal); + let return_ty = self.normalize_projection_ty(projection); + for &fn_x in subtraits { + let fn_x_trait = fn_x.get_id(self.db, krate)?; + trait_ref.trait_id = to_chalk_trait_id(fn_x_trait); + let obligation: chalk_ir::InEnvironment> = + InEnvironment { + goal: trait_ref.clone().cast(Interner), + environment: trait_env.clone(), + }; + let canonical = self.canonicalize(obligation.clone()); + if self + .db + .trait_solve(krate, self.trait_env.block, canonical.cast(Interner)) + .is_some() + { + return Some((fn_x, arg_tys, return_ty)); + } } + return Some((fn_trait_name, arg_tys, return_ty)); } - unreachable!("It should at least implement FnOnce at this point"); - } else { - None } + None } pub(super) fn insert_type_vars(&mut self, ty: T) -> T diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index c4e0640051..1d1044df6e 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -2023,11 +2023,11 @@ pub fn mir_body_for_closure_query( ctx.result.locals.alloc(Local { ty: infer[*root].clone() }); let closure_local = ctx.result.locals.alloc(Local { ty: match kind { - FnTrait::FnOnce => infer[expr].clone(), - FnTrait::FnMut => { + FnTrait::FnOnce | FnTrait::AsyncFnOnce => infer[expr].clone(), + FnTrait::FnMut | FnTrait::AsyncFnMut => { TyKind::Ref(Mutability::Mut, error_lifetime(), infer[expr].clone()).intern(Interner) } - FnTrait::Fn => { + FnTrait::Fn | FnTrait::AsyncFn => { TyKind::Ref(Mutability::Not, error_lifetime(), infer[expr].clone()).intern(Interner) } }, @@ -2055,8 +2055,10 @@ pub fn mir_body_for_closure_query( let mut err = None; let closure_local = ctx.result.locals.iter().nth(1).unwrap().0; let closure_projection = match kind { - FnTrait::FnOnce => vec![], - FnTrait::FnMut | FnTrait::Fn => vec![ProjectionElem::Deref], + FnTrait::FnOnce | FnTrait::AsyncFnOnce => vec![], + FnTrait::FnMut | FnTrait::Fn | FnTrait::AsyncFnMut | FnTrait::AsyncFn => { + vec![ProjectionElem::Deref] + } }; ctx.result.walk_places(|p, store| { if let Some(it) = upvar_map.get(&p.local) { diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 624148cab2..82ff51927e 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -4811,3 +4811,53 @@ fn bar(v: *const ()) { "#]], ); } + +#[test] +fn async_fn_traits() { + check_infer( + r#" +//- minicore: async_fn +async fn foo i32>(a: T) { + let fut1 = a(0); + fut1.await; +} +async fn bar i32>(mut b: T) { + let fut2 = b(0); + fut2.await; +} +async fn baz i32>(c: T) { + let fut3 = c(0); + fut3.await; +} + "#, + expect![[r#" + 37..38 'a': T + 43..83 '{ ...ait; }': () + 43..83 '{ ...ait; }': impl Future + 53..57 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 60..61 'a': T + 60..64 'a(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 62..63 '0': u32 + 70..74 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 70..80 'fut1.await': i32 + 124..129 'mut b': T + 134..174 '{ ...ait; }': () + 134..174 '{ ...ait; }': impl Future + 144..148 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 151..152 'b': T + 151..155 'b(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 153..154 '0': u32 + 161..165 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 161..171 'fut2.await': i32 + 216..217 'c': T + 222..262 '{ ...ait; }': () + 222..262 '{ ...ait; }': impl Future + 232..236 'fut3': AsyncFnOnce::CallOnceFuture + 239..240 'c': T + 239..243 'c(0)': AsyncFnOnce::CallOnceFuture + 241..242 '0': u32 + 249..253 'fut3': AsyncFnOnce::CallOnceFuture + 249..259 'fut3.await': i32 + "#]], + ); +} diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index 51ccd4ef29..8cb7dbf60f 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -220,6 +220,10 @@ pub enum FnTrait { FnOnce, FnMut, Fn, + + AsyncFnOnce, + AsyncFnMut, + AsyncFn, } impl fmt::Display for FnTrait { @@ -228,6 +232,9 @@ impl fmt::Display for FnTrait { FnTrait::FnOnce => write!(f, "FnOnce"), FnTrait::FnMut => write!(f, "FnMut"), FnTrait::Fn => write!(f, "Fn"), + FnTrait::AsyncFnOnce => write!(f, "AsyncFnOnce"), + FnTrait::AsyncFnMut => write!(f, "AsyncFnMut"), + FnTrait::AsyncFn => write!(f, "AsyncFn"), } } } @@ -238,6 +245,9 @@ impl FnTrait { FnTrait::FnOnce => "call_once", FnTrait::FnMut => "call_mut", FnTrait::Fn => "call", + FnTrait::AsyncFnOnce => "async_call_once", + FnTrait::AsyncFnMut => "async_call_mut", + FnTrait::AsyncFn => "async_call", } } @@ -246,6 +256,9 @@ impl FnTrait { FnTrait::FnOnce => LangItem::FnOnce, FnTrait::FnMut => LangItem::FnMut, FnTrait::Fn => LangItem::Fn, + FnTrait::AsyncFnOnce => LangItem::AsyncFnOnce, + FnTrait::AsyncFnMut => LangItem::AsyncFnMut, + FnTrait::AsyncFn => LangItem::AsyncFn, } } @@ -254,15 +267,19 @@ impl FnTrait { LangItem::FnOnce => Some(FnTrait::FnOnce), LangItem::FnMut => Some(FnTrait::FnMut), LangItem::Fn => Some(FnTrait::Fn), + LangItem::AsyncFnOnce => Some(FnTrait::AsyncFnOnce), + LangItem::AsyncFnMut => Some(FnTrait::AsyncFnMut), + LangItem::AsyncFn => Some(FnTrait::AsyncFn), _ => None, } } pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind { + // Chalk doesn't support async fn traits. match self { - FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce, - FnTrait::FnMut => rust_ir::ClosureKind::FnMut, - FnTrait::Fn => rust_ir::ClosureKind::Fn, + FnTrait::AsyncFnOnce | FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce, + FnTrait::AsyncFnMut | FnTrait::FnMut => rust_ir::ClosureKind::FnMut, + FnTrait::AsyncFn | FnTrait::Fn => rust_ir::ClosureKind::Fn, } } @@ -271,6 +288,9 @@ impl FnTrait { FnTrait::FnOnce => Name::new_symbol_root(sym::call_once.clone()), FnTrait::FnMut => Name::new_symbol_root(sym::call_mut.clone()), FnTrait::Fn => Name::new_symbol_root(sym::call.clone()), + FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once.clone()), + FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut.clone()), + FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call.clone()), } } diff --git a/crates/ide-diagnostics/src/handlers/expected_function.rs b/crates/ide-diagnostics/src/handlers/expected_function.rs index 02299197b1..e3a1e12e02 100644 --- a/crates/ide-diagnostics/src/handlers/expected_function.rs +++ b/crates/ide-diagnostics/src/handlers/expected_function.rs @@ -37,4 +37,25 @@ fn foo() { "#, ); } + + #[test] + fn no_error_for_async_fn_traits() { + check_diagnostics( + r#" +//- minicore: async_fn +async fn f(it: impl AsyncFn(u32) -> i32) { + let fut = it(0); + let _: i32 = fut.await; +} +async fn g(mut it: impl AsyncFnMut(u32) -> i32) { + let fut = it(0); + let _: i32 = fut.await; +} +async fn h(it: impl AsyncFnOnce(u32) -> i32) { + let fut = it(0); + let _: i32 = fut.await; +} + "#, + ); + } } diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs index 865518fe94..ee96eff330 100644 --- a/crates/intern/src/symbol/symbols.rs +++ b/crates/intern/src/symbol/symbols.rs @@ -150,6 +150,9 @@ define_symbols! { C, call_mut, call_once, + async_call_once, + async_call_mut, + async_call, call, cdecl, Center, @@ -221,6 +224,9 @@ define_symbols! { fn_mut, fn_once_output, fn_once, + async_fn_once, + async_fn_mut, + async_fn, fn_ptr_addr, fn_ptr_trait, format_alignment, @@ -334,6 +340,8 @@ define_symbols! { Option, Ord, Output, + CallRefFuture, + CallOnceFuture, owned_box, packed, panic_2015, diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs index 07767d5ae9..f5c8466cb9 100644 --- a/crates/test-utils/src/minicore.rs +++ b/crates/test-utils/src/minicore.rs @@ -12,6 +12,7 @@ //! asm: //! assert: //! as_ref: sized +//! async_fn: fn, tuple, future, copy //! bool_impl: option, fn //! builtin_impls: //! cell: copy, drop @@ -29,7 +30,7 @@ //! eq: sized //! error: fmt //! fmt: option, result, transmute, coerce_unsized, copy, clone, derive -//! fn: +//! fn: tuple //! from: sized //! future: pin //! coroutine: pin @@ -60,6 +61,7 @@ //! sync: sized //! transmute: //! try: infallible +//! tuple: //! unpin: sized //! unsize: sized //! todo: panic @@ -138,10 +140,10 @@ pub mod marker { } // endregion:copy - // region:fn + // region:tuple #[lang = "tuple_trait"] pub trait Tuple {} - // endregion:fn + // endregion:tuple // region:phantom_data #[lang = "phantom_data"] @@ -682,6 +684,116 @@ pub mod ops { } pub use self::function::{Fn, FnMut, FnOnce}; // endregion:fn + + // region:async_fn + mod async_function { + use crate::{future::Future, marker::Tuple}; + + #[lang = "async_fn"] + #[fundamental] + pub trait AsyncFn: AsyncFnMut { + extern "rust-call" fn async_call(&self, args: Args) -> Self::CallRefFuture<'_>; + } + + #[lang = "async_fn_mut"] + #[fundamental] + pub trait AsyncFnMut: AsyncFnOnce { + #[lang = "call_ref_future"] + type CallRefFuture<'a>: Future + where + Self: 'a; + extern "rust-call" fn async_call_mut(&mut self, args: Args) -> Self::CallRefFuture<'_>; + } + + #[lang = "async_fn_once"] + #[fundamental] + pub trait AsyncFnOnce { + #[lang = "async_fn_once_output"] + type Output; + #[lang = "call_once_future"] + type CallOnceFuture: Future; + extern "rust-call" fn async_call_once(self, args: Args) -> Self::CallOnceFuture; + } + + mod impls { + use super::{AsyncFn, AsyncFnMut, AsyncFnOnce}; + use crate::marker::Tuple; + + impl AsyncFn for &F + where + F: AsyncFn, + { + extern "rust-call" fn async_call(&self, args: A) -> Self::CallRefFuture<'_> { + F::async_call(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl AsyncFnMut for &F + where + F: AsyncFn, + { + type CallRefFuture<'a> + = F::CallRefFuture<'a> + where + Self: 'a; + + extern "rust-call" fn async_call_mut( + &mut self, + args: A, + ) -> Self::CallRefFuture<'_> { + F::async_call(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce for &'a F + where + F: AsyncFn, + { + type Output = F::Output; + type CallOnceFuture = F::CallRefFuture<'a>; + + extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture { + F::async_call(self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl AsyncFnMut for &mut F + where + F: AsyncFnMut, + { + type CallRefFuture<'a> + = F::CallRefFuture<'a> + where + Self: 'a; + + extern "rust-call" fn async_call_mut( + &mut self, + args: A, + ) -> Self::CallRefFuture<'_> { + F::async_call_mut(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce for &'a mut F + where + F: AsyncFnMut, + { + type Output = F::Output; + type CallOnceFuture = F::CallRefFuture<'a>; + + extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture { + F::async_call_mut(self, args) + } + } + } + } + pub use self::async_function::{AsyncFn, AsyncFnMut, AsyncFnOnce}; + // endregion:async_fn + // region:try mod try_ { use crate::convert::Infallible; @@ -1684,6 +1796,7 @@ pub mod prelude { marker::Sync, // :sync mem::drop, // :drop ops::Drop, // :drop + ops::{AsyncFn, AsyncFnMut, AsyncFnOnce}, // :async_fn ops::{Fn, FnMut, FnOnce}, // :fn option::Option::{self, None, Some}, // :option panic, // :panic