From d389485aa6a58a6c2fc0a6dcad7f47366ded682a Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Sun, 15 Dec 2024 15:07:45 -0800 Subject: [PATCH] perf(bvh-region): more efficient range query --- crates/bvh-region/src/lib.rs | 89 -------------------- crates/bvh-region/src/query.rs | 1 + crates/bvh-region/src/query/closest.rs | 13 +-- crates/bvh-region/src/query/range.rs | 107 +++++++++++++++++++++++++ crates/bvh-region/tests/simple.rs | 56 ++++++++++++- crates/geometry/src/aabb.rs | 25 ++++++ crates/spatial/src/lib.rs | 2 +- 7 files changed, 190 insertions(+), 103 deletions(-) create mode 100644 crates/bvh-region/src/query/range.rs diff --git a/crates/bvh-region/src/lib.rs b/crates/bvh-region/src/lib.rs index cf458a18..94599431 100644 --- a/crates/bvh-region/src/lib.rs +++ b/crates/bvh-region/src/lib.rs @@ -9,7 +9,6 @@ use std::fmt::Debug; use arrayvec::ArrayVec; use geometry::aabb::Aabb; -use glam::Vec3; const ELEMENTS_TO_ACTIVATE_LEAF: usize = 16; const VOLUME_TO_ACTIVATE_LEAF: f32 = 5.0; @@ -17,8 +16,6 @@ const VOLUME_TO_ACTIVATE_LEAF: f32 = 5.0; mod node; use node::BvhNode; -use crate::utils::GetAabb; - mod build; mod query; mod utils; @@ -197,91 +194,5 @@ impl BvhNode { } } -struct BvhIter<'a, T> { - bvh: &'a Bvh, - target: Aabb, -} - -impl<'a, T> BvhIter<'a, T> { - fn consume( - bvh: &'a Bvh, - target: Aabb, - get_aabb: impl GetAabb + 'a, - ) -> Box + 'a> { - let root = bvh.root(); - - let root = match root { - Node::Internal(internal) => internal, - Node::Leaf(leaf) => { - for elem in leaf.iter() { - let aabb = get_aabb(elem); - if aabb.collides(&target) { - return Box::new(std::iter::once(elem)); - } - } - return Box::new(std::iter::empty()); - } - }; - - if !root.aabb.collides(&target) { - return Box::new(std::iter::empty()); - } - - let iter = Self { target, bvh }; - - Box::new(iter.process(root, get_aabb)) - } - - #[expect(clippy::excessive_nesting, reason = "todo: fix")] - pub fn process( - self, - on: &'a BvhNode, - get_aabb: impl GetAabb, - ) -> impl Iterator { - gen move { - let mut stack: ArrayVec<&'a BvhNode, 64> = ArrayVec::new(); - stack.push(on); - - while let Some(on) = stack.pop() { - for child in on.children(self.bvh) { - match child { - Node::Internal(child) => { - if child.aabb.collides(&self.target) { - stack.push(child); - } - } - Node::Leaf(elements) => { - for elem in elements { - let aabb = get_aabb(elem); - if aabb.collides(&self.target) { - yield elem; - } - } - } - } - } - } - } - } -} - -pub fn random_aabb(width: f32) -> Aabb { - let min = std::array::from_fn(|_| fastrand::f32() * width); - let min = Vec3::from_array(min); - let max = min + Vec3::splat(1.0); - - Aabb::new(min, max) -} - -pub fn create_random_elements_1(count: usize, width: f32) -> Vec { - let mut elements = Vec::new(); - - for _ in 0..count { - elements.push(random_aabb(width)); - } - - elements -} - #[cfg(test)] mod tests; diff --git a/crates/bvh-region/src/query.rs b/crates/bvh-region/src/query.rs index e9cc8996..4bd31c04 100644 --- a/crates/bvh-region/src/query.rs +++ b/crates/bvh-region/src/query.rs @@ -1 +1,2 @@ mod closest; +mod range; diff --git a/crates/bvh-region/src/query/closest.rs b/crates/bvh-region/src/query/closest.rs index c32af02f..68120b90 100644 --- a/crates/bvh-region/src/query/closest.rs +++ b/crates/bvh-region/src/query/closest.rs @@ -3,10 +3,7 @@ use std::{cmp::Reverse, collections::BinaryHeap, fmt::Debug}; use geometry::aabb::Aabb; use glam::Vec3; -use crate::{ - Bvh, BvhIter, Node, - utils::{GetAabb, NodeOrd}, -}; +use crate::{Bvh, Node, utils::NodeOrd}; impl Bvh { /// Returns the closest element to the target and the distance squared to it. @@ -73,12 +70,4 @@ impl Bvh { min_node.map(|elem| (elem, min_dist2)) } - - pub fn get_collisions<'a>( - &'a self, - target: Aabb, - get_aabb: impl GetAabb + 'a, - ) -> impl Iterator + 'a { - BvhIter::consume(self, target, get_aabb) - } } diff --git a/crates/bvh-region/src/query/range.rs b/crates/bvh-region/src/query/range.rs new file mode 100644 index 00000000..f6a09286 --- /dev/null +++ b/crates/bvh-region/src/query/range.rs @@ -0,0 +1,107 @@ +use std::fmt::Debug; + +use arrayvec::ArrayVec; +use geometry::aabb::Aabb; + +use crate::{Bvh, Node, utils::GetAabb}; + +impl Bvh { + pub fn range<'a>( + &'a self, + target: Aabb, + get_aabb: impl GetAabb + 'a, + ) -> impl Iterator + 'a { + CollisionIter::new(self, target, get_aabb) + } +} + +pub struct CollisionIter<'a, T, F> { + bvh: &'a Bvh, + target: Aabb, + get_aabb: F, + stack: ArrayVec, 64>, + current_leaf: Option<(&'a [T], usize)>, +} + +impl<'a, T, F> CollisionIter<'a, T, F> +where + F: GetAabb, +{ + fn new(bvh: &'a Bvh, target: Aabb, get_aabb: F) -> Self { + let mut stack = ArrayVec::new(); + // Initialize stack with root if it collides + match bvh.root() { + Node::Internal(root) => { + if root.aabb.collides(&target) { + stack.push(Node::Internal(root)); + } + } + Node::Leaf(leaf) => { + // We'll handle collision checks in next() as we iterate through leaves + stack.push(Node::Leaf(leaf)); + } + } + + Self { + bvh, + target, + get_aabb, + stack, + current_leaf: None, + } + } +} + +impl<'a, T, F> Iterator for CollisionIter<'a, T, F> +where + F: GetAabb, +{ + type Item = &'a T; + + fn next(&mut self) -> Option { + loop { + // If we're currently iterating over a leaf's elements + if let Some((leaf, index)) = &mut self.current_leaf { + if *index < leaf.len() { + let elem = &leaf[*index]; + *index += 1; + + let elem_aabb = (self.get_aabb)(elem); + if elem_aabb.collides(&self.target) { + return Some(elem); + } + // If not colliding, continue to next element in leaf + continue; + } else { + // Leaf exhausted + self.current_leaf = None; + } + } + + // If no current leaf, pop from stack + let node = self.stack.pop()?; + match node { + Node::Internal(internal) => { + // Push children that potentially collide + for child in internal.children(self.bvh) { + match child { + Node::Internal(child_node) => { + if child_node.aabb.collides(&self.target) { + self.stack.push(Node::Internal(child_node)); + } + } + Node::Leaf(child_leaf) => { + // We'll check collisions inside the leaf iteration + self.stack.push(Node::Leaf(child_leaf)); + } + } + } + } + Node::Leaf(leaf) => { + // Start iterating over this leaf's elements + self.current_leaf = Some((leaf, 0)); + } + } + } + } +} diff --git a/crates/bvh-region/tests/simple.rs b/crates/bvh-region/tests/simple.rs index db62dab6..a76f7b8e 100644 --- a/crates/bvh-region/tests/simple.rs +++ b/crates/bvh-region/tests/simple.rs @@ -1,6 +1,8 @@ +use std::collections::HashSet; + use approx::assert_relative_eq; use bvh_region::Bvh; -use geometry::aabb::Aabb; +use geometry::aabb::{Aabb, OrderedAabb}; use glam::Vec3; use proptest::prelude::*; @@ -102,3 +104,55 @@ proptest! { } } } + +proptest! { + #[test] + fn test_range_correctness( + elements in prop::collection::vec( + (any::(), any::(), any::(), any::(), any::(), any::()) + .prop_map(|(x1, y1, z1, x2, y2, z2)| { + let min_x = x1.min(x2); + let max_x = x1.max(x2); + let min_y = y1.min(y2); + let max_y = y1.max(y2); + let min_z = z1.min(z2); + let max_z = z1.max(z2); + Aabb::from([min_x, min_y, min_z, max_x, max_y, max_z]) + }), + 1..50 + ), + target in (any::(), any::(), any::(), any::(), any::(), any::()) + .prop_map(|(x1, y1, z1, x2, y2, z2)| { + let min_x = x1.min(x2); + let max_x = x1.max(x2); + let min_y = y1.min(y2); + let max_y = y1.max(y2); + let min_z = z1.min(z2); + let max_z = z1.max(z2); + Aabb::from([min_x, min_y, min_z, max_x, max_y, max_z]) + }) + ) { + let bvh = Bvh::build(elements.clone(), copied); + + // Compute brute force collisions + let mut brute_force_set = HashSet::new(); + for aabb in &elements { + if aabb.collides(&target) { + let aabb = OrderedAabb::try_from(*aabb).unwrap(); + brute_force_set.insert(aabb); + } + } + + // Compute BVH collisions + let mut bvh_set = HashSet::new(); + + for candidate in bvh.range(target, copied) { + // Find index of candidate in `elements`: + let candidate = OrderedAabb::try_from(*candidate).unwrap(); + bvh_set.insert(candidate); + } + + // Compare sets + prop_assert_eq!(&bvh_set, &brute_force_set, "Mismatch between BVH range and brute force collision sets: {:?} != {:?}", bvh_set, brute_force_set); + } +} diff --git a/crates/geometry/src/aabb.rs b/crates/geometry/src/aabb.rs index 22513585..1d5855e0 100644 --- a/crates/geometry/src/aabb.rs +++ b/crates/geometry/src/aabb.rs @@ -18,6 +18,31 @@ impl HasAabb for Aabb { } } +#[derive(Copy, Clone, Eq, PartialEq, Debug, Ord, PartialOrd, Hash)] +pub struct OrderedAabb { + min_x: ordered_float::NotNan, + min_y: ordered_float::NotNan, + min_z: ordered_float::NotNan, + max_x: ordered_float::NotNan, + max_y: ordered_float::NotNan, + max_z: ordered_float::NotNan, +} + +impl TryFrom for OrderedAabb { + type Error = ordered_float::FloatIsNan; + + fn try_from(value: Aabb) -> Result { + Ok(Self { + min_x: value.min.x.try_into()?, + min_y: value.min.y.try_into()?, + min_z: value.min.z.try_into()?, + max_x: value.max.x.try_into()?, + max_y: value.max.y.try_into()?, + max_z: value.max.z.try_into()?, + }) + } +} + #[derive(Copy, Clone, PartialEq, Serialize, Deserialize)] pub struct Aabb { pub min: Vec3, diff --git a/crates/spatial/src/lib.rs b/crates/spatial/src/lib.rs index d7d3725f..7ea1a3d8 100644 --- a/crates/spatial/src/lib.rs +++ b/crates/spatial/src/lib.rs @@ -51,7 +51,7 @@ impl SpatialIndex { world: &'a World, ) -> impl Iterator + 'a { let get_aabb = get_aabb_func(world); - self.query.get_collisions(target, get_aabb).copied() + self.query.range(target, get_aabb).copied() } /// Get the closest player to the given position.