From 6c06fc5b7c5b9b3c34abfbacc73d3aae795a5bab Mon Sep 17 00:00:00 2001 From: Nicola Papale Date: Wed, 13 Jul 2022 16:08:48 +0000 Subject: [PATCH] Add ExactSizeIterator implementation for QueryCombinatonIter (#5148) Following https://github.com/bevyengine/bevy/pull/5124 I decided to add the `ExactSizeIterator` impl for `QueryCombinationIter`. Also: - Clean up the tests for `size_hint` and `len` for both the normal `QueryIter` and `QueryCombinationIter`. - Add tests to `QueryCombinationIter` when it shouldn't be `ExactSizeIterator` --- ## Changelog - Added `ExactSizeIterator` implementation for `QueryCombinatonIter` --- crates/bevy_ecs/src/query/iter.rs | 40 ++- crates/bevy_ecs/src/query/mod.rs | 261 +++++++----------- ...uery_combin_exact_sized_iterator_safety.rs | 20 ++ ..._combin_exact_sized_iterator_safety.stderr | 51 ++++ 4 files changed, 201 insertions(+), 171 deletions(-) create mode 100644 crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.rs create mode 100644 crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.stderr diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index 65881a6a93..498e04fe42 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -343,17 +343,26 @@ where if max_size < K { return (0, Some(0)); } + if max_size == K { + return (1, Some(1)); + } - // n! / k!(n-k)! = (n*n-1*...*n-k+1) / k! - let max_combinations = (0..K) - .try_fold(1usize, |n, i| n.checked_mul(max_size - i)) - .map(|n| { - let k_factorial: usize = (1..=K).product(); - n / k_factorial - }); + // binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k! + // See https://en.wikipedia.org/wiki/Binomial_coefficient + // See https://blog.plover.com/math/choose.html for implementation + // It was chosen to reduce overflow potential. + fn choose(n: usize, k: usize) -> Option { + let ks = 1..=k; + let ns = (n - k + 1..=n).rev(); + ks.zip(ns) + .try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k)) + } + let smallest = K.min(max_size - K); + let max_combinations = choose(max_size, smallest); let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL; - let min_combinations = if archetype_query { max_size } else { 0 }; + let known_max = max_combinations.unwrap_or(usize::MAX); + let min_combinations = if archetype_query { known_max } else { 0 }; (min_combinations, max_combinations) } } @@ -372,6 +381,21 @@ where } } +impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery + ArchetypeFilter, const K: usize> + ExactSizeIterator for QueryCombinationIter<'w, 's, Q, F, K> +where + QueryFetch<'w, Q>: Clone, + QueryFetch<'w, F>: Clone, +{ + /// Returns the exact length of the iterator. + /// + /// **NOTE**: When the iterator length overflows `usize`, this will + /// return `usize::MAX`. + fn len(&self) -> usize { + self.size_hint().0 + } +} + // This is correct as [`QueryCombinationIter`] always returns `None` once exhausted. impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> FusedIterator for QueryCombinationIter<'w, 's, Q, F, K> diff --git a/crates/bevy_ecs/src/query/mod.rs b/crates/bevy_ecs/src/query/mod.rs index 0b260a1a80..2a6da06c31 100644 --- a/crates/bevy_ecs/src/query/mod.rs +++ b/crates/bevy_ecs/src/query/mod.rs @@ -20,9 +20,11 @@ pub(crate) unsafe fn debug_checked_unreachable() -> ! { #[cfg(test)] mod tests { use super::WorldQuery; - use crate::prelude::{AnyOf, Entity, Or, With, Without}; + use crate::prelude::{AnyOf, Entity, Or, QueryState, With, Without}; + use crate::query::{ArchetypeFilter, QueryCombinationIter, QueryFetch, ReadOnlyWorldQuery}; use crate::system::{IntoSystem, Query, System}; use crate::{self as bevy_ecs, component::Component, world::World}; + use std::any::type_name; use std::collections::HashSet; #[derive(Component, Debug, Hash, Eq, PartialEq, Clone, Copy)] @@ -54,24 +56,81 @@ mod tests { } #[test] - fn query_filtered_len() { + fn query_filtered_exactsizeiterator_len() { + fn choose(n: usize, k: usize) -> usize { + if n == 0 || k == 0 || n < k { + return 0; + } + let ks = 1..=k; + let ns = (n - k + 1..=n).rev(); + ks.zip(ns).fold(1, |acc, (k, n)| acc * n / k) + } + fn assert_combination(world: &mut World, expected_size: usize) + where + Q: ReadOnlyWorldQuery, + F: ReadOnlyWorldQuery + ArchetypeFilter, + for<'w> QueryFetch<'w, Q>: Clone, + for<'w> QueryFetch<'w, F>: Clone, + { + let mut query = world.query_filtered::(); + let iter = query.iter_combinations::(world); + let query_type = type_name::>(); + assert_all_sizes_iterator_equal(iter, expected_size, query_type); + } + fn assert_all_sizes_equal(world: &mut World, expected_size: usize) + where + Q: ReadOnlyWorldQuery, + F: ReadOnlyWorldQuery + ArchetypeFilter, + for<'w> QueryFetch<'w, Q>: Clone, + for<'w> QueryFetch<'w, F>: Clone, + { + let mut query = world.query_filtered::(); + let iter = query.iter(world); + let query_type = type_name::>(); + assert_all_sizes_iterator_equal(iter, expected_size, query_type); + + let expected = expected_size; + assert_combination::(world, choose(expected, 0)); + assert_combination::(world, choose(expected, 1)); + assert_combination::(world, choose(expected, 2)); + assert_combination::(world, choose(expected, 5)); + assert_combination::(world, choose(expected, 43)); + assert_combination::(world, choose(expected, 128)); + } + fn assert_all_sizes_iterator_equal( + iterator: impl ExactSizeIterator, + expected_size: usize, + query_type: &'static str, + ) { + let size_hint_0 = iterator.size_hint().0; + let size_hint_1 = iterator.size_hint().1; + let len = iterator.len(); + // `count` tests that not only it is the expected value, but also + // the value is accurate to what the query returns. + let count = iterator.count(); + // This will show up when one of the asserts in this function fails + println!( + r#"query declared sizes: +for query: {query_type} +expected: {expected_size} +len(): {len} +size_hint().0: {size_hint_0} +size_hint().1: {size_hint_1:?} +count(): {count}"# + ); + assert_eq!(len, expected_size); + assert_eq!(size_hint_0, expected_size); + assert_eq!(size_hint_1, Some(expected_size)); + assert_eq!(count, expected_size); + } + let mut world = World::new(); world.spawn().insert_bundle((A(1), B(1))); world.spawn().insert_bundle((A(2),)); world.spawn().insert_bundle((A(3),)); - let mut values = world.query_filtered::<&A, With>(); - let n = 1; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Without>(); - let n = 2; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); + assert_all_sizes_equal::<&A, With>(&mut world, 1); + assert_all_sizes_equal::<&A, Without>(&mut world, 2); let mut world = World::new(); world.spawn().insert_bundle((A(1), B(1), C(1))); @@ -86,110 +145,37 @@ mod tests { world.spawn().insert_bundle((A(10),)); // With/Without for B and C - let mut values = world.query_filtered::<&A, With>(); - let n = 3; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, With>(); - let n = 4; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Without>(); - let n = 7; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Without>(); - let n = 6; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); + assert_all_sizes_equal::<&A, With>(&mut world, 3); + assert_all_sizes_equal::<&A, With>(&mut world, 4); + assert_all_sizes_equal::<&A, Without>(&mut world, 7); + assert_all_sizes_equal::<&A, Without>(&mut world, 6); // With/Without (And) combinations - let mut values = world.query_filtered::<&A, (With, With)>(); - let n = 1; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, (With, Without)>(); - let n = 2; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, (Without, With)>(); - let n = 3; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, (Without, Without)>(); - let n = 4; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); + assert_all_sizes_equal::<&A, (With, With)>(&mut world, 1); + assert_all_sizes_equal::<&A, (With, Without)>(&mut world, 2); + assert_all_sizes_equal::<&A, (Without, With)>(&mut world, 3); + assert_all_sizes_equal::<&A, (Without, Without)>(&mut world, 4); // With/Without Or<()> combinations - let mut values = world.query_filtered::<&A, Or<(With, With)>>(); - let n = 6; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Or<(With, Without)>>(); - let n = 7; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Or<(Without, With)>>(); - let n = 8; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Or<(Without, Without)>>(); - let n = 9; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); + assert_all_sizes_equal::<&A, Or<(With, With)>>(&mut world, 6); + assert_all_sizes_equal::<&A, Or<(With, Without)>>(&mut world, 7); + assert_all_sizes_equal::<&A, Or<(Without, With)>>(&mut world, 8); + assert_all_sizes_equal::<&A, Or<(Without, Without)>>(&mut world, 9); + assert_all_sizes_equal::<&A, (Or<(With,)>, Or<(With,)>)>(&mut world, 1); + assert_all_sizes_equal::<&A, Or<(Or<(With, With)>, With)>>(&mut world, 6); - let mut values = world.query_filtered::<&A, (Or<(With,)>, Or<(With,)>)>(); - let n = 1; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Or<(Or<(With, With)>, With)>>(); - let n = 6; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); + for i in 11..14 { + world.spawn().insert_bundle((A(i), D(i))); + } - world.spawn().insert_bundle((A(11), D(11))); + assert_all_sizes_equal::<&A, Or<(Or<(With, With)>, With)>>(&mut world, 9); + assert_all_sizes_equal::<&A, Or<(Or<(With, With)>, Without)>>(&mut world, 10); - let mut values = world.query_filtered::<&A, Or<(Or<(With, With)>, With)>>(); - let n = 7; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); - let mut values = world.query_filtered::<&A, Or<(Or<(With, With)>, Without)>>(); - let n = 10; - assert_eq!(values.iter(&world).size_hint().0, n); - assert_eq!(values.iter(&world).size_hint().1.unwrap(), n); - assert_eq!(values.iter(&world).len(), n); - assert_eq!(values.iter(&world).count(), n); + // a fair amount of entities + for i in 14..20 { + world.spawn().insert_bundle((C(i), D(i))); + } + assert_all_sizes_equal::, With)>(&mut world, 6); } #[test] @@ -201,23 +187,6 @@ mod tests { world.spawn().insert_bundle((A(3),)); world.spawn().insert_bundle((A(4),)); - let mut a_query = world.query::<&A>(); - let w = &world; - assert_eq!(a_query.iter_combinations::<0>(w).count(), 0); - assert_eq!(a_query.iter_combinations::<0>(w).size_hint().1, Some(0)); - assert_eq!(a_query.iter_combinations::<1>(w).count(), 4); - assert_eq!(a_query.iter_combinations::<1>(w).size_hint().1, Some(4)); - assert_eq!(a_query.iter_combinations::<2>(w).count(), 6); - assert_eq!(a_query.iter_combinations::<2>(w).size_hint().1, Some(6)); - assert_eq!(a_query.iter_combinations::<3>(w).count(), 4); - assert_eq!(a_query.iter_combinations::<3>(w).size_hint().1, Some(4)); - assert_eq!(a_query.iter_combinations::<4>(w).count(), 1); - assert_eq!(a_query.iter_combinations::<4>(w).size_hint().1, Some(1)); - assert_eq!(a_query.iter_combinations::<5>(w).count(), 0); - assert_eq!(a_query.iter_combinations::<5>(w).size_hint().1, Some(0)); - assert_eq!(a_query.iter_combinations::<128>(w).count(), 0); - assert_eq!(a_query.iter_combinations::<128>(w).size_hint().1, Some(0)); - let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect(); assert_eq!( values, @@ -230,8 +199,7 @@ mod tests { [&A(3), &A(4)], ] ); - let size = a_query.iter_combinations::<3>(&world).size_hint(); - assert_eq!(size.1, Some(4)); + let mut a_query = world.query::<&A>(); let values: Vec<[&A; 3]> = a_query.iter_combinations(&world).collect(); assert_eq!( values, @@ -282,40 +250,7 @@ mod tests { world.spawn().insert_bundle((A(3),)); world.spawn().insert_bundle((A(4),)); - let mut a_with_b = world.query_filtered::<&A, With>(); - let w = &world; - assert_eq!(a_with_b.iter_combinations::<0>(w).count(), 0); - assert_eq!(a_with_b.iter_combinations::<0>(w).size_hint().1, Some(0)); - assert_eq!(a_with_b.iter_combinations::<1>(w).count(), 1); - assert_eq!(a_with_b.iter_combinations::<1>(w).size_hint().1, Some(1)); - assert_eq!(a_with_b.iter_combinations::<2>(w).count(), 0); - assert_eq!(a_with_b.iter_combinations::<2>(w).size_hint().1, Some(0)); - assert_eq!(a_with_b.iter_combinations::<3>(w).count(), 0); - assert_eq!(a_with_b.iter_combinations::<3>(w).size_hint().1, Some(0)); - assert_eq!(a_with_b.iter_combinations::<4>(w).count(), 0); - assert_eq!(a_with_b.iter_combinations::<4>(w).size_hint().1, Some(0)); - assert_eq!(a_with_b.iter_combinations::<5>(w).count(), 0); - assert_eq!(a_with_b.iter_combinations::<5>(w).size_hint().1, Some(0)); - assert_eq!(a_with_b.iter_combinations::<128>(w).count(), 0); - assert_eq!(a_with_b.iter_combinations::<128>(w).size_hint().1, Some(0)); - let mut a_wout_b = world.query_filtered::<&A, Without>(); - let w = &world; - assert_eq!(a_wout_b.iter_combinations::<0>(w).count(), 0); - assert_eq!(a_wout_b.iter_combinations::<0>(w).size_hint().1, Some(0)); - assert_eq!(a_wout_b.iter_combinations::<1>(w).count(), 3); - assert_eq!(a_wout_b.iter_combinations::<1>(w).size_hint().1, Some(3)); - assert_eq!(a_wout_b.iter_combinations::<2>(w).count(), 3); - assert_eq!(a_wout_b.iter_combinations::<2>(w).size_hint().1, Some(3)); - assert_eq!(a_wout_b.iter_combinations::<3>(w).count(), 1); - assert_eq!(a_wout_b.iter_combinations::<3>(w).size_hint().1, Some(1)); - assert_eq!(a_wout_b.iter_combinations::<4>(w).count(), 0); - assert_eq!(a_wout_b.iter_combinations::<4>(w).size_hint().1, Some(0)); - assert_eq!(a_wout_b.iter_combinations::<5>(w).count(), 0); - assert_eq!(a_wout_b.iter_combinations::<5>(w).size_hint().1, Some(0)); - assert_eq!(a_wout_b.iter_combinations::<128>(w).count(), 0); - assert_eq!(a_wout_b.iter_combinations::<128>(w).size_hint().1, Some(0)); - let values: HashSet<[&A; 2]> = a_wout_b.iter_combinations(&world).collect(); assert_eq!( values, diff --git a/crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.rs b/crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.rs new file mode 100644 index 0000000000..6a9f1775c4 --- /dev/null +++ b/crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.rs @@ -0,0 +1,20 @@ +use bevy_ecs::prelude::*; + +#[derive(Component)] +struct Foo; +#[derive(Component)] +struct Bar; + +fn on_changed(query: Query<&Foo, Or<(Changed, With)>>) { + // this should fail to compile + is_exact_size_iterator(query.iter_combinations::<2>()); +} + +fn on_added(query: Query<&Foo, (Added, Without)>) { + // this should fail to compile + is_exact_size_iterator(query.iter_combinations::<2>()); +} + +fn is_exact_size_iterator(_iter: T) {} + +fn main() {} diff --git a/crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.stderr b/crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.stderr new file mode 100644 index 0000000000..3de9d2c089 --- /dev/null +++ b/crates/bevy_ecs_compile_fail_tests/tests/ui/query_combin_exact_sized_iterator_safety.stderr @@ -0,0 +1,51 @@ +error[E0277]: the trait bound `bevy_ecs::query::Changed: ArchetypeFilter` is not satisfied + --> tests/ui/query_combin_exact_sized_iterator_safety.rs:10:28 + | +10 | is_exact_size_iterator(query.iter_combinations::<2>()); + | ---------------------- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `ArchetypeFilter` is not implemented for `bevy_ecs::query::Changed` + | | + | required by a bound introduced by this call + | + = help: the following other types implement trait `ArchetypeFilter`: + () + (F0, F1) + (F0, F1, F2) + (F0, F1, F2, F3) + (F0, F1, F2, F3, F4) + (F0, F1, F2, F3, F4, F5) + (F0, F1, F2, F3, F4, F5, F6) + (F0, F1, F2, F3, F4, F5, F6, F7) + and 26 others + = note: required because of the requirements on the impl of `ArchetypeFilter` for `bevy_ecs::query::Or<(bevy_ecs::query::Changed, bevy_ecs::query::With)>` + = note: required because of the requirements on the impl of `ExactSizeIterator` for `QueryCombinationIter<'_, '_, &Foo, bevy_ecs::query::Or<(bevy_ecs::query::Changed, bevy_ecs::query::With)>, 2_usize>` +note: required by a bound in `is_exact_size_iterator` + --> tests/ui/query_combin_exact_sized_iterator_safety.rs:18:30 + | +18 | fn is_exact_size_iterator(_iter: T) {} + | ^^^^^^^^^^^^^^^^^ required by this bound in `is_exact_size_iterator` + +error[E0277]: the trait bound `bevy_ecs::query::Added: ArchetypeFilter` is not satisfied + --> tests/ui/query_combin_exact_sized_iterator_safety.rs:15:28 + | +15 | is_exact_size_iterator(query.iter_combinations::<2>()); + | ---------------------- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `ArchetypeFilter` is not implemented for `bevy_ecs::query::Added` + | | + | required by a bound introduced by this call + | + = help: the following other types implement trait `ArchetypeFilter`: + () + (F0, F1) + (F0, F1, F2) + (F0, F1, F2, F3) + (F0, F1, F2, F3, F4) + (F0, F1, F2, F3, F4, F5) + (F0, F1, F2, F3, F4, F5, F6) + (F0, F1, F2, F3, F4, F5, F6, F7) + and 26 others + = note: required because of the requirements on the impl of `ArchetypeFilter` for `(bevy_ecs::query::Added, bevy_ecs::query::Without)` + = note: required because of the requirements on the impl of `ExactSizeIterator` for `QueryCombinationIter<'_, '_, &Foo, (bevy_ecs::query::Added, bevy_ecs::query::Without), 2_usize>` +note: required by a bound in `is_exact_size_iterator` + --> tests/ui/query_combin_exact_sized_iterator_safety.rs:18:30 + | +18 | fn is_exact_size_iterator(_iter: T) {} + | ^^^^^^^^^^^^^^^^^ required by this bound in `is_exact_size_iterator`