Skip to content

Commit

Permalink
Fix query transmute from table to archetype iteration unsoundness (#1…
Browse files Browse the repository at this point in the history
…4615)

# Objective

- Fixes #14348 
- Fixes #14528
- Less complex (but also likely less performant) alternative to #14611

## Solution

- Add a `is_dense` field flag to `QueryIter` indicating whether it is
dense or not, that is whether it can perform dense iteration or not;
- Check this flag any time iteration over a query is performed.

---

It would be nice if someone could try benching this change to see if it
actually matters.

~Note that this not 100% ready for mergin, since there are a bunch of
safety comments on the use of the various `IS_DENSE` for checks that
still need to be updated.~ This is ready modulo benchmarks

---------

Co-authored-by: Alice Cecile <alice.i.cecile@gmail.com>
  • Loading branch information
SkiFire13 and alice-i-cecile authored Aug 27, 2024
1 parent f06cd44 commit e320fa0
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 35 deletions.
44 changes: 44 additions & 0 deletions crates/bevy_ecs/src/query/builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::marker::PhantomData;

use crate::component::StorageType;
use crate::{component::ComponentId, prelude::*};

use super::{FilteredAccess, QueryData, QueryFilter};
Expand Down Expand Up @@ -68,6 +69,26 @@ impl<'w, D: QueryData, F: QueryFilter> QueryBuilder<'w, D, F> {
}
}

pub(super) fn is_dense(&self) -> bool {
// Note: `component_id` comes from the user in safe code, so we cannot trust it to
// exist. If it doesn't exist we pessimistically assume it's sparse.
let is_dense = |component_id| {
self.world()
.components()
.get_info(component_id)
.map_or(false, |info| info.storage_type() == StorageType::Table)
};

self.access
.access()
.component_reads_and_writes()
.all(is_dense)
&& self.access.access().archetypal().all(is_dense)
&& !self.access.access().has_read_all_components()
&& self.access.with_filters().all(is_dense)
&& self.access.without_filters().all(is_dense)
}

/// Returns a reference to the world passed to [`Self::new`].
pub fn world(&self) -> &World {
self.world
Expand Down Expand Up @@ -396,4 +417,27 @@ mod tests {
assert_eq!(1, b.deref::<B>().0);
}
}

