From 77ebabc4fe0224565a2039bd9c0195901560b67f Mon Sep 17 00:00:00 2001 From: JMS55 <47158642+JMS55@users.noreply.github.com> Date: Sat, 4 May 2024 12:56:19 -0700 Subject: [PATCH] Meshlet remove per-cluster data upload (#13125) # Objective - Per-cluster (instance of a meshlet) data upload is ridiculously expensive in both CPU and GPU time (8 bytes per cluster, millions of clusters, you very quickly run into PCIE bandwidth maximums, and lots of CPU-side copies and malloc). - We need to be uploading only per-instance/entity data. Anything else needs to be done on the GPU. ## Solution - Per instance, upload: - `meshlet_instance_meshlet_counts_prefix_sum` - An exclusive prefix sum over the count of how many clusters each instance has. - `meshlet_instance_meshlet_slice_starts` - The starting index of the meshlets for each instance within the `meshlets` buffer. - A new `fill_cluster_buffers` pass once at the start of the frame has a thread per cluster, and finds its instance ID and meshlet ID via a binary search of `meshlet_instance_meshlet_counts_prefix_sum` to find what instance it belongs to, and then uses that plus `meshlet_instance_meshlet_slice_starts` to find what number meshlet within the instance it is. The shader then writes out the per-cluster instance/meshlet ID buffers for later passes to quickly read from. - I've gone from 45 -> 180 FPS in my stress test scene, and saved ~30ms/frame of overall CPU/GPU time. --- .../bevy_pbr/src/meshlet/cull_meshlets.wgsl | 16 +- .../src/meshlet/fill_cluster_buffers.wgsl | 42 +++++ crates/bevy_pbr/src/meshlet/gpu_scene.rs | 164 +++++++++++++----- .../src/meshlet/material_draw_nodes.rs | 6 +- .../src/meshlet/meshlet_bindings.wgsl | 44 +++-- crates/bevy_pbr/src/meshlet/mod.rs | 22 ++- crates/bevy_pbr/src/meshlet/pipelines.rs | 29 +++- .../src/meshlet/visibility_buffer_raster.wgsl | 8 +- .../meshlet/visibility_buffer_raster_node.rs | 61 +++++-- .../meshlet/visibility_buffer_resolve.wgsl | 15 +- 10 files changed, 305 insertions(+), 102 deletions(-) create mode 100644 crates/bevy_pbr/src/meshlet/fill_cluster_buffers.wgsl diff --git a/crates/bevy_pbr/src/meshlet/cull_meshlets.wgsl b/crates/bevy_pbr/src/meshlet/cull_meshlets.wgsl index 2e04f3332b20d..abfd9aed55db1 100644 --- a/crates/bevy_pbr/src/meshlet/cull_meshlets.wgsl +++ b/crates/bevy_pbr/src/meshlet/cull_meshlets.wgsl @@ -1,14 +1,14 @@ #import bevy_pbr::meshlet_bindings::{ - meshlet_thread_meshlet_ids, + meshlet_cluster_meshlet_ids, meshlet_bounding_spheres, - meshlet_thread_instance_ids, + meshlet_cluster_instance_ids, meshlet_instance_uniforms, meshlet_second_pass_candidates, depth_pyramid, view, previous_view, should_cull_instance, - meshlet_is_second_pass_candidate, + cluster_is_second_pass_candidate, meshlets, draw_indirect_args, draw_triangle_buffer, @@ -21,7 +21,7 @@ /// the instance, frustum, and LOD tests in the first pass, but were not visible last frame according to the occlusion culling. @compute -@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 instanced meshlet per thread +@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread fn cull_meshlets( @builtin(workgroup_id) workgroup_id: vec3, @builtin(num_workgroups) num_workgroups: vec3, @@ -29,21 +29,21 @@ fn cull_meshlets( ) { // Calculate the cluster ID for this thread let cluster_id = local_invocation_id.x + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u)); - if cluster_id >= arrayLength(&meshlet_thread_meshlet_ids) { return; } + if cluster_id >= arrayLength(&meshlet_cluster_meshlet_ids) { return; } #ifdef MESHLET_SECOND_CULLING_PASS - if !meshlet_is_second_pass_candidate(cluster_id) { return; } + if !cluster_is_second_pass_candidate(cluster_id) { return; } #endif // Check for instance culling - let instance_id = meshlet_thread_instance_ids[cluster_id]; + let instance_id = meshlet_cluster_instance_ids[cluster_id]; #ifdef MESHLET_FIRST_CULLING_PASS if should_cull_instance(instance_id) { return; } #endif // Calculate world-space culling bounding sphere for the cluster let instance_uniform = meshlet_instance_uniforms[instance_id]; - let meshlet_id = meshlet_thread_meshlet_ids[cluster_id]; + let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; let model = affine3_to_square(instance_uniform.model); let model_scale = max(length(model[0]), max(length(model[1]), length(model[2]))); let bounding_spheres = meshlet_bounding_spheres[meshlet_id]; diff --git a/crates/bevy_pbr/src/meshlet/fill_cluster_buffers.wgsl b/crates/bevy_pbr/src/meshlet/fill_cluster_buffers.wgsl new file mode 100644 index 0000000000000..89e64de0c197b --- /dev/null +++ b/crates/bevy_pbr/src/meshlet/fill_cluster_buffers.wgsl @@ -0,0 +1,42 @@ +#import bevy_pbr::meshlet_bindings::{ + cluster_count, + meshlet_instance_meshlet_counts_prefix_sum, + meshlet_instance_meshlet_slice_starts, + meshlet_cluster_instance_ids, + meshlet_cluster_meshlet_ids, +} + +@compute +@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread +fn fill_cluster_buffers( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(num_workgroups) num_workgroups: vec3, + @builtin(local_invocation_id) local_invocation_id: vec3 +) { + // Calculate the cluster ID for this thread + let cluster_id = local_invocation_id.x + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u)); + if cluster_id >= cluster_count { return; } + + // Binary search to find the instance this cluster belongs to + var left = 0u; + var right = arrayLength(&meshlet_instance_meshlet_counts_prefix_sum) - 1u; + while left <= right { + let mid = (left + right) / 2u; + if meshlet_instance_meshlet_counts_prefix_sum[mid] <= cluster_id { + left = mid + 1u; + } else { + right = mid - 1u; + } + } + let instance_id = right; + + // Find the meshlet ID for this cluster within the instance's MeshletMesh + let meshlet_id_local = cluster_id - meshlet_instance_meshlet_counts_prefix_sum[instance_id]; + + // Find the overall meshlet ID in the global meshlet buffer + let meshlet_id = meshlet_id_local + meshlet_instance_meshlet_slice_starts[instance_id]; + + // Write results to buffers + meshlet_cluster_instance_ids[cluster_id] = instance_id; + meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id; +} diff --git a/crates/bevy_pbr/src/meshlet/gpu_scene.rs b/crates/bevy_pbr/src/meshlet/gpu_scene.rs index 4cffa6da714bf..a986260003c71 100644 --- a/crates/bevy_pbr/src/meshlet/gpu_scene.rs +++ b/crates/bevy_pbr/src/meshlet/gpu_scene.rs @@ -31,7 +31,7 @@ use std::{ iter, mem::size_of, ops::{DerefMut, Range}, - sync::Arc, + sync::{atomic::AtomicBool, Arc}, }; /// Create and queue for uploading to the GPU [`MeshUniform`] components for @@ -91,17 +91,14 @@ pub fn extract_meshlet_meshes( } for ( - instance_index, - ( - instance, - handle, - transform, - previous_transform, - render_layers, - not_shadow_receiver, - not_shadow_caster, - ), - ) in instances_query.iter().enumerate() + instance, + handle, + transform, + previous_transform, + render_layers, + not_shadow_receiver, + not_shadow_caster, + ) in &instances_query { // Skip instances with an unloaded MeshletMesh asset if asset_server.is_managed(handle.id()) @@ -117,7 +114,6 @@ pub fn extract_meshlet_meshes( not_shadow_caster, handle, &mut assets, - instance_index as u32, ); // Build a MeshUniform for each instance @@ -235,12 +231,12 @@ pub fn prepare_meshlet_per_frame_resources( &render_queue, ); upload_storage_buffer( - &mut gpu_scene.thread_instance_ids, + &mut gpu_scene.instance_meshlet_counts_prefix_sum, &render_device, &render_queue, ); upload_storage_buffer( - &mut gpu_scene.thread_meshlet_ids, + &mut gpu_scene.instance_meshlet_slice_starts, &render_device, &render_queue, ); @@ -248,6 +244,34 @@ pub fn prepare_meshlet_per_frame_resources( // Early submission for GPU data uploads to start while the render graph records commands render_queue.submit([]); + let needed_buffer_size = 4 * gpu_scene.scene_meshlet_count as u64; + match &mut gpu_scene.cluster_instance_ids { + Some(buffer) if buffer.size() >= needed_buffer_size => buffer.clone(), + slot => { + let buffer = render_device.create_buffer(&BufferDescriptor { + label: Some("meshlet_cluster_instance_ids"), + size: needed_buffer_size, + usage: BufferUsages::STORAGE, + mapped_at_creation: false, + }); + *slot = Some(buffer.clone()); + buffer + } + }; + match &mut gpu_scene.cluster_meshlet_ids { + Some(buffer) if buffer.size() >= needed_buffer_size => buffer.clone(), + slot => { + let buffer = render_device.create_buffer(&BufferDescriptor { + label: Some("meshlet_cluster_meshlet_ids"), + size: needed_buffer_size, + usage: BufferUsages::STORAGE, + mapped_at_creation: false, + }); + *slot = Some(buffer.clone()); + buffer + } + }; + let needed_buffer_size = 4 * gpu_scene.scene_triangle_count; let visibility_buffer_draw_triangle_buffer = match &mut gpu_scene.visibility_buffer_draw_triangle_buffer { @@ -456,18 +480,44 @@ pub fn prepare_meshlet_view_bind_groups( render_device: Res, mut commands: Commands, ) { - let (Some(view_uniforms), Some(previous_view_uniforms)) = ( + let ( + Some(cluster_instance_ids), + Some(cluster_meshlet_ids), + Some(view_uniforms), + Some(previous_view_uniforms), + ) = ( + gpu_scene.cluster_instance_ids.as_ref(), + gpu_scene.cluster_meshlet_ids.as_ref(), view_uniforms.uniforms.binding(), previous_view_uniforms.uniforms.binding(), - ) else { + ) + else { return; }; + let first_node = Arc::new(AtomicBool::new(true)); + + // TODO: Some of these bind groups can be reused across multiple views for (view_entity, view_resources, view_depth) in &views { let entries = BindGroupEntries::sequential(( - gpu_scene.thread_meshlet_ids.binding().unwrap(), + gpu_scene + .instance_meshlet_counts_prefix_sum + .binding() + .unwrap(), + gpu_scene.instance_meshlet_slice_starts.binding().unwrap(), + cluster_instance_ids.as_entire_binding(), + cluster_meshlet_ids.as_entire_binding(), + )); + let fill_cluster_buffers = render_device.create_bind_group( + "meshlet_fill_cluster_buffers", + &gpu_scene.fill_cluster_buffers_bind_group_layout, + &entries, + ); + + let entries = BindGroupEntries::sequential(( + cluster_meshlet_ids.as_entire_binding(), gpu_scene.meshlet_bounding_spheres.binding(), - gpu_scene.thread_instance_ids.binding().unwrap(), + cluster_instance_ids.as_entire_binding(), gpu_scene.instance_uniforms.binding().unwrap(), view_resources.instance_visibility.as_entire_binding(), view_resources @@ -491,9 +541,9 @@ pub fn prepare_meshlet_view_bind_groups( ); let entries = BindGroupEntries::sequential(( - gpu_scene.thread_meshlet_ids.binding().unwrap(), + cluster_meshlet_ids.as_entire_binding(), gpu_scene.meshlet_bounding_spheres.binding(), - gpu_scene.thread_instance_ids.binding().unwrap(), + cluster_instance_ids.as_entire_binding(), gpu_scene.instance_uniforms.binding().unwrap(), view_resources.instance_visibility.as_entire_binding(), view_resources @@ -539,12 +589,12 @@ pub fn prepare_meshlet_view_bind_groups( .collect(); let entries = BindGroupEntries::sequential(( - gpu_scene.thread_meshlet_ids.binding().unwrap(), + cluster_meshlet_ids.as_entire_binding(), gpu_scene.meshlets.binding(), gpu_scene.indices.binding(), gpu_scene.vertex_ids.binding(), gpu_scene.vertex_data.binding(), - gpu_scene.thread_instance_ids.binding().unwrap(), + cluster_instance_ids.as_entire_binding(), gpu_scene.instance_uniforms.binding().unwrap(), gpu_scene.instance_material_ids.binding().unwrap(), view_resources @@ -581,12 +631,12 @@ pub fn prepare_meshlet_view_bind_groups( .map(|visibility_buffer| { let entries = BindGroupEntries::sequential(( &visibility_buffer.default_view, - gpu_scene.thread_meshlet_ids.binding().unwrap(), + cluster_meshlet_ids.as_entire_binding(), gpu_scene.meshlets.binding(), gpu_scene.indices.binding(), gpu_scene.vertex_ids.binding(), gpu_scene.vertex_data.binding(), - gpu_scene.thread_instance_ids.binding().unwrap(), + cluster_instance_ids.as_entire_binding(), gpu_scene.instance_uniforms.binding().unwrap(), )); render_device.create_bind_group( @@ -597,6 +647,8 @@ pub fn prepare_meshlet_view_bind_groups( }); commands.entity(view_entity).insert(MeshletViewBindGroups { + first_node: Arc::clone(&first_node), + fill_cluster_buffers, culling_first, culling_second, downsample_depth, @@ -629,12 +681,15 @@ pub struct MeshletGpuScene { /// Per-view per-instance visibility bit. Used for [`RenderLayers`] and [`NotShadowCaster`] support. view_instance_visibility: EntityHashMap>>, instance_material_ids: StorageBuffer>, - thread_instance_ids: StorageBuffer>, - thread_meshlet_ids: StorageBuffer>, + instance_meshlet_counts_prefix_sum: StorageBuffer>, + instance_meshlet_slice_starts: StorageBuffer>, + cluster_instance_ids: Option, + cluster_meshlet_ids: Option, second_pass_candidates_buffer: Option, previous_depth_pyramids: EntityHashMap, visibility_buffer_draw_triangle_buffer: Option, + fill_cluster_buffers_bind_group_layout: BindGroupLayout, culling_bind_group_layout: BindGroupLayout, visibility_buffer_raster_bind_group_layout: BindGroupLayout, downsample_depth_bind_group_layout: BindGroupLayout, @@ -675,21 +730,35 @@ impl FromWorld for MeshletGpuScene { buffer.set_label(Some("meshlet_instance_material_ids")); buffer }, - thread_instance_ids: { + instance_meshlet_counts_prefix_sum: { let mut buffer = StorageBuffer::default(); - buffer.set_label(Some("meshlet_thread_instance_ids")); + buffer.set_label(Some("meshlet_instance_meshlet_counts_prefix_sum")); buffer }, - thread_meshlet_ids: { + instance_meshlet_slice_starts: { let mut buffer = StorageBuffer::default(); - buffer.set_label(Some("meshlet_thread_meshlet_ids")); + buffer.set_label(Some("meshlet_instance_meshlet_slice_starts")); buffer }, + cluster_instance_ids: None, + cluster_meshlet_ids: None, second_pass_candidates_buffer: None, previous_depth_pyramids: EntityHashMap::default(), visibility_buffer_draw_triangle_buffer: None, // TODO: Buffer min sizes + fill_cluster_buffers_bind_group_layout: render_device.create_bind_group_layout( + "meshlet_fill_cluster_buffers_bind_group_layout", + &BindGroupLayoutEntries::sequential( + ShaderStages::COMPUTE, + ( + storage_buffer_read_only_sized(false, None), + storage_buffer_read_only_sized(false, None), + storage_buffer_sized(false, None), + storage_buffer_sized(false, None), + ), + ), + ), culling_bind_group_layout: render_device.create_bind_group_layout( "meshlet_culling_bind_group_layout", &BindGroupLayoutEntries::sequential( @@ -784,8 +853,8 @@ impl MeshletGpuScene { .for_each(|b| b.get_mut().clear()); self.instance_uniforms.get_mut().clear(); self.instance_material_ids.get_mut().clear(); - self.thread_instance_ids.get_mut().clear(); - self.thread_meshlet_ids.get_mut().clear(); + self.instance_meshlet_counts_prefix_sum.get_mut().clear(); + self.instance_meshlet_slice_starts.get_mut().clear(); // TODO: Remove unused entries for view_instance_visibility and previous_depth_pyramids } @@ -796,7 +865,6 @@ impl MeshletGpuScene { not_shadow_caster: bool, handle: &Handle, assets: &mut Assets, - instance_index: u32, ) { let queue_meshlet_mesh = |asset_id: &AssetId| { let meshlet_mesh = assets.remove_untracked(*asset_id).expect( @@ -833,11 +901,6 @@ impl MeshletGpuScene { ) }; - // Append instance data for this frame - self.instances - .push((instance, render_layers, not_shadow_caster)); - self.instance_material_ids.get_mut().push(0); - // If the MeshletMesh asset has not been uploaded to the GPU yet, queue it for uploading let ([_, _, _, meshlets_slice, _], triangle_count) = self .meshlet_mesh_slices @@ -848,14 +911,19 @@ impl MeshletGpuScene { let meshlets_slice = (meshlets_slice.start as u32 / size_of::() as u32) ..(meshlets_slice.end as u32 / size_of::() as u32); + // Append instance data for this frame + self.instances + .push((instance, render_layers, not_shadow_caster)); + self.instance_material_ids.get_mut().push(0); + self.instance_meshlet_counts_prefix_sum + .get_mut() + .push(self.scene_meshlet_count); + self.instance_meshlet_slice_starts + .get_mut() + .push(meshlets_slice.start); + self.scene_meshlet_count += meshlets_slice.end - meshlets_slice.start; self.scene_triangle_count += triangle_count; - - // Append per-cluster data for this frame - self.thread_instance_ids - .get_mut() - .extend(std::iter::repeat(instance_index).take(meshlets_slice.len())); - self.thread_meshlet_ids.get_mut().extend(meshlets_slice); } /// Get the depth value for use with the material depth texture for a given [`Material`] asset. @@ -873,6 +941,10 @@ impl MeshletGpuScene { self.material_ids_present_in_scene.contains(material_id) } + pub fn fill_cluster_buffers_bind_group_layout(&self) -> BindGroupLayout { + self.fill_cluster_buffers_bind_group_layout.clone() + } + pub fn culling_bind_group_layout(&self) -> BindGroupLayout { self.culling_bind_group_layout.clone() } @@ -912,6 +984,8 @@ pub struct MeshletViewResources { #[derive(Component)] pub struct MeshletViewBindGroups { + pub first_node: Arc, + pub fill_cluster_buffers: BindGroup, pub culling_first: BindGroup, pub culling_second: BindGroup, pub downsample_depth: Box<[BindGroup]>, diff --git a/crates/bevy_pbr/src/meshlet/material_draw_nodes.rs b/crates/bevy_pbr/src/meshlet/material_draw_nodes.rs index bbe1676bbe076..e327eadfcc9bf 100644 --- a/crates/bevy_pbr/src/meshlet/material_draw_nodes.rs +++ b/crates/bevy_pbr/src/meshlet/material_draw_nodes.rs @@ -116,8 +116,8 @@ impl ViewNode for MeshletMainOpaquePass3dNode { pipeline_cache.get_render_pipeline(*material_pipeline_id) { let x = *material_id * 3; - render_pass.set_bind_group(2, material_bind_group, &[]); render_pass.set_render_pipeline(material_pipeline); + render_pass.set_bind_group(2, material_bind_group, &[]); render_pass.draw(x..(x + 3), 0..1); } } @@ -237,8 +237,8 @@ impl ViewNode for MeshletPrepassNode { pipeline_cache.get_render_pipeline(*material_pipeline_id) { let x = *material_id * 3; - render_pass.set_bind_group(2, material_bind_group, &[]); render_pass.set_render_pipeline(material_pipeline); + render_pass.set_bind_group(2, material_bind_group, &[]); render_pass.draw(x..(x + 3), 0..1); } } @@ -363,8 +363,8 @@ impl ViewNode for MeshletDeferredGBufferPrepassNode { pipeline_cache.get_render_pipeline(*material_pipeline_id) { let x = *material_id * 3; - render_pass.set_bind_group(2, material_bind_group, &[]); render_pass.set_render_pipeline(material_pipeline); + render_pass.set_bind_group(2, material_bind_group, &[]); render_pass.draw(x..(x + 3), 0..1); } } diff --git a/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl b/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl index 2ca98b5d41bac..a3f18cbc9b29e 100644 --- a/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl +++ b/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl @@ -51,14 +51,22 @@ struct DrawIndirectArgs { first_instance: u32, } +#ifdef MESHLET_FILL_CLUSTER_BUFFERS_PASS +var cluster_count: u32; +@group(0) @binding(0) var meshlet_instance_meshlet_counts_prefix_sum: array; // Per entity instance +@group(0) @binding(1) var meshlet_instance_meshlet_slice_starts: array; // Per entity instance +@group(0) @binding(2) var meshlet_cluster_instance_ids: array; // Per cluster +@group(0) @binding(3) var meshlet_cluster_meshlet_ids: array; // Per cluster +#endif + #ifdef MESHLET_CULLING_PASS -@group(0) @binding(0) var meshlet_thread_meshlet_ids: array; // Per cluster (instance of a meshlet) -@group(0) @binding(1) var meshlet_bounding_spheres: array; // Per asset meshlet -@group(0) @binding(2) var meshlet_thread_instance_ids: array; // Per cluster (instance of a meshlet) +@group(0) @binding(0) var meshlet_cluster_meshlet_ids: array; // Per cluster +@group(0) @binding(1) var meshlet_bounding_spheres: array; // Per meshlet +@group(0) @binding(2) var meshlet_cluster_instance_ids: array; // Per cluster @group(0) @binding(3) var meshlet_instance_uniforms: array; // Per entity instance @group(0) @binding(4) var meshlet_view_instance_visibility: array; // 1 bit per entity instance, packed as a bitmask -@group(0) @binding(5) var meshlet_second_pass_candidates: array>; // 1 bit per cluster (instance of a meshlet), packed as a bitmask -@group(0) @binding(6) var meshlets: array; // Per asset meshlet +@group(0) @binding(5) var meshlet_second_pass_candidates: array>; // 1 bit per cluster , packed as a bitmask +@group(0) @binding(6) var meshlets: array; // Per meshlet @group(0) @binding(7) var draw_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups/meshlets/triangles @group(0) @binding(8) var draw_triangle_buffer: array; // Single object shared between all workgroups/meshlets/triangles @group(0) @binding(9) var depth_pyramid: texture_2d; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass @@ -71,7 +79,7 @@ fn should_cull_instance(instance_id: u32) -> bool { return bool(extractBits(packed_visibility, bit_offset, 1u)); } -fn meshlet_is_second_pass_candidate(cluster_id: u32) -> bool { +fn cluster_is_second_pass_candidate(cluster_id: u32) -> bool { let packed_candidates = meshlet_second_pass_candidates[cluster_id / 32u]; let bit_offset = cluster_id % 32u; return bool(extractBits(packed_candidates, bit_offset, 1u)); @@ -79,12 +87,12 @@ fn meshlet_is_second_pass_candidate(cluster_id: u32) -> bool { #endif #ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS -@group(0) @binding(0) var meshlet_thread_meshlet_ids: array; // Per cluster (instance of a meshlet) -@group(0) @binding(1) var meshlets: array; // Per asset meshlet -@group(0) @binding(2) var meshlet_indices: array; // Many per asset meshlet -@group(0) @binding(3) var meshlet_vertex_ids: array; // Many per asset meshlet -@group(0) @binding(4) var meshlet_vertex_data: array; // Many per asset meshlet -@group(0) @binding(5) var meshlet_thread_instance_ids: array; // Per cluster (instance of a meshlet) +@group(0) @binding(0) var meshlet_cluster_meshlet_ids: array; // Per cluster +@group(0) @binding(1) var meshlets: array; // Per meshlet +@group(0) @binding(2) var meshlet_indices: array; // Many per meshlet +@group(0) @binding(3) var meshlet_vertex_ids: array; // Many per meshlet +@group(0) @binding(4) var meshlet_vertex_data: array; // Many per meshlet +@group(0) @binding(5) var meshlet_cluster_instance_ids: array; // Per cluster @group(0) @binding(6) var meshlet_instance_uniforms: array; // Per entity instance @group(0) @binding(7) var meshlet_instance_material_ids: array; // Per entity instance @group(0) @binding(8) var draw_triangle_buffer: array; // Single object shared between all workgroups/meshlets/triangles @@ -99,12 +107,12 @@ fn get_meshlet_index(index_id: u32) -> u32 { #ifdef MESHLET_MESH_MATERIAL_PASS @group(1) @binding(0) var meshlet_visibility_buffer: texture_2d; // Generated from the meshlet raster passes -@group(1) @binding(1) var meshlet_thread_meshlet_ids: array; // Per cluster (instance of a meshlet) -@group(1) @binding(2) var meshlets: array; // Per asset meshlet -@group(1) @binding(3) var meshlet_indices: array; // Many per asset meshlet -@group(1) @binding(4) var meshlet_vertex_ids: array; // Many per asset meshlet -@group(1) @binding(5) var meshlet_vertex_data: array; // Many per asset meshlet -@group(1) @binding(6) var meshlet_thread_instance_ids: array; // Per cluster (instance of a meshlet) +@group(1) @binding(1) var meshlet_cluster_meshlet_ids: array; // Per cluster +@group(1) @binding(2) var meshlets: array; // Per meshlet +@group(1) @binding(3) var meshlet_indices: array; // Many per meshlet +@group(1) @binding(4) var meshlet_vertex_ids: array; // Many per meshlet +@group(1) @binding(5) var meshlet_vertex_data: array; // Many per meshlet +@group(1) @binding(6) var meshlet_cluster_instance_ids: array; // Per cluster @group(1) @binding(7) var meshlet_instance_uniforms: array; // Per entity instance fn get_meshlet_index(index_id: u32) -> u32 { diff --git a/crates/bevy_pbr/src/meshlet/mod.rs b/crates/bevy_pbr/src/meshlet/mod.rs index 00c7629bcba3f..69255f9084cf1 100644 --- a/crates/bevy_pbr/src/meshlet/mod.rs +++ b/crates/bevy_pbr/src/meshlet/mod.rs @@ -49,7 +49,8 @@ use self::{ }, pipelines::{ MeshletPipelines, MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE, MESHLET_CULLING_SHADER_HANDLE, - MESHLET_DOWNSAMPLE_DEPTH_SHADER_HANDLE, MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE, + MESHLET_DOWNSAMPLE_DEPTH_SHADER_HANDLE, MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE, + MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE, }, visibility_buffer_raster_node::MeshletVisibilityBufferRasterPassNode, }; @@ -74,6 +75,8 @@ use bevy_ecs::{ use bevy_render::{ render_graph::{RenderGraphApp, ViewNodeRunner}, render_resource::{Shader, TextureUsages}, + renderer::RenderDevice, + settings::WgpuFeatures, view::{ check_visibility, prepare_view_targets, InheritedVisibility, Msaa, ViewVisibility, Visibility, VisibilitySystems, @@ -105,7 +108,7 @@ const MESHLET_MESH_MATERIAL_SHADER_HANDLE: Handle = /// /// This plugin is not compatible with [`Msaa`], and adding this plugin will disable it. /// -/// This plugin does not work on the WebGL2 backend. +/// This plugin does not work on WASM. /// /// ![A render of the Stanford dragon as a `MeshletMesh`](https://raw.githubusercontent.com/bevyengine/bevy/main/crates/bevy_pbr/src/meshlet/meshlet_preview.png) pub struct MeshletPlugin; @@ -124,6 +127,12 @@ impl Plugin for MeshletPlugin { "visibility_buffer_resolve.wgsl", Shader::from_wgsl ); + load_internal_asset!( + app, + MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE, + "fill_cluster_buffers.wgsl", + Shader::from_wgsl + ); load_internal_asset!( app, MESHLET_CULLING_SHADER_HANDLE, @@ -169,6 +178,15 @@ impl Plugin for MeshletPlugin { return; }; + if !render_app + .world() + .resource::() + .features() + .contains(WgpuFeatures::PUSH_CONSTANTS) + { + panic!("MeshletPlugin can't be used. GPU lacks support: WgpuFeatures::PUSH_CONSTANTS is not supported."); + } + render_app .add_render_graph_node::( Core3d, diff --git a/crates/bevy_pbr/src/meshlet/pipelines.rs b/crates/bevy_pbr/src/meshlet/pipelines.rs index bb62c6bdf5020..551efbe176f19 100644 --- a/crates/bevy_pbr/src/meshlet/pipelines.rs +++ b/crates/bevy_pbr/src/meshlet/pipelines.rs @@ -9,16 +9,19 @@ use bevy_ecs::{ }; use bevy_render::render_resource::*; -pub const MESHLET_CULLING_SHADER_HANDLE: Handle = Handle::weak_from_u128(4325134235233421); +pub const MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE: Handle = + Handle::weak_from_u128(4325134235233421); +pub const MESHLET_CULLING_SHADER_HANDLE: Handle = Handle::weak_from_u128(5325134235233421); pub const MESHLET_DOWNSAMPLE_DEPTH_SHADER_HANDLE: Handle = - Handle::weak_from_u128(5325134235233421); -pub const MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE: Handle = Handle::weak_from_u128(6325134235233421); -pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle = +pub const MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE: Handle = Handle::weak_from_u128(7325134235233421); +pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle = + Handle::weak_from_u128(8325134235233421); #[derive(Resource)] pub struct MeshletPipelines { + fill_cluster_buffers: CachedComputePipelineId, cull_first: CachedComputePipelineId, cull_second: CachedComputePipelineId, downsample_depth: CachedRenderPipelineId, @@ -31,6 +34,8 @@ pub struct MeshletPipelines { impl FromWorld for MeshletPipelines { fn from_world(world: &mut World) -> Self { let gpu_scene = world.resource::(); + let fill_cluster_buffers_bind_group_layout = + gpu_scene.fill_cluster_buffers_bind_group_layout(); let cull_layout = gpu_scene.culling_bind_group_layout(); let downsample_depth_layout = gpu_scene.downsample_depth_bind_group_layout(); let visibility_buffer_layout = gpu_scene.visibility_buffer_raster_bind_group_layout(); @@ -38,6 +43,20 @@ impl FromWorld for MeshletPipelines { let pipeline_cache = world.resource_mut::(); Self { + fill_cluster_buffers: pipeline_cache.queue_compute_pipeline( + ComputePipelineDescriptor { + label: Some("meshlet_fill_cluster_buffers_pipeline".into()), + layout: vec![fill_cluster_buffers_bind_group_layout.clone()], + push_constant_ranges: vec![PushConstantRange { + stages: ShaderStages::COMPUTE, + range: 0..4, + }], + shader: MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE, + shader_defs: vec!["MESHLET_FILL_CLUSTER_BUFFERS_PASS".into()], + entry_point: "fill_cluster_buffers".into(), + }, + ), + cull_first: pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor { label: Some("meshlet_culling_first_pipeline".into()), layout: vec![cull_layout.clone()], @@ -242,6 +261,7 @@ impl MeshletPipelines { pub fn get( world: &World, ) -> Option<( + &ComputePipeline, &ComputePipeline, &ComputePipeline, &RenderPipeline, @@ -253,6 +273,7 @@ impl MeshletPipelines { let pipeline_cache = world.get_resource::()?; let pipeline = world.get_resource::()?; Some(( + pipeline_cache.get_compute_pipeline(pipeline.fill_cluster_buffers)?, pipeline_cache.get_compute_pipeline(pipeline.cull_first)?, pipeline_cache.get_compute_pipeline(pipeline.cull_second)?, pipeline_cache.get_render_pipeline(pipeline.downsample_depth)?, diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_raster.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_raster.wgsl index e2c716de162e7..b72079b7f1065 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_raster.wgsl +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_raster.wgsl @@ -1,10 +1,10 @@ #import bevy_pbr::{ meshlet_bindings::{ - meshlet_thread_meshlet_ids, + meshlet_cluster_meshlet_ids, meshlets, meshlet_vertex_ids, meshlet_vertex_data, - meshlet_thread_instance_ids, + meshlet_cluster_instance_ids, meshlet_instance_uniforms, meshlet_instance_material_ids, draw_triangle_buffer, @@ -42,12 +42,12 @@ fn vertex(@builtin(vertex_index) vertex_index: u32) -> VertexOutput { let cluster_id = packed_ids >> 6u; let triangle_id = extractBits(packed_ids, 0u, 6u); let index_id = (triangle_id * 3u) + (vertex_index % 3u); - let meshlet_id = meshlet_thread_meshlet_ids[cluster_id]; + let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; let meshlet = meshlets[meshlet_id]; let index = get_meshlet_index(meshlet.start_index_id + index_id); let vertex_id = meshlet_vertex_ids[meshlet.start_vertex_id + index]; let vertex = unpack_meshlet_vertex(meshlet_vertex_data[vertex_id]); - let instance_id = meshlet_thread_instance_ids[cluster_id]; + let instance_id = meshlet_cluster_instance_ids[cluster_id]; let instance_uniform = meshlet_instance_uniforms[instance_id]; let model = affine3_to_square(instance_uniform.model); diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_raster_node.rs b/crates/bevy_pbr/src/meshlet/visibility_buffer_raster_node.rs index 54303af71c74a..f3ffb1865ed50 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_raster_node.rs +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_raster_node.rs @@ -15,6 +15,7 @@ use bevy_render::{ renderer::RenderContext, view::{ViewDepthTexture, ViewUniformOffset}, }; +use std::sync::atomic::Ordering; /// Rasterize meshlets into a depth buffer, and optional visibility buffer + material depth buffer for shading passes. pub struct MeshletVisibilityBufferRasterPassNode { @@ -72,6 +73,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode { }; let Some(( + fill_cluster_buffers_pipeline, culling_first_pipeline, culling_second_pipeline, downsample_depth_pipeline, @@ -84,9 +86,14 @@ impl Node for MeshletVisibilityBufferRasterPassNode { return Ok(()); }; - let culling_workgroups = (meshlet_view_resources.scene_meshlet_count.div_ceil(128) as f32) - .cbrt() - .ceil() as u32; + let first_node = meshlet_view_bind_groups + .first_node + .fetch_and(false, Ordering::SeqCst); + + let thread_per_cluster_workgroups = + (meshlet_view_resources.scene_meshlet_count.div_ceil(128) as f32) + .cbrt() + .ceil() as u32; render_context .command_encoder() @@ -96,6 +103,15 @@ impl Node for MeshletVisibilityBufferRasterPassNode { 0, None, ); + if first_node { + fill_cluster_buffers_pass( + render_context, + &meshlet_view_bind_groups.fill_cluster_buffers, + fill_cluster_buffers_pipeline, + thread_per_cluster_workgroups, + meshlet_view_resources.scene_meshlet_count, + ); + } cull_pass( "culling_first", render_context, @@ -103,7 +119,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode { view_offset, previous_view_offset, culling_first_pipeline, - culling_workgroups, + thread_per_cluster_workgroups, ); raster_pass( true, @@ -129,7 +145,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode { view_offset, previous_view_offset, culling_second_pipeline, - culling_workgroups, + thread_per_cluster_workgroups, ); raster_pass( false, @@ -191,7 +207,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode { view_offset, previous_view_offset, culling_first_pipeline, - culling_workgroups, + thread_per_cluster_workgroups, ); raster_pass( true, @@ -217,7 +233,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode { view_offset, previous_view_offset, culling_second_pipeline, - culling_workgroups, + thread_per_cluster_workgroups, ); raster_pass( false, @@ -243,6 +259,29 @@ impl Node for MeshletVisibilityBufferRasterPassNode { } } +// TODO: Reuse same compute pass as cull_pass +fn fill_cluster_buffers_pass( + render_context: &mut RenderContext, + fill_cluster_buffers_bind_group: &BindGroup, + fill_cluster_buffers_pass_pipeline: &ComputePipeline, + fill_cluster_buffers_pass_workgroups: u32, + cluster_count: u32, +) { + let command_encoder = render_context.command_encoder(); + let mut cull_pass = command_encoder.begin_compute_pass(&ComputePassDescriptor { + label: Some("fill_cluster_buffers"), + timestamp_writes: None, + }); + cull_pass.set_pipeline(fill_cluster_buffers_pass_pipeline); + cull_pass.set_push_constants(0, &cluster_count.to_le_bytes()); + cull_pass.set_bind_group(0, fill_cluster_buffers_bind_group, &[]); + cull_pass.dispatch_workgroups( + fill_cluster_buffers_pass_workgroups, + fill_cluster_buffers_pass_workgroups, + fill_cluster_buffers_pass_workgroups, + ); +} + fn cull_pass( label: &'static str, render_context: &mut RenderContext, @@ -257,12 +296,12 @@ fn cull_pass( label: Some(label), timestamp_writes: None, }); + cull_pass.set_pipeline(culling_pipeline); cull_pass.set_bind_group( 0, culling_bind_group, &[view_offset.offset, previous_view_offset.offset], ); - cull_pass.set_pipeline(culling_pipeline); cull_pass.dispatch_workgroups(culling_workgroups, culling_workgroups, culling_workgroups); } @@ -327,12 +366,12 @@ fn raster_pass( draw_pass.set_camera_viewport(viewport); } + draw_pass.set_render_pipeline(visibility_buffer_raster_pipeline); draw_pass.set_bind_group( 0, &meshlet_view_bind_groups.visibility_buffer_raster, &[view_offset.offset], ); - draw_pass.set_render_pipeline(visibility_buffer_raster_pipeline); draw_pass.draw_indirect(visibility_buffer_draw_indirect_args, 0); } @@ -363,8 +402,8 @@ fn downsample_depth( }; let mut downsample_pass = render_context.begin_tracked_render_pass(downsample_pass); - downsample_pass.set_bind_group(0, &meshlet_view_bind_groups.downsample_depth[i], &[]); downsample_pass.set_render_pipeline(downsample_depth_pipeline); + downsample_pass.set_bind_group(0, &meshlet_view_bind_groups.downsample_depth[i], &[]); downsample_pass.draw(0..3, 0..1); } @@ -400,8 +439,8 @@ fn copy_material_depth_pass( copy_pass.set_camera_viewport(viewport); } - copy_pass.set_bind_group(0, copy_material_depth_bind_group, &[]); copy_pass.set_render_pipeline(copy_material_depth_pipeline); + copy_pass.set_bind_group(0, copy_material_depth_bind_group, &[]); copy_pass.draw(0..3, 0..1); } } diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl index 947c9d49be99c..7f8e50573e109 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl @@ -3,11 +3,11 @@ #import bevy_pbr::{ meshlet_bindings::{ meshlet_visibility_buffer, - meshlet_thread_meshlet_ids, + meshlet_cluster_meshlet_ids, meshlets, meshlet_vertex_ids, meshlet_vertex_data, - meshlet_thread_instance_ids, + meshlet_cluster_instance_ids, meshlet_instance_uniforms, get_meshlet_index, unpack_meshlet_vertex, @@ -95,11 +95,11 @@ struct VertexOutput { /// Load the visibility buffer texture and resolve it into a VertexOutput. fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { - let vbuffer = textureLoad(meshlet_visibility_buffer, vec2(frag_coord.xy), 0).r; - let cluster_id = vbuffer >> 6u; - let meshlet_id = meshlet_thread_meshlet_ids[cluster_id]; + let packed_ids = textureLoad(meshlet_visibility_buffer, vec2(frag_coord.xy), 0).r; + let cluster_id = packed_ids >> 6u; + let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; let meshlet = meshlets[meshlet_id]; - let triangle_id = extractBits(vbuffer, 0u, 6u); + let triangle_id = extractBits(packed_ids, 0u, 6u); let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u); let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z)); let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]); @@ -107,13 +107,14 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]); let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]); - let instance_id = meshlet_thread_instance_ids[cluster_id]; + let instance_id = meshlet_cluster_instance_ids[cluster_id]; let instance_uniform = meshlet_instance_uniforms[instance_id]; let model = affine3_to_square(instance_uniform.model); let world_position_1 = mesh_position_local_to_world(model, vec4(vertex_1.position, 1.0)); let world_position_2 = mesh_position_local_to_world(model, vec4(vertex_2.position, 1.0)); let world_position_3 = mesh_position_local_to_world(model, vec4(vertex_3.position, 1.0)); + let clip_position_1 = position_world_to_clip(world_position_1.xyz); let clip_position_2 = position_world_to_clip(world_position_2.xyz); let clip_position_3 = position_world_to_clip(world_position_3.xyz);