From d662cb1e9eed02381bd80003f8ea5ec4d1374a6c Mon Sep 17 00:00:00 2001 From: Gino Valente Date: Thu, 5 Sep 2024 12:11:13 -0700 Subject: [PATCH] Added `arg_count` method --- .../bevy_reflect/src/func/dynamic_function.rs | 55 +++++++++++++++++-- .../src/func/dynamic_function_mut.rs | 38 ++++++++++--- crates/bevy_reflect/src/func/error.rs | 8 ++- crates/bevy_reflect/src/func/function.rs | 12 +++- crates/bevy_reflect/src/func/function_map.rs | 13 +++++ crates/bevy_reflect/src/func/info.rs | 31 ++++++----- crates/bevy_reflect/src/func/mod.rs | 4 +- crates/bevy_reflect/src/func/reflect_fn.rs | 8 +-- .../bevy_reflect/src/func/reflect_fn_mut.rs | 8 +-- 9 files changed, 138 insertions(+), 39 deletions(-) diff --git a/crates/bevy_reflect/src/func/dynamic_function.rs b/crates/bevy_reflect/src/func/dynamic_function.rs index e5e270b5ff..33df39bdaf 100644 --- a/crates/bevy_reflect/src/func/dynamic_function.rs +++ b/crates/bevy_reflect/src/func/dynamic_function.rs @@ -13,6 +13,7 @@ use crate::{ use alloc::{borrow::Cow, boxed::Box, sync::Arc}; use bevy_reflect_derive::impl_type_path; use core::fmt::{Debug, Formatter}; +use core::ops::RangeInclusive; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; @@ -299,7 +300,7 @@ impl<'env> DynamicFunction<'env> { let expected_arg_count = self.function_map.info().arg_count(); let received_arg_count = args.len(); - if !self.is_overloaded() && expected_arg_count != received_arg_count { + if !expected_arg_count.contains(&received_arg_count) { Err(FunctionError::ArgCountMismatch { expected: expected_arg_count, received: received_arg_count, @@ -350,6 +351,29 @@ impl<'env> DynamicFunction<'env> { pub fn is_overloaded(&self) -> bool { self.function_map.is_overloaded() } + + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will return the minimum and maximum number of arguments. + /// + /// Otherwise, the range will have the same start and end. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunction; + /// let add = (|a: i32, b: i32| a + b).into_function(); + /// assert_eq!(add.arg_count(), 2..=2); + /// + /// let add = add.with_overload(|a: f32, b: f32, c: f32| a + b + c); + /// assert_eq!(add.arg_count(), 2..=3); + /// ``` + /// + /// [overloaded]: Self::with_overload + pub fn arg_count(&self) -> RangeInclusive { + self.function_map.arg_count() + } } impl Function for DynamicFunction<'static> { @@ -548,13 +572,36 @@ mod tests { let args = ArgList::default().push_owned(25_i32); let error = func.call(args).unwrap_err(); - assert!(matches!( + + assert_eq!( error, FunctionError::ArgCountMismatch { - expected: 2, + expected: 2..=2, received: 1 } - )); + ); + } + + #[test] + fn should_return_error_on_arg_count_mismatch_overloaded() { + let func = (|a: i32, b: i32| a + b) + .into_function() + .with_overload(|a: i32, b: i32, c: i32| a + b + c); + + let args = ArgList::default() + .push_owned(1_i32) + .push_owned(2_i32) + .push_owned(3_i32) + .push_owned(4_i32); + let error = func.call(args).unwrap_err(); + + assert_eq!( + error, + FunctionError::ArgCountMismatch { + expected: 2..=3, + received: 4 + } + ); } #[test] diff --git a/crates/bevy_reflect/src/func/dynamic_function_mut.rs b/crates/bevy_reflect/src/func/dynamic_function_mut.rs index bdd6823a92..c57e2650b1 100644 --- a/crates/bevy_reflect/src/func/dynamic_function_mut.rs +++ b/crates/bevy_reflect/src/func/dynamic_function_mut.rs @@ -1,5 +1,6 @@ use alloc::{borrow::Cow, boxed::Box, sync::Arc}; use core::fmt::{Debug, Formatter}; +use core::ops::RangeInclusive; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; @@ -269,7 +270,7 @@ impl<'env> DynamicFunctionMut<'env> { let expected_arg_count = self.function_map.info().arg_count(); let received_arg_count = args.len(); - if !self.is_overloaded() && expected_arg_count != received_arg_count { + if !expected_arg_count.contains(&received_arg_count) { Err(FunctionError::ArgCountMismatch { expected: expected_arg_count, received: received_arg_count, @@ -351,6 +352,29 @@ impl<'env> DynamicFunctionMut<'env> { pub fn is_overloaded(&self) -> bool { self.function_map.is_overloaded() } + + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will return the minimum and maximum number of arguments. + /// + /// Otherwise, the range will have the same start and end. + /// + /// # Example + /// + /// ``` + /// # use bevy_reflect::func::IntoFunctionMut; + /// let add = (|a: i32, b: i32| a + b).into_function_mut(); + /// assert_eq!(add.arg_count(), 2..=2); + /// + /// let add = add.with_overload(|a: f32, b: f32, c: f32| a + b + c); + /// assert_eq!(add.arg_count(), 2..=3); + /// ``` + /// + /// [overloaded]: Self::with_overload + pub fn arg_count(&self) -> RangeInclusive { + self.function_map.arg_count() + } } /// Outputs the function's signature. @@ -452,23 +476,23 @@ mod tests { let args = ArgList::default().push_owned(25_i32); let error = func.call(args).unwrap_err(); - assert!(matches!( + assert_eq!( error, FunctionError::ArgCountMismatch { - expected: 2, + expected: 2..=2, received: 1 } - )); + ); let args = ArgList::default().push_owned(25_i32); let error = func.call_once(args).unwrap_err(); - assert!(matches!( + assert_eq!( error, FunctionError::ArgCountMismatch { - expected: 2, + expected: 2..=2, received: 1 } - )); + ); } #[test] diff --git a/crates/bevy_reflect/src/func/error.rs b/crates/bevy_reflect/src/func/error.rs index 2b73aac757..cedb0eaad9 100644 --- a/crates/bevy_reflect/src/func/error.rs +++ b/crates/bevy_reflect/src/func/error.rs @@ -2,6 +2,7 @@ use crate::func::signature::ArgumentSignature; use crate::func::{args::ArgError, Return}; use alloc::borrow::Cow; use bevy_utils::HashSet; +use core::ops::RangeInclusive; use thiserror::Error; #[cfg(not(feature = "std"))] @@ -17,8 +18,11 @@ pub enum FunctionError { #[error(transparent)] ArgError(#[from] ArgError), /// The number of arguments provided does not match the expected number. - #[error("expected {expected} arguments but received {received}")] - ArgCountMismatch { expected: usize, received: usize }, + #[error("expected {expected:?} arguments but received {received}")] + ArgCountMismatch { + expected: RangeInclusive, + received: usize, + }, /// No overload was found for the given set of arguments. #[error("no overload found for arguments with signature `{received:?}`, expected one of `{expected:?}`")] NoOverload { diff --git a/crates/bevy_reflect/src/func/function.rs b/crates/bevy_reflect/src/func/function.rs index 97a2c85d2b..bfa40a8b75 100644 --- a/crates/bevy_reflect/src/func/function.rs +++ b/crates/bevy_reflect/src/func/function.rs @@ -4,6 +4,7 @@ use crate::{ }; use alloc::borrow::Cow; use core::fmt::Debug; +use core::ops::RangeInclusive; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; @@ -47,8 +48,15 @@ pub trait Function: PartialReflect + Debug { /// [`IntoFunction`]: crate::func::IntoFunction fn name(&self) -> Option<&Cow<'static, str>>; - /// The number of arguments this function accepts. - fn arg_count(&self) -> usize { + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will return the minimum and maximum number of arguments. + /// + /// Otherwise, the range will have the same start and end. + /// + /// [overloaded]: FunctionInfoType::Overloaded + fn arg_count(&self) -> RangeInclusive { self.info().arg_count() } diff --git a/crates/bevy_reflect/src/func/function_map.rs b/crates/bevy_reflect/src/func/function_map.rs index b7e672fc1a..7c0a394932 100644 --- a/crates/bevy_reflect/src/func/function_map.rs +++ b/crates/bevy_reflect/src/func/function_map.rs @@ -3,6 +3,7 @@ use crate::func::{ArgList, FunctionError, FunctionInfo, FunctionInfoType, Functi use alloc::borrow::Cow; use bevy_utils::hashbrown::HashMap; use bevy_utils::NoOpHash; +use core::ops::RangeInclusive; /// A helper type for storing a mapping of overloaded functions /// along with the corresponding [function information]. @@ -76,6 +77,18 @@ impl FunctionMap { } } + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will return the minimum and maximum number of arguments. + /// + /// Otherwise, the range will have the same start and end. + /// + /// [overloaded]: Self::Overloaded + pub fn arg_count(&self) -> RangeInclusive { + self.info().arg_count() + } + /// Merge another [`FunctionMap`] into this one. /// /// If the other map contains any functions with the same signature as this one, diff --git a/crates/bevy_reflect/src/func/info.rs b/crates/bevy_reflect/src/func/info.rs index f9da702d1f..6787ab35e0 100644 --- a/crates/bevy_reflect/src/func/info.rs +++ b/crates/bevy_reflect/src/func/info.rs @@ -3,6 +3,7 @@ use alloc::{borrow::Cow, vec}; #[cfg(not(feature = "std"))] use alloc::{boxed::Box, format, vec}; +use core::ops::RangeInclusive; use variadics_please::all_tuples; use crate::{ @@ -62,21 +63,23 @@ impl IntoIterator for FunctionInfoType<'_> { } impl FunctionInfoType<'_> { - pub fn arg_count(&self) -> usize { + /// Returns the number of arguments the function expects. + /// + /// For [overloaded] functions that can have a variable number of arguments, + /// this will return the minimum and maximum number of arguments. + /// + /// Otherwise, the range will have the same start and end. + /// + /// [overloaded]: Self::Overloaded + pub fn arg_count(&self) -> RangeInclusive { match self { - Self::Standard(info) => info.arg_count(), - Self::Overloaded(infos) => { - // TODO: This needs proper implementation - infos.iter().map(FunctionInfo::arg_count).min().unwrap() - } // match self { - // Self::Standard(info) => RangeInclusive::new(info.arg_count(), info.arg_count()), - // Self::Overloaded(infos) => infos.iter().map(FunctionInfo::arg_count).fold( - // RangeInclusive::new(0, 0), - // |acc, count| { - // RangeInclusive::new((*acc.start()).min(count), (*acc.end()).max(count)) - // }, - // ), - // } + Self::Standard(info) => RangeInclusive::new(info.arg_count(), info.arg_count()), + Self::Overloaded(infos) => infos.iter().map(FunctionInfo::arg_count).fold( + RangeInclusive::new(usize::MAX, usize::MIN), + |acc, count| { + RangeInclusive::new((*acc.start()).min(count), (*acc.end()).max(count)) + }, + ), } } } diff --git a/crates/bevy_reflect/src/func/mod.rs b/crates/bevy_reflect/src/func/mod.rs index 4e0e89bc41..15d92cb085 100644 --- a/crates/bevy_reflect/src/func/mod.rs +++ b/crates/bevy_reflect/src/func/mod.rs @@ -178,7 +178,7 @@ mod tests { assert_eq!( result.unwrap_err(), FunctionError::ArgCountMismatch { - expected: 1, + expected: 1..=1, received: 0 } ); @@ -194,7 +194,7 @@ mod tests { assert_eq!( result.unwrap_err(), FunctionError::ArgCountMismatch { - expected: 0, + expected: 0..=0, received: 1 } ); diff --git a/crates/bevy_reflect/src/func/reflect_fn.rs b/crates/bevy_reflect/src/func/reflect_fn.rs index 486fa452aa..24cfdeebdc 100644 --- a/crates/bevy_reflect/src/func/reflect_fn.rs +++ b/crates/bevy_reflect/src/func/reflect_fn.rs @@ -96,7 +96,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } @@ -125,7 +125,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } @@ -155,7 +155,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } @@ -185,7 +185,7 @@ macro_rules! impl_reflect_fn { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } diff --git a/crates/bevy_reflect/src/func/reflect_fn_mut.rs b/crates/bevy_reflect/src/func/reflect_fn_mut.rs index 6a8b9a6d73..c5747c49ef 100644 --- a/crates/bevy_reflect/src/func/reflect_fn_mut.rs +++ b/crates/bevy_reflect/src/func/reflect_fn_mut.rs @@ -102,7 +102,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } @@ -131,7 +131,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } @@ -161,7 +161,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); } @@ -191,7 +191,7 @@ macro_rules! impl_reflect_fn_mut { if args.len() != COUNT { return Err(FunctionError::ArgCountMismatch { - expected: COUNT, + expected: COUNT..=COUNT, received: args.len(), }); }