/// Regression test for issue #14348
#[test]
fn builder_static_dense_dynamic_sparse() {
#[derive(Component)]
struct Dense;

#[derive(Component)]
#[component(storage = "SparseSet")]
struct Sparse;

let mut world = World::new();

world.spawn(Dense);
world.spawn((Dense, Sparse));

let mut query = QueryBuilder::<&Dense>::new(&mut world)
.with::<Sparse>()
.build();

let matched = query.iter(&world).count();
assert_eq!(matched, 1);
}
}
64 changes: 37 additions & 27 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
/// # Safety
/// - all `rows` must be in `[0, table.entity_count)`.
/// - `table` must match D and F
/// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true.
/// - The query iteration must be dense (i.e. `self.query_state.is_dense` must be true).
#[inline]
pub(super) unsafe fn fold_over_table_range<B, Func>(
&mut self,
Expand Down Expand Up @@ -183,7 +183,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
/// - The query iteration must not be dense (i.e. `self.query_state.is_dense` must be false).
#[inline]
pub(super) unsafe fn fold_over_archetype_range<B, Func>(
&mut self,
Expand Down Expand Up @@ -252,7 +252,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - `archetype` must have the same length with it's table.
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
/// - The query iteration must not be dense (i.e. `self.query_state.is_dense` must be false).
#[inline]
pub(super) unsafe fn fold_over_dense_archetype_range<B, Func>(
&mut self,
Expand Down Expand Up @@ -1031,40 +1031,47 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for QueryIter<'w, 's, D, F>
let Some(item) = self.next() else { break };
accum = func(accum, item);
}
for id in self.cursor.storage_id_iter.clone() {
if D::IS_DENSE && F::IS_DENSE {

if self.cursor.is_dense {
for id in self.cursor.storage_id_iter.clone() {
// SAFETY: `self.cursor.is_dense` is true, so storage ids are guaranteed to be table ids.
let table_id = unsafe { id.table_id };
// SAFETY: Matched table IDs are guaranteed to still exist.
let table = unsafe { self.tables.get(id.table_id).debug_checked_unwrap() };
let table = unsafe { self.tables.get(table_id).debug_checked_unwrap() };

accum =
// SAFETY:
// - The fetched table matches both D and F
// - The provided range is equivalent to [0, table.entity_count)
// - The if block ensures that D::IS_DENSE and F::IS_DENSE are both true
// - The if block ensures that the query iteration is dense
unsafe { self.fold_over_table_range(accum, &mut func, table, 0..table.entity_count()) };
} else {
let archetype =
// SAFETY: Matched archetype IDs are guaranteed to still exist.
unsafe { self.archetypes.get(id.archetype_id).debug_checked_unwrap() };
}
} else {
for id in self.cursor.storage_id_iter.clone() {
// SAFETY: `self.cursor.is_dense` is false, so storage ids are guaranteed to be archetype ids.
let archetype_id = unsafe { id.archetype_id };
// SAFETY: Matched archetype IDs are guaranteed to still exist.
let archetype = unsafe { self.archetypes.get(archetype_id).debug_checked_unwrap() };
// SAFETY: Matched table IDs are guaranteed to still exist.
let table = unsafe { self.tables.get(archetype.table_id()).debug_checked_unwrap() };

// When an archetype and its table have equal entity counts, dense iteration can be safely used.
// this leverages cache locality to optimize performance.
if table.entity_count() == archetype.len() {
accum =
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided archetype and its' table have the same length.
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that ether D::IS_DENSE or F::IS_DENSE are false
unsafe { self.fold_over_dense_archetype_range(accum, &mut func, archetype,0..archetype.len()) };
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided archetype and its' table have the same length.
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that the query iteration is not dense.
unsafe { self.fold_over_dense_archetype_range(accum, &mut func, archetype, 0..archetype.len()) };
} else {
accum =
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that ether D::IS_DENSE or F::IS_DENSE are false
unsafe { self.fold_over_archetype_range(accum, &mut func, archetype,0..archetype.len()) };
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that the query iteration is not dense.
unsafe { self.fold_over_archetype_range(accum, &mut func, archetype, 0..archetype.len()) };
}
}
}
Expand Down Expand Up @@ -1675,6 +1682,8 @@ impl<'w, 's, D: QueryData, F: QueryFilter, const K: usize> Debug
}

struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> {
// whether the query iteration is dense or not. Mirrors QueryState's `is_dense` field.
is_dense: bool,
storage_id_iter: std::slice::Iter<'s, StorageId>,
table_entities: &'w [Entity],
archetype_entities: &'w [ArchetypeEntity],
Expand All @@ -1689,6 +1698,7 @@ struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> {
impl<D: QueryData, F: QueryFilter> Clone for QueryIterationCursor<'_, '_, D, F> {
fn clone(&self) -> Self {
Self {
is_dense: self.is_dense,
storage_id_iter: self.storage_id_iter.clone(),
table_entities: self.table_entities,
archetype_entities: self.archetype_entities,
Expand All @@ -1701,8 +1711,6 @@ impl<D: QueryData, F: QueryFilter> Clone for QueryIterationCursor<'_, '_, D, F>
}

impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
const IS_DENSE: bool = D::IS_DENSE && F::IS_DENSE;

unsafe fn init_empty(
world: UnsafeWorldCell<'w>,
query_state: &'s QueryState<D, F>,
Expand Down Expand Up @@ -1732,13 +1740,15 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
table_entities: &[],
archetype_entities: &[],
storage_id_iter: query_state.matched_storage_ids.iter(),
is_dense: query_state.is_dense,
current_len: 0,
current_row: 0,
}
}

fn reborrow(&mut self) -> QueryIterationCursor<'_, 's, D, F> {
QueryIterationCursor {
is_dense: self.is_dense,
fetch: D::shrink_fetch(self.fetch.clone()),
filter: F::shrink_fetch(self.filter.clone()),
table_entities: self.table_entities,
Expand All @@ -1754,7 +1764,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
unsafe fn peek_last(&mut self) -> Option<D::Item<'w>> {
if self.current_row > 0 {
let index = self.current_row - 1;
if Self::IS_DENSE {
if self.is_dense {
let entity = self.table_entities.get_unchecked(index);
Some(D::fetch(
&mut self.fetch,
Expand All @@ -1780,7 +1790,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
/// will be **the exact count of remaining values**.
fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let ids = self.storage_id_iter.clone();
let remaining_matched: usize = if Self::IS_DENSE {
let remaining_matched: usize = if self.is_dense {
// SAFETY: The if check ensures that storage_id_iter stores TableIds
unsafe { ids.map(|id| tables[id.table_id].entity_count()).sum() }
} else {
Expand All @@ -1803,7 +1813,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
archetypes: &'w Archetypes,
query_state: &'s QueryState<D, F>,
) -> Option<D::Item<'w>> {
if Self::IS_DENSE {
if self.is_dense {
loop {
// we are on the beginning of the query, or finished processing a table, so skip to the next
if self.current_row == self.current_len {
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_ecs/src/query/par_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
fn get_batch_size(&self, thread_count: usize) -> usize {
let max_items = || {
let id_iter = self.state.matched_storage_ids.iter();
if D::IS_DENSE && F::IS_DENSE {
if self.state.is_dense {
// SAFETY: We only access table metadata.
let tables = unsafe { &self.world.world_metadata().storages().tables };
id_iter
Expand Down
Loading

0 comments on commit e320fa0

Please sign in to comment.