diff --git a/crates/bevy_reflect/src/func/dynamic_function.rs b/crates/bevy_reflect/src/func/dynamic_function.rs index aa0a47007e..8344f138f6 100644 --- a/crates/bevy_reflect/src/func/dynamic_function.rs +++ b/crates/bevy_reflect/src/func/dynamic_function.rs @@ -2,11 +2,9 @@ use crate::{ self as bevy_reflect, __macro_exports::RegisterForReflection, func::{ - args::ArgList, - function_map::{merge_function_map, FunctionMap}, - info::FunctionInfoType, - signature::ArgumentSignature, - DynamicFunctionMut, Function, FunctionError, FunctionResult, IntoFunction, IntoFunctionMut, + args::ArgList, function_map::FunctionMap, info::FunctionInfoType, + signature::ArgumentSignature, DynamicFunctionMut, Function, FunctionError, FunctionResult, + IntoFunction, IntoFunctionMut, }, serde::Serializable, ApplyError, MaybeTyped, PartialReflect, Reflect, ReflectKind, ReflectMut, ReflectOwned, @@ -49,7 +47,7 @@ type ArcFn<'env> = Arc Fn(ArgList<'a>) -> FunctionResult<'a> + Send /// Most of the time, a [`DynamicFunction`] can be created using the [`IntoFunction`] trait: /// /// ``` -/// # use bevy_reflect::func::{ArgList, DynamicFunction, FunctionInfo, IntoFunction}; +/// # use bevy_reflect::func::{ArgList, DynamicFunction, IntoFunction}; /// # /// fn add(a: i32, b: i32) -> i32 { /// a + b @@ -70,7 +68,6 @@ type ArcFn<'env> = Arc Fn(ArgList<'a>) -> FunctionResult<'a> + Send /// [module-level documentation]: crate::func pub struct DynamicFunction<'env> { pub(super) name: Option>, - pub(super) info: FunctionInfoType, pub(super) function_map: FunctionMap>, } @@ -104,16 +101,21 @@ impl<'env> DynamicFunction<'env> { FunctionInfoType::Standard(info) => info.name().cloned(), FunctionInfoType::Overloaded(_) => None, }, - function_map: match &info { - FunctionInfoType::Standard(_) => FunctionMap::Standard(func), - FunctionInfoType::Overloaded(infos) => { - FunctionMap::Overloaded(HashMap::from_iter(infos.iter().map(|info| { - let sig = ArgumentSignature::from(info); - (sig, func.clone()) - }))) - } + function_map: match info { + FunctionInfoType::Standard(info) => FunctionMap { + functions: vec![func], + indices: HashMap::from([(ArgumentSignature::from(&info), 0)]), + info: FunctionInfoType::Standard(info), + }, + FunctionInfoType::Overloaded(infos) => FunctionMap { + functions: vec![func], + indices: infos + .iter() + .map(|info| (ArgumentSignature::from(info), 0)) + .collect(), + info: FunctionInfoType::Overloaded(infos), + }, }, - info, } } @@ -144,6 +146,10 @@ impl<'env> DynamicFunction<'env> { /// /// Overloaded functions retain the [name] of the original function. /// + /// # Panics + /// + /// Panics if the function, `F`, contains a signature already found in this function. + /// /// # Examples /// /// ``` @@ -202,6 +208,23 @@ impl<'env> DynamicFunction<'env> { /// assert_eq!(result.try_take::().unwrap(), 200); /// ``` /// + ///```should_panic + /// # use bevy_reflect::func::IntoFunction; + /// + /// fn add(a: i32, b: i32) -> i32 { + /// a + b + /// } + /// + /// fn sub(a: i32, b: i32) -> i32 { + /// a - b + /// } + /// + /// let mut func = add.into_function(); + /// + /// // This will panic because the function already has an argument signature for `(i32, i32)`: + /// func = func.with_overload(sub); + /// ``` + /// /// [argument signature]: ArgumentSignature /// [name]: Self::name pub fn with_overload<'a, F: IntoFunction<'a, Marker>, Marker>( @@ -213,19 +236,15 @@ impl<'env> DynamicFunction<'env> { { let function = function.into_function(); - let name = self.name; - let (function_map, info) = merge_function_map( - self.function_map, - self.info, - function.function_map, - function.info, - ); + let name = self.name.clone(); + let mut function_map = self.function_map; + function_map + .merge(function.function_map) + .unwrap_or_else(|err| { + panic!("{}", err); + }); - DynamicFunction { - name, - info, - function_map, - } + DynamicFunction { name, function_map } } /// Call the function with the given arguments. @@ -252,35 +271,25 @@ impl<'env> DynamicFunction<'env> { /// /// The function itself may also return any errors it needs to. pub fn call<'a>(&self, args: ArgList<'a>) -> FunctionResult<'a> { - let expected_arg_count = self.info.arg_count(); + let expected_arg_count = self.function_map.info.arg_count(); let received_arg_count = args.len(); - match self.function_map { - FunctionMap::Standard(ref func) => { - if expected_arg_count != received_arg_count { - Err(FunctionError::ArgCountMismatch { - expected: expected_arg_count, - received: received_arg_count, - }) - } else { - func(args) - } - } - FunctionMap::Overloaded(ref map) => { - let sig = ArgumentSignature::from(&args); - let func = map.get(&sig).ok_or_else(|| FunctionError::NoOverload { - expected: map.keys().cloned().collect(), - received: sig, - })?; - - func(args) - } + if matches!(self.function_map.info, FunctionInfoType::Standard(_)) + && expected_arg_count != received_arg_count + { + Err(FunctionError::ArgCountMismatch { + expected: expected_arg_count, + received: received_arg_count, + }) + } else { + let func = self.function_map.get(&args)?; + func(args) } } /// Returns the function info. pub fn info(&self) -> &FunctionInfoType { - &self.info + &self.function_map.info } /// The name of the function. @@ -308,7 +317,7 @@ impl Function for DynamicFunction<'static> { } fn info(&self) -> &FunctionInfoType { - &self.info + &self.function_map.info } fn reflect_call<'a>(&self, args: ArgList<'a>) -> FunctionResult<'a> { @@ -444,7 +453,6 @@ impl<'env> Clone for DynamicFunction<'env> { fn clone(&self) -> Self { Self { name: self.name.clone(), - info: self.info.clone(), function_map: self.function_map.clone(), } } @@ -467,7 +475,7 @@ impl<'env> IntoFunctionMut<'env, ()> for DynamicFunction<'env> { #[cfg(test)] mod tests { use super::*; - use crate::func::{FunctionInfo, IntoReturn}; + use crate::func::{FunctionError, FunctionInfo, IntoReturn}; use crate::Type; use bevy_utils::HashSet; use std::ops::Add; diff --git a/crates/bevy_reflect/src/func/dynamic_function_mut.rs b/crates/bevy_reflect/src/func/dynamic_function_mut.rs index c00866c2f3..c934977b05 100644 --- a/crates/bevy_reflect/src/func/dynamic_function_mut.rs +++ b/crates/bevy_reflect/src/func/dynamic_function_mut.rs @@ -5,11 +5,10 @@ use core::fmt::{Debug, Formatter}; use alloc::{boxed::Box, format, vec}; use crate::func::{ - args::ArgList, - function_map::{merge_function_map, FunctionMap}, - signature::ArgumentSignature, - DynamicFunction, FunctionError, FunctionInfoType, FunctionResult, IntoFunctionMut, + args::ArgList, function_map::FunctionMap, signature::ArgumentSignature, DynamicFunction, + FunctionError, FunctionInfoType, FunctionResult, IntoFunctionMut, }; +use bevy_utils::HashMap; /// A [`Box`] containing a callback to a reflected function. type BoxFnMut<'env> = Box FnMut(ArgList<'a>) -> FunctionResult<'a> + 'env>; @@ -72,7 +71,6 @@ type BoxFnMut<'env> = Box FnMut(ArgList<'a>) -> FunctionResult<'a> + /// [module-level documentation]: crate::func pub struct DynamicFunctionMut<'env> { name: Option>, - info: FunctionInfoType, function_map: FunctionMap>, } @@ -99,13 +97,28 @@ impl<'env> DynamicFunctionMut<'env> { ) -> Self { let info = info.try_into().unwrap(); + let func: BoxFnMut = Box::new(func); + Self { name: match &info { FunctionInfoType::Standard(info) => info.name().cloned(), FunctionInfoType::Overloaded(_) => None, }, - info, - function_map: FunctionMap::Standard(Box::new(func)), + function_map: match info { + FunctionInfoType::Standard(info) => FunctionMap { + functions: vec![func], + indices: HashMap::from([(ArgumentSignature::from(&info), 0)]), + info: FunctionInfoType::Standard(info), + }, + FunctionInfoType::Overloaded(infos) => FunctionMap { + functions: vec![func], + indices: infos + .iter() + .map(|info| (ArgumentSignature::from(info), 0)) + .collect(), + info: FunctionInfoType::Overloaded(infos), + }, + }, } } @@ -141,6 +154,10 @@ impl<'env> DynamicFunctionMut<'env> { /// However, it's still possible to overload functions that do not capture their environment mutably, /// or those that maintain mutually exclusive mutable references to their environment. /// + /// # Panics + /// + /// Panics if the function, `F`, contains a signature already found in this function. + /// /// # Example /// /// ``` @@ -181,19 +198,13 @@ impl<'env> DynamicFunctionMut<'env> { { let function = function.into_function_mut(); - let name = self.name; - let (function_map, info) = merge_function_map( - self.function_map, - self.info, - function.function_map, - function.info, - ); + let name = self.name.clone(); + let mut function_map = self.function_map; + function_map + .merge(function.function_map) + .unwrap_or_else(|_| todo!()); - DynamicFunctionMut { - name, - info, - function_map, - } + DynamicFunctionMut { name, function_map } } /// Call the function with the given arguments. @@ -228,31 +239,19 @@ impl<'env> DynamicFunctionMut<'env> { /// /// [`call_once`]: DynamicFunctionMut::call_once pub fn call<'a>(&mut self, args: ArgList<'a>) -> FunctionResult<'a> { - let expected_arg_count = self.info.arg_count(); + let expected_arg_count = self.function_map.info.arg_count(); let received_arg_count = args.len(); - match self.function_map { - FunctionMap::Standard(ref mut func) => { - if expected_arg_count != received_arg_count { - Err(FunctionError::ArgCountMismatch { - expected: expected_arg_count, - received: received_arg_count, - }) - } else { - func(args) - } - } - FunctionMap::Overloaded(ref mut map) => { - let sig = ArgumentSignature::from(&args); - if let Some(func) = map.get_mut(&sig) { - func(args) - } else { - Err(FunctionError::NoOverload { - expected: map.keys().cloned().collect(), - received: sig, - }) - } - } + if matches!(self.function_map.info, FunctionInfoType::Standard(_)) + && expected_arg_count != received_arg_count + { + Err(FunctionError::ArgCountMismatch { + expected: expected_arg_count, + received: received_arg_count, + }) + } else { + let func = self.function_map.get_mut(&args)?; + func(args) } } @@ -290,7 +289,7 @@ impl<'env> DynamicFunctionMut<'env> { /// Returns the function info. pub fn info(&self) -> &FunctionInfoType { - &self.info + &self.function_map.info } /// The name of the function. @@ -346,15 +345,15 @@ impl<'env> From> for DynamicFunctionMut<'env> { fn from(function: DynamicFunction<'env>) -> Self { Self { name: function.name, - info: function.info, - function_map: match function.function_map { - FunctionMap::Standard(func) => FunctionMap::Standard(arc_to_box(func)), - FunctionMap::Overloaded(functions) => FunctionMap::Overloaded( - functions - .into_iter() - .map(|(name, func)| (name, arc_to_box(func))) - .collect(), - ), + function_map: FunctionMap { + info: function.function_map.info, + indices: function.function_map.indices, + functions: function + .function_map + .functions + .into_iter() + .map(arc_to_box) + .collect(), }, } } @@ -364,7 +363,7 @@ impl<'env> From> for DynamicFunctionMut<'env> { /// /// This is needed to help the compiler infer the correct types. fn arc_to_box<'env>( - f: Arc Fn(ArgList<'a>) -> FunctionResult<'a> + 'env>, + f: Arc Fn(ArgList<'a>) -> FunctionResult<'a> + Send + Sync + 'env>, ) -> BoxFnMut<'env> { Box::new(move |args| f(args)) } diff --git a/crates/bevy_reflect/src/func/error.rs b/crates/bevy_reflect/src/func/error.rs index 3dbd70386d..2b73aac757 100644 --- a/crates/bevy_reflect/src/func/error.rs +++ b/crates/bevy_reflect/src/func/error.rs @@ -43,6 +43,13 @@ pub type FunctionResult<'a> = Result, FunctionError>; #[error("expected a `FunctionInfo` but found none")] pub struct MissingFunctionInfoError; +/// An error that occurs when attempting to add a function overload with a duplicate signature. +#[derive(Debug, Error, PartialEq)] +#[error("could not add function overload: duplicate found for signature `{signature:?}`")] +pub struct FunctionOverloadError { + pub signature: ArgumentSignature, +} + /// An error that occurs when registering a function into a [`FunctionRegistry`]. /// /// [`FunctionRegistry`]: crate::func::FunctionRegistry diff --git a/crates/bevy_reflect/src/func/function_map.rs b/crates/bevy_reflect/src/func/function_map.rs index f4ebf81928..aaab049647 100644 --- a/crates/bevy_reflect/src/func/function_map.rs +++ b/crates/bevy_reflect/src/func/function_map.rs @@ -1,102 +1,203 @@ use crate::func::signature::ArgumentSignature; -use crate::func::FunctionInfoType; +use crate::func::{ArgList, FunctionError, FunctionInfoType, FunctionOverloadError}; use bevy_utils::HashMap; -/// A helper type for storing either a single function or a mapping of overloaded functions. -#[derive(Clone)] -pub(super) enum FunctionMap { - Standard(F), - Overloaded(HashMap), +/// A helper type for storing a mapping of overloaded functions +/// along with the corresponding [function information]. +/// +/// [function information]: FunctionInfoType +#[derive(Clone, Debug)] +pub(super) struct FunctionMap { + pub info: FunctionInfoType, + pub functions: Vec, + pub indices: HashMap, } -/// Merges the given [`FunctionMap`]s and [`FunctionInfoType`]s into a new [`FunctionMap`] and [`FunctionInfoType`]. -/// -/// # Panics -/// -/// Panics if a [`FunctionMap`]'s corresponding [`FunctionInfoType`] does not match its overload status. -pub(super) fn merge_function_map( - map_a: FunctionMap, - info_a: FunctionInfoType, - map_b: FunctionMap, - info_b: FunctionInfoType, -) -> (FunctionMap, FunctionInfoType) { - match (map_a, info_a, map_b, info_b) { - ( - FunctionMap::Standard(old), - FunctionInfoType::Standard(info_a), - FunctionMap::Standard(new), - FunctionInfoType::Standard(info_b), - ) => { - let sig_a = ArgumentSignature::from(&info_a); - let sig_b = ArgumentSignature::from(&info_b); - - if sig_a == sig_b { - ( - FunctionMap::Standard(new), - FunctionInfoType::Standard(info_b), - ) - } else { - ( - FunctionMap::Overloaded(HashMap::from([(sig_a, old), (sig_b, new)])), - FunctionInfoType::Overloaded(Box::new([info_a, info_b])), - ) - } - } - ( - FunctionMap::Overloaded(old), - FunctionInfoType::Overloaded(info_a), - FunctionMap::Standard(new), - FunctionInfoType::Standard(info_b), - ) => { - let sig_b = ArgumentSignature::from(&info_b); - let mut map = old; - map.insert(sig_b, new); - - let mut info = Vec::from_iter(info_a); - info.push(info_b); - - ( - FunctionMap::Overloaded(map), - FunctionInfoType::Overloaded(info.into_boxed_slice()), - ) - } - ( - FunctionMap::Standard(old), - FunctionInfoType::Standard(info_a), - FunctionMap::Overloaded(new), - FunctionInfoType::Overloaded(info_b), - ) => { - let sig_a = ArgumentSignature::from(&info_a); - let mut map = new; - map.insert(sig_a, old); - - let mut info = vec![info_a]; - info.extend(info_b); - - ( - FunctionMap::Overloaded(map), - FunctionInfoType::Overloaded(info.into_boxed_slice()), - ) - } - ( - FunctionMap::Overloaded(map1), - FunctionInfoType::Overloaded(info_a), - FunctionMap::Overloaded(map2), - FunctionInfoType::Overloaded(info_b), - ) => { - let mut map = map1; - map.extend(map2); - - let mut info = Vec::from_iter(info_a); - info.extend(info_b); - - ( - FunctionMap::Overloaded(map), - FunctionInfoType::Overloaded(info.into_boxed_slice()), - ) - } - _ => { - panic!("`FunctionMap` and `FunctionInfoType` mismatch"); +impl FunctionMap { + /// Get a reference to a function in the map. + /// + /// If there is only one function in the map, it will be returned. + /// Otherwise, the function will be selected based on the arguments provided. + /// + /// If no overload matches the provided arguments, an error will be returned. + pub fn get(&self, args: &ArgList) -> Result<&F, FunctionError> { + if self.functions.len() == 1 { + Ok(&self.functions[0]) + } else { + let signature = ArgumentSignature::from(args); + self.indices + .get(&signature) + .map(|index| &self.functions[*index]) + .ok_or_else(|| FunctionError::NoOverload { + expected: self.indices.keys().cloned().collect(), + received: signature, + }) } } + + /// Get a mutable reference to a function in the map. + /// + /// If there is only one function in the map, it will be returned. + /// Otherwise, the function will be selected based on the arguments provided. + /// + /// If no overload matches the provided arguments, an error will be returned. + pub fn get_mut(&mut self, args: &ArgList) -> Result<&mut F, FunctionError> { + if self.functions.len() == 1 { + Ok(&mut self.functions[0]) + } else { + let signature = ArgumentSignature::from(args); + self.indices + .get(&signature) + .map(|index| &mut self.functions[*index]) + .ok_or_else(|| FunctionError::NoOverload { + expected: self.indices.keys().cloned().collect(), + received: signature, + }) + } + } + + /// Merge another [`FunctionMap`] into this one. + /// + /// If the other map contains any functions with the same signature as this one, + /// an error will be returned and the original map will remain unchanged. + pub fn merge(&mut self, other: Self) -> Result<(), FunctionOverloadError> { + // === Function Map === // + let mut other_indices = HashMap::new(); + + for (sig, index) in other.indices { + if self.indices.contains_key(&sig) { + return Err(FunctionOverloadError { signature: sig }); + } + + other_indices.insert(sig, self.functions.len() + index); + } + + // === Function Info === // + let mut other_infos = Vec::new(); + + for info in other.info.into_iter() { + let sig = ArgumentSignature::from(&info); + if self.indices.contains_key(&sig) { + return Err(FunctionOverloadError { signature: sig }); + } + other_infos.push(info); + } + + // === Update === // + self.indices.extend(other_indices); + self.functions.extend(other.functions); + self.info = match &self.info { + FunctionInfoType::Standard(info) => FunctionInfoType::Overloaded( + std::iter::once(info.clone()).chain(other_infos).collect(), + ), + FunctionInfoType::Overloaded(infos) => { + FunctionInfoType::Overloaded(infos.iter().cloned().chain(other_infos).collect()) + } + }; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::func::FunctionInfo; + use crate::Type; + + #[test] + fn should_merge_function_maps() { + let mut map_a = FunctionMap { + info: FunctionInfoType::Overloaded(Box::new([ + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + ])), + functions: vec!['a', 'b', 'c'], + indices: HashMap::from([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]), + }; + + let map_b = FunctionMap { + info: FunctionInfoType::Overloaded(Box::new([ + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + ])), + functions: vec!['d', 'e', 'f'], + indices: HashMap::from([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]), + }; + + map_a.merge(map_b).unwrap(); + + assert_eq!(map_a.functions, vec!['a', 'b', 'c', 'd', 'e', 'f']); + assert_eq!( + map_a.indices, + HashMap::from([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + (ArgumentSignature::from_iter([Type::of::()]), 3), + (ArgumentSignature::from_iter([Type::of::()]), 4), + (ArgumentSignature::from_iter([Type::of::()]), 5), + ]) + ); + } + + #[test] + fn should_return_error_on_duplicate_signature() { + let mut map_a = FunctionMap { + info: FunctionInfoType::Overloaded(Box::new([ + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + ])), + functions: vec!['a', 'b', 'c'], + indices: HashMap::from([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]), + }; + + let map_b = FunctionMap { + info: FunctionInfoType::Overloaded(Box::new([ + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + FunctionInfo::anonymous().with_arg::("arg0"), + ])), + functions: vec!['d', 'e', 'f'], + indices: HashMap::from([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]), + }; + + let result = map_a.merge(map_b); + assert_eq!( + result.unwrap_err(), + FunctionOverloadError { + signature: ArgumentSignature::from_iter([Type::of::()]) + } + ); + + // Assert that the original map remains unchanged: + assert_eq!(map_a.functions, vec!['a', 'b', 'c']); + assert_eq!( + map_a.indices, + HashMap::from([ + (ArgumentSignature::from_iter([Type::of::()]), 0), + (ArgumentSignature::from_iter([Type::of::()]), 1), + (ArgumentSignature::from_iter([Type::of::()]), 2), + ]) + ); + } } diff --git a/crates/bevy_reflect/src/func/info.rs b/crates/bevy_reflect/src/func/info.rs index f388708c3e..9a4d1e7929 100644 --- a/crates/bevy_reflect/src/func/info.rs +++ b/crates/bevy_reflect/src/func/info.rs @@ -47,6 +47,18 @@ impl TryFrom> for FunctionInfoType { } } +impl IntoIterator for FunctionInfoType { + type Item = FunctionInfo; + type IntoIter = vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + match self { + FunctionInfoType::Standard(info) => vec![info].into_iter(), + FunctionInfoType::Overloaded(infos) => infos.into_vec().into_iter(), + } + } +} + impl FunctionInfoType { pub fn arg_count(&self) -> usize { match self {