Meshlet remove per-cluster data upload (#13125)

# Objective

- Per-cluster (instance of a meshlet) data upload is ridiculously
expensive in both CPU and GPU time (8 bytes per cluster, millions of
clusters, you very quickly run into PCIE bandwidth maximums, and lots of
CPU-side copies and malloc).
- We need to be uploading only per-instance/entity data. Anything else
needs to be done on the GPU.

## Solution

- Per instance, upload:
- `meshlet_instance_meshlet_counts_prefix_sum` - An exclusive prefix sum
over the count of how many clusters each instance has.
- `meshlet_instance_meshlet_slice_starts` - The starting index of the
meshlets for each instance within the `meshlets` buffer.
- A new `fill_cluster_buffers` pass once at the start of the frame has a
thread per cluster, and finds its instance ID and meshlet ID via a
binary search of `meshlet_instance_meshlet_counts_prefix_sum` to find
what instance it belongs to, and then uses that plus
`meshlet_instance_meshlet_slice_starts` to find what number meshlet
within the instance it is. The shader then writes out the per-cluster
instance/meshlet ID buffers for later passes to quickly read from.
- I've gone from 45 -> 180 FPS in my stress test scene, and saved
~30ms/frame of overall CPU/GPU time.
This commit is contained in:
JMS55 2024-05-04 12:56:19 -07:00 committed by GitHub
parent ec418aa429
commit 77ebabc4fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 305 additions and 102 deletions

View file

@ -1,14 +1,14 @@
#import bevy_pbr::meshlet_bindings::{
meshlet_thread_meshlet_ids,
meshlet_cluster_meshlet_ids,
meshlet_bounding_spheres,
meshlet_thread_instance_ids,
meshlet_cluster_instance_ids,
meshlet_instance_uniforms,
meshlet_second_pass_candidates,
depth_pyramid,
view,
previous_view,
should_cull_instance,
meshlet_is_second_pass_candidate,
cluster_is_second_pass_candidate,
meshlets,
draw_indirect_args,
draw_triangle_buffer,
@ -21,7 +21,7 @@
/// the instance, frustum, and LOD tests in the first pass, but were not visible last frame according to the occlusion culling.
@compute
@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 instanced meshlet per thread
@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread
fn cull_meshlets(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@ -29,21 +29,21 @@ fn cull_meshlets(
) {
// Calculate the cluster ID for this thread
let cluster_id = local_invocation_id.x + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
if cluster_id >= arrayLength(&meshlet_thread_meshlet_ids) { return; }
if cluster_id >= arrayLength(&meshlet_cluster_meshlet_ids) { return; }
#ifdef MESHLET_SECOND_CULLING_PASS
if !meshlet_is_second_pass_candidate(cluster_id) { return; }
if !cluster_is_second_pass_candidate(cluster_id) { return; }
#endif
// Check for instance culling
let instance_id = meshlet_thread_instance_ids[cluster_id];
let instance_id = meshlet_cluster_instance_ids[cluster_id];
#ifdef MESHLET_FIRST_CULLING_PASS
if should_cull_instance(instance_id) { return; }
#endif
// Calculate world-space culling bounding sphere for the cluster
let instance_uniform = meshlet_instance_uniforms[instance_id];
let meshlet_id = meshlet_thread_meshlet_ids[cluster_id];
let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id];
let model = affine3_to_square(instance_uniform.model);
let model_scale = max(length(model[0]), max(length(model[1]), length(model[2])));
let bounding_spheres = meshlet_bounding_spheres[meshlet_id];

View file

@ -0,0 +1,42 @@
#import bevy_pbr::meshlet_bindings::{
cluster_count,
meshlet_instance_meshlet_counts_prefix_sum,
meshlet_instance_meshlet_slice_starts,
meshlet_cluster_instance_ids,
meshlet_cluster_meshlet_ids,
}
@compute
@workgroup_size(128, 1, 1) // 128 threads per workgroup, 1 cluster per thread
fn fill_cluster_buffers(
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
@builtin(local_invocation_id) local_invocation_id: vec3<u32>
) {
// Calculate the cluster ID for this thread
let cluster_id = local_invocation_id.x + 128u * dot(workgroup_id, vec3(num_workgroups.x * num_workgroups.x, num_workgroups.x, 1u));
if cluster_id >= cluster_count { return; }
// Binary search to find the instance this cluster belongs to
var left = 0u;
var right = arrayLength(&meshlet_instance_meshlet_counts_prefix_sum) - 1u;
while left <= right {
let mid = (left + right) / 2u;
if meshlet_instance_meshlet_counts_prefix_sum[mid] <= cluster_id {
left = mid + 1u;
} else {
right = mid - 1u;
}
}
let instance_id = right;
// Find the meshlet ID for this cluster within the instance's MeshletMesh
let meshlet_id_local = cluster_id - meshlet_instance_meshlet_counts_prefix_sum[instance_id];
// Find the overall meshlet ID in the global meshlet buffer
let meshlet_id = meshlet_id_local + meshlet_instance_meshlet_slice_starts[instance_id];
// Write results to buffers
meshlet_cluster_instance_ids[cluster_id] = instance_id;
meshlet_cluster_meshlet_ids[cluster_id] = meshlet_id;
}

View file

@ -31,7 +31,7 @@ use std::{
iter,
mem::size_of,
ops::{DerefMut, Range},
sync::Arc,
sync::{atomic::AtomicBool, Arc},
};
/// Create and queue for uploading to the GPU [`MeshUniform`] components for
@ -91,17 +91,14 @@ pub fn extract_meshlet_meshes(
}
for (
instance_index,
(
instance,
handle,
transform,
previous_transform,
render_layers,
not_shadow_receiver,
not_shadow_caster,
),
) in instances_query.iter().enumerate()
instance,
handle,
transform,
previous_transform,
render_layers,
not_shadow_receiver,
not_shadow_caster,
) in &instances_query
{
// Skip instances with an unloaded MeshletMesh asset
if asset_server.is_managed(handle.id())
@ -117,7 +114,6 @@ pub fn extract_meshlet_meshes(
not_shadow_caster,
handle,
&mut assets,
instance_index as u32,
);
// Build a MeshUniform for each instance
@ -235,12 +231,12 @@ pub fn prepare_meshlet_per_frame_resources(
&render_queue,
);
upload_storage_buffer(
&mut gpu_scene.thread_instance_ids,
&mut gpu_scene.instance_meshlet_counts_prefix_sum,
&render_device,
&render_queue,
);
upload_storage_buffer(
&mut gpu_scene.thread_meshlet_ids,
&mut gpu_scene.instance_meshlet_slice_starts,
&render_device,
&render_queue,
);
@ -248,6 +244,34 @@ pub fn prepare_meshlet_per_frame_resources(
// Early submission for GPU data uploads to start while the render graph records commands
render_queue.submit([]);
let needed_buffer_size = 4 * gpu_scene.scene_meshlet_count as u64;
match &mut gpu_scene.cluster_instance_ids {
Some(buffer) if buffer.size() >= needed_buffer_size => buffer.clone(),
slot => {
let buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("meshlet_cluster_instance_ids"),
size: needed_buffer_size,
usage: BufferUsages::STORAGE,
mapped_at_creation: false,
});
*slot = Some(buffer.clone());
buffer
}
};
match &mut gpu_scene.cluster_meshlet_ids {
Some(buffer) if buffer.size() >= needed_buffer_size => buffer.clone(),
slot => {
let buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("meshlet_cluster_meshlet_ids"),
size: needed_buffer_size,
usage: BufferUsages::STORAGE,
mapped_at_creation: false,
});
*slot = Some(buffer.clone());
buffer
}
};
let needed_buffer_size = 4 * gpu_scene.scene_triangle_count;
let visibility_buffer_draw_triangle_buffer =
match &mut gpu_scene.visibility_buffer_draw_triangle_buffer {
@ -456,18 +480,44 @@ pub fn prepare_meshlet_view_bind_groups(
render_device: Res<RenderDevice>,
mut commands: Commands,
) {
let (Some(view_uniforms), Some(previous_view_uniforms)) = (
let (
Some(cluster_instance_ids),
Some(cluster_meshlet_ids),
Some(view_uniforms),
Some(previous_view_uniforms),
) = (
gpu_scene.cluster_instance_ids.as_ref(),
gpu_scene.cluster_meshlet_ids.as_ref(),
view_uniforms.uniforms.binding(),
previous_view_uniforms.uniforms.binding(),
) else {
)
else {
return;
};
let first_node = Arc::new(AtomicBool::new(true));
// TODO: Some of these bind groups can be reused across multiple views
for (view_entity, view_resources, view_depth) in &views {
let entries = BindGroupEntries::sequential((
gpu_scene.thread_meshlet_ids.binding().unwrap(),
gpu_scene
.instance_meshlet_counts_prefix_sum
.binding()
.unwrap(),
gpu_scene.instance_meshlet_slice_starts.binding().unwrap(),
cluster_instance_ids.as_entire_binding(),
cluster_meshlet_ids.as_entire_binding(),
));
let fill_cluster_buffers = render_device.create_bind_group(
"meshlet_fill_cluster_buffers",
&gpu_scene.fill_cluster_buffers_bind_group_layout,
&entries,
);
let entries = BindGroupEntries::sequential((
cluster_meshlet_ids.as_entire_binding(),
gpu_scene.meshlet_bounding_spheres.binding(),
gpu_scene.thread_instance_ids.binding().unwrap(),
cluster_instance_ids.as_entire_binding(),
gpu_scene.instance_uniforms.binding().unwrap(),
view_resources.instance_visibility.as_entire_binding(),
view_resources
@ -491,9 +541,9 @@ pub fn prepare_meshlet_view_bind_groups(
);
let entries = BindGroupEntries::sequential((
gpu_scene.thread_meshlet_ids.binding().unwrap(),
cluster_meshlet_ids.as_entire_binding(),
gpu_scene.meshlet_bounding_spheres.binding(),
gpu_scene.thread_instance_ids.binding().unwrap(),
cluster_instance_ids.as_entire_binding(),
gpu_scene.instance_uniforms.binding().unwrap(),
view_resources.instance_visibility.as_entire_binding(),
view_resources
@ -539,12 +589,12 @@ pub fn prepare_meshlet_view_bind_groups(
.collect();
let entries = BindGroupEntries::sequential((
gpu_scene.thread_meshlet_ids.binding().unwrap(),
cluster_meshlet_ids.as_entire_binding(),
gpu_scene.meshlets.binding(),
gpu_scene.indices.binding(),
gpu_scene.vertex_ids.binding(),
gpu_scene.vertex_data.binding(),
gpu_scene.thread_instance_ids.binding().unwrap(),
cluster_instance_ids.as_entire_binding(),
gpu_scene.instance_uniforms.binding().unwrap(),
gpu_scene.instance_material_ids.binding().unwrap(),
view_resources
@ -581,12 +631,12 @@ pub fn prepare_meshlet_view_bind_groups(
.map(|visibility_buffer| {
let entries = BindGroupEntries::sequential((
&visibility_buffer.default_view,
gpu_scene.thread_meshlet_ids.binding().unwrap(),
cluster_meshlet_ids.as_entire_binding(),
gpu_scene.meshlets.binding(),
gpu_scene.indices.binding(),
gpu_scene.vertex_ids.binding(),
gpu_scene.vertex_data.binding(),
gpu_scene.thread_instance_ids.binding().unwrap(),
cluster_instance_ids.as_entire_binding(),
gpu_scene.instance_uniforms.binding().unwrap(),
));
render_device.create_bind_group(
@ -597,6 +647,8 @@ pub fn prepare_meshlet_view_bind_groups(
});
commands.entity(view_entity).insert(MeshletViewBindGroups {
first_node: Arc::clone(&first_node),
fill_cluster_buffers,
culling_first,
culling_second,
downsample_depth,
@ -629,12 +681,15 @@ pub struct MeshletGpuScene {
/// Per-view per-instance visibility bit. Used for [`RenderLayers`] and [`NotShadowCaster`] support.
view_instance_visibility: EntityHashMap<StorageBuffer<Vec<u32>>>,
instance_material_ids: StorageBuffer<Vec<u32>>,
thread_instance_ids: StorageBuffer<Vec<u32>>,
thread_meshlet_ids: StorageBuffer<Vec<u32>>,
instance_meshlet_counts_prefix_sum: StorageBuffer<Vec<u32>>,
instance_meshlet_slice_starts: StorageBuffer<Vec<u32>>,
cluster_instance_ids: Option<Buffer>,
cluster_meshlet_ids: Option<Buffer>,
second_pass_candidates_buffer: Option<Buffer>,
previous_depth_pyramids: EntityHashMap<TextureView>,
visibility_buffer_draw_triangle_buffer: Option<Buffer>,
fill_cluster_buffers_bind_group_layout: BindGroupLayout,
culling_bind_group_layout: BindGroupLayout,
visibility_buffer_raster_bind_group_layout: BindGroupLayout,
downsample_depth_bind_group_layout: BindGroupLayout,
@ -675,21 +730,35 @@ impl FromWorld for MeshletGpuScene {
buffer.set_label(Some("meshlet_instance_material_ids"));
buffer
},
thread_instance_ids: {
instance_meshlet_counts_prefix_sum: {
let mut buffer = StorageBuffer::default();
buffer.set_label(Some("meshlet_thread_instance_ids"));
buffer.set_label(Some("meshlet_instance_meshlet_counts_prefix_sum"));
buffer
},
thread_meshlet_ids: {
instance_meshlet_slice_starts: {
let mut buffer = StorageBuffer::default();
buffer.set_label(Some("meshlet_thread_meshlet_ids"));
buffer.set_label(Some("meshlet_instance_meshlet_slice_starts"));
buffer
},
cluster_instance_ids: None,
cluster_meshlet_ids: None,
second_pass_candidates_buffer: None,
previous_depth_pyramids: EntityHashMap::default(),
visibility_buffer_draw_triangle_buffer: None,
// TODO: Buffer min sizes
fill_cluster_buffers_bind_group_layout: render_device.create_bind_group_layout(
"meshlet_fill_cluster_buffers_bind_group_layout",
&BindGroupLayoutEntries::sequential(
ShaderStages::COMPUTE,
(
storage_buffer_read_only_sized(false, None),
storage_buffer_read_only_sized(false, None),
storage_buffer_sized(false, None),
storage_buffer_sized(false, None),
),
),
),
culling_bind_group_layout: render_device.create_bind_group_layout(
"meshlet_culling_bind_group_layout",
&BindGroupLayoutEntries::sequential(
@ -784,8 +853,8 @@ impl MeshletGpuScene {
.for_each(|b| b.get_mut().clear());
self.instance_uniforms.get_mut().clear();
self.instance_material_ids.get_mut().clear();
self.thread_instance_ids.get_mut().clear();
self.thread_meshlet_ids.get_mut().clear();
self.instance_meshlet_counts_prefix_sum.get_mut().clear();
self.instance_meshlet_slice_starts.get_mut().clear();
// TODO: Remove unused entries for view_instance_visibility and previous_depth_pyramids
}
@ -796,7 +865,6 @@ impl MeshletGpuScene {
not_shadow_caster: bool,
handle: &Handle<MeshletMesh>,
assets: &mut Assets<MeshletMesh>,
instance_index: u32,
) {
let queue_meshlet_mesh = |asset_id: &AssetId<MeshletMesh>| {
let meshlet_mesh = assets.remove_untracked(*asset_id).expect(
@ -833,11 +901,6 @@ impl MeshletGpuScene {
)
};
// Append instance data for this frame
self.instances
.push((instance, render_layers, not_shadow_caster));
self.instance_material_ids.get_mut().push(0);
// If the MeshletMesh asset has not been uploaded to the GPU yet, queue it for uploading
let ([_, _, _, meshlets_slice, _], triangle_count) = self
.meshlet_mesh_slices
@ -848,14 +911,19 @@ impl MeshletGpuScene {
let meshlets_slice = (meshlets_slice.start as u32 / size_of::<Meshlet>() as u32)
..(meshlets_slice.end as u32 / size_of::<Meshlet>() as u32);
// Append instance data for this frame
self.instances
.push((instance, render_layers, not_shadow_caster));
self.instance_material_ids.get_mut().push(0);
self.instance_meshlet_counts_prefix_sum
.get_mut()
.push(self.scene_meshlet_count);
self.instance_meshlet_slice_starts
.get_mut()
.push(meshlets_slice.start);
self.scene_meshlet_count += meshlets_slice.end - meshlets_slice.start;
self.scene_triangle_count += triangle_count;
// Append per-cluster data for this frame
self.thread_instance_ids
.get_mut()
.extend(std::iter::repeat(instance_index).take(meshlets_slice.len()));
self.thread_meshlet_ids.get_mut().extend(meshlets_slice);
}
/// Get the depth value for use with the material depth texture for a given [`Material`] asset.
@ -873,6 +941,10 @@ impl MeshletGpuScene {
self.material_ids_present_in_scene.contains(material_id)
}
pub fn fill_cluster_buffers_bind_group_layout(&self) -> BindGroupLayout {
self.fill_cluster_buffers_bind_group_layout.clone()
}
pub fn culling_bind_group_layout(&self) -> BindGroupLayout {
self.culling_bind_group_layout.clone()
}
@ -912,6 +984,8 @@ pub struct MeshletViewResources {
#[derive(Component)]
pub struct MeshletViewBindGroups {
pub first_node: Arc<AtomicBool>,
pub fill_cluster_buffers: BindGroup,
pub culling_first: BindGroup,
pub culling_second: BindGroup,
pub downsample_depth: Box<[BindGroup]>,

View file

@ -116,8 +116,8 @@ impl ViewNode for MeshletMainOpaquePass3dNode {
pipeline_cache.get_render_pipeline(*material_pipeline_id)
{
let x = *material_id * 3;
render_pass.set_bind_group(2, material_bind_group, &[]);
render_pass.set_render_pipeline(material_pipeline);
render_pass.set_bind_group(2, material_bind_group, &[]);
render_pass.draw(x..(x + 3), 0..1);
}
}
@ -237,8 +237,8 @@ impl ViewNode for MeshletPrepassNode {
pipeline_cache.get_render_pipeline(*material_pipeline_id)
{
let x = *material_id * 3;
render_pass.set_bind_group(2, material_bind_group, &[]);
render_pass.set_render_pipeline(material_pipeline);
render_pass.set_bind_group(2, material_bind_group, &[]);
render_pass.draw(x..(x + 3), 0..1);
}
}
@ -363,8 +363,8 @@ impl ViewNode for MeshletDeferredGBufferPrepassNode {
pipeline_cache.get_render_pipeline(*material_pipeline_id)
{
let x = *material_id * 3;
render_pass.set_bind_group(2, material_bind_group, &[]);
render_pass.set_render_pipeline(material_pipeline);
render_pass.set_bind_group(2, material_bind_group, &[]);
render_pass.draw(x..(x + 3), 0..1);
}
}

View file

@ -51,14 +51,22 @@ struct DrawIndirectArgs {
first_instance: u32,
}
#ifdef MESHLET_FILL_CLUSTER_BUFFERS_PASS
var<push_constant> cluster_count: u32;
@group(0) @binding(0) var<storage, read> meshlet_instance_meshlet_counts_prefix_sum: array<u32>; // Per entity instance
@group(0) @binding(1) var<storage, read> meshlet_instance_meshlet_slice_starts: array<u32>; // Per entity instance
@group(0) @binding(2) var<storage, read_write> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(3) var<storage, read_write> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
#endif
#ifdef MESHLET_CULLING_PASS
@group(0) @binding(0) var<storage, read> meshlet_thread_meshlet_ids: array<u32>; // Per cluster (instance of a meshlet)
@group(0) @binding(1) var<storage, read> meshlet_bounding_spheres: array<MeshletBoundingSpheres>; // Per asset meshlet
@group(0) @binding(2) var<storage, read> meshlet_thread_instance_ids: array<u32>; // Per cluster (instance of a meshlet)
@group(0) @binding(0) var<storage, read> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(0) @binding(1) var<storage, read> meshlet_bounding_spheres: array<MeshletBoundingSpheres>; // Per meshlet
@group(0) @binding(2) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(3) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
@group(0) @binding(4) var<storage, read> meshlet_view_instance_visibility: array<u32>; // 1 bit per entity instance, packed as a bitmask
@group(0) @binding(5) var<storage, read_write> meshlet_second_pass_candidates: array<atomic<u32>>; // 1 bit per cluster (instance of a meshlet), packed as a bitmask
@group(0) @binding(6) var<storage, read> meshlets: array<Meshlet>; // Per asset meshlet
@group(0) @binding(5) var<storage, read_write> meshlet_second_pass_candidates: array<atomic<u32>>; // 1 bit per cluster , packed as a bitmask
@group(0) @binding(6) var<storage, read> meshlets: array<Meshlet>; // Per meshlet
@group(0) @binding(7) var<storage, read_write> draw_indirect_args: DrawIndirectArgs; // Single object shared between all workgroups/meshlets/triangles
@group(0) @binding(8) var<storage, read_write> draw_triangle_buffer: array<u32>; // Single object shared between all workgroups/meshlets/triangles
@group(0) @binding(9) var depth_pyramid: texture_2d<f32>; // From the end of the last frame for the first culling pass, and from the first raster pass for the second culling pass
@ -71,7 +79,7 @@ fn should_cull_instance(instance_id: u32) -> bool {
return bool(extractBits(packed_visibility, bit_offset, 1u));
}
fn meshlet_is_second_pass_candidate(cluster_id: u32) -> bool {
fn cluster_is_second_pass_candidate(cluster_id: u32) -> bool {
let packed_candidates = meshlet_second_pass_candidates[cluster_id / 32u];
let bit_offset = cluster_id % 32u;
return bool(extractBits(packed_candidates, bit_offset, 1u));
@ -79,12 +87,12 @@ fn meshlet_is_second_pass_candidate(cluster_id: u32) -> bool {
#endif
#ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS
@group(0) @binding(0) var<storage, read> meshlet_thread_meshlet_ids: array<u32>; // Per cluster (instance of a meshlet)
@group(0) @binding(1) var<storage, read> meshlets: array<Meshlet>; // Per asset meshlet
@group(0) @binding(2) var<storage, read> meshlet_indices: array<u32>; // Many per asset meshlet
@group(0) @binding(3) var<storage, read> meshlet_vertex_ids: array<u32>; // Many per asset meshlet
@group(0) @binding(4) var<storage, read> meshlet_vertex_data: array<PackedMeshletVertex>; // Many per asset meshlet
@group(0) @binding(5) var<storage, read> meshlet_thread_instance_ids: array<u32>; // Per cluster (instance of a meshlet)
@group(0) @binding(0) var<storage, read> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(0) @binding(1) var<storage, read> meshlets: array<Meshlet>; // Per meshlet
@group(0) @binding(2) var<storage, read> meshlet_indices: array<u32>; // Many per meshlet
@group(0) @binding(3) var<storage, read> meshlet_vertex_ids: array<u32>; // Many per meshlet
@group(0) @binding(4) var<storage, read> meshlet_vertex_data: array<PackedMeshletVertex>; // Many per meshlet
@group(0) @binding(5) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(0) @binding(6) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
@group(0) @binding(7) var<storage, read> meshlet_instance_material_ids: array<u32>; // Per entity instance
@group(0) @binding(8) var<storage, read> draw_triangle_buffer: array<u32>; // Single object shared between all workgroups/meshlets/triangles
@ -99,12 +107,12 @@ fn get_meshlet_index(index_id: u32) -> u32 {
#ifdef MESHLET_MESH_MATERIAL_PASS
@group(1) @binding(0) var meshlet_visibility_buffer: texture_2d<u32>; // Generated from the meshlet raster passes
@group(1) @binding(1) var<storage, read> meshlet_thread_meshlet_ids: array<u32>; // Per cluster (instance of a meshlet)
@group(1) @binding(2) var<storage, read> meshlets: array<Meshlet>; // Per asset meshlet
@group(1) @binding(3) var<storage, read> meshlet_indices: array<u32>; // Many per asset meshlet
@group(1) @binding(4) var<storage, read> meshlet_vertex_ids: array<u32>; // Many per asset meshlet
@group(1) @binding(5) var<storage, read> meshlet_vertex_data: array<PackedMeshletVertex>; // Many per asset meshlet
@group(1) @binding(6) var<storage, read> meshlet_thread_instance_ids: array<u32>; // Per cluster (instance of a meshlet)
@group(1) @binding(1) var<storage, read> meshlet_cluster_meshlet_ids: array<u32>; // Per cluster
@group(1) @binding(2) var<storage, read> meshlets: array<Meshlet>; // Per meshlet
@group(1) @binding(3) var<storage, read> meshlet_indices: array<u32>; // Many per meshlet
@group(1) @binding(4) var<storage, read> meshlet_vertex_ids: array<u32>; // Many per meshlet
@group(1) @binding(5) var<storage, read> meshlet_vertex_data: array<PackedMeshletVertex>; // Many per meshlet
@group(1) @binding(6) var<storage, read> meshlet_cluster_instance_ids: array<u32>; // Per cluster
@group(1) @binding(7) var<storage, read> meshlet_instance_uniforms: array<Mesh>; // Per entity instance
fn get_meshlet_index(index_id: u32) -> u32 {

View file

@ -49,7 +49,8 @@ use self::{
},
pipelines::{
MeshletPipelines, MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE, MESHLET_CULLING_SHADER_HANDLE,
MESHLET_DOWNSAMPLE_DEPTH_SHADER_HANDLE, MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE,
MESHLET_DOWNSAMPLE_DEPTH_SHADER_HANDLE, MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE,
MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE,
},
visibility_buffer_raster_node::MeshletVisibilityBufferRasterPassNode,
};
@ -74,6 +75,8 @@ use bevy_ecs::{
use bevy_render::{
render_graph::{RenderGraphApp, ViewNodeRunner},
render_resource::{Shader, TextureUsages},
renderer::RenderDevice,
settings::WgpuFeatures,
view::{
check_visibility, prepare_view_targets, InheritedVisibility, Msaa, ViewVisibility,
Visibility, VisibilitySystems,
@ -105,7 +108,7 @@ const MESHLET_MESH_MATERIAL_SHADER_HANDLE: Handle<Shader> =
///
/// This plugin is not compatible with [`Msaa`], and adding this plugin will disable it.
///
/// This plugin does not work on the WebGL2 backend.
/// This plugin does not work on WASM.
///
/// ![A render of the Stanford dragon as a `MeshletMesh`](https://raw.githubusercontent.com/bevyengine/bevy/main/crates/bevy_pbr/src/meshlet/meshlet_preview.png)
pub struct MeshletPlugin;
@ -124,6 +127,12 @@ impl Plugin for MeshletPlugin {
"visibility_buffer_resolve.wgsl",
Shader::from_wgsl
);
load_internal_asset!(
app,
MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE,
"fill_cluster_buffers.wgsl",
Shader::from_wgsl
);
load_internal_asset!(
app,
MESHLET_CULLING_SHADER_HANDLE,
@ -169,6 +178,15 @@ impl Plugin for MeshletPlugin {
return;
};
if !render_app
.world()
.resource::<RenderDevice>()
.features()
.contains(WgpuFeatures::PUSH_CONSTANTS)
{
panic!("MeshletPlugin can't be used. GPU lacks support: WgpuFeatures::PUSH_CONSTANTS is not supported.");
}
render_app
.add_render_graph_node::<MeshletVisibilityBufferRasterPassNode>(
Core3d,

View file

@ -9,16 +9,19 @@ use bevy_ecs::{
};
use bevy_render::render_resource::*;
pub const MESHLET_CULLING_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(4325134235233421);
pub const MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(4325134235233421);
pub const MESHLET_CULLING_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(5325134235233421);
pub const MESHLET_DOWNSAMPLE_DEPTH_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(5325134235233421);
pub const MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(6325134235233421);
pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle<Shader> =
pub const MESHLET_VISIBILITY_BUFFER_RASTER_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(7325134235233421);
pub const MESHLET_COPY_MATERIAL_DEPTH_SHADER_HANDLE: Handle<Shader> =
Handle::weak_from_u128(8325134235233421);
#[derive(Resource)]
pub struct MeshletPipelines {
fill_cluster_buffers: CachedComputePipelineId,
cull_first: CachedComputePipelineId,
cull_second: CachedComputePipelineId,
downsample_depth: CachedRenderPipelineId,
@ -31,6 +34,8 @@ pub struct MeshletPipelines {
impl FromWorld for MeshletPipelines {
fn from_world(world: &mut World) -> Self {
let gpu_scene = world.resource::<MeshletGpuScene>();
let fill_cluster_buffers_bind_group_layout =
gpu_scene.fill_cluster_buffers_bind_group_layout();
let cull_layout = gpu_scene.culling_bind_group_layout();
let downsample_depth_layout = gpu_scene.downsample_depth_bind_group_layout();
let visibility_buffer_layout = gpu_scene.visibility_buffer_raster_bind_group_layout();
@ -38,6 +43,20 @@ impl FromWorld for MeshletPipelines {
let pipeline_cache = world.resource_mut::<PipelineCache>();
Self {
fill_cluster_buffers: pipeline_cache.queue_compute_pipeline(
ComputePipelineDescriptor {
label: Some("meshlet_fill_cluster_buffers_pipeline".into()),
layout: vec![fill_cluster_buffers_bind_group_layout.clone()],
push_constant_ranges: vec![PushConstantRange {
stages: ShaderStages::COMPUTE,
range: 0..4,
}],
shader: MESHLET_FILL_CLUSTER_BUFFERS_SHADER_HANDLE,
shader_defs: vec!["MESHLET_FILL_CLUSTER_BUFFERS_PASS".into()],
entry_point: "fill_cluster_buffers".into(),
},
),
cull_first: pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("meshlet_culling_first_pipeline".into()),
layout: vec![cull_layout.clone()],
@ -242,6 +261,7 @@ impl MeshletPipelines {
pub fn get(
world: &World,
) -> Option<(
&ComputePipeline,
&ComputePipeline,
&ComputePipeline,
&RenderPipeline,
@ -253,6 +273,7 @@ impl MeshletPipelines {
let pipeline_cache = world.get_resource::<PipelineCache>()?;
let pipeline = world.get_resource::<Self>()?;
Some((
pipeline_cache.get_compute_pipeline(pipeline.fill_cluster_buffers)?,
pipeline_cache.get_compute_pipeline(pipeline.cull_first)?,
pipeline_cache.get_compute_pipeline(pipeline.cull_second)?,
pipeline_cache.get_render_pipeline(pipeline.downsample_depth)?,

View file

@ -1,10 +1,10 @@
#import bevy_pbr::{
meshlet_bindings::{
meshlet_thread_meshlet_ids,
meshlet_cluster_meshlet_ids,
meshlets,
meshlet_vertex_ids,
meshlet_vertex_data,
meshlet_thread_instance_ids,
meshlet_cluster_instance_ids,
meshlet_instance_uniforms,
meshlet_instance_material_ids,
draw_triangle_buffer,
@ -42,12 +42,12 @@ fn vertex(@builtin(vertex_index) vertex_index: u32) -> VertexOutput {
let cluster_id = packed_ids >> 6u;
let triangle_id = extractBits(packed_ids, 0u, 6u);
let index_id = (triangle_id * 3u) + (vertex_index % 3u);
let meshlet_id = meshlet_thread_meshlet_ids[cluster_id];
let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id];
let meshlet = meshlets[meshlet_id];
let index = get_meshlet_index(meshlet.start_index_id + index_id);
let vertex_id = meshlet_vertex_ids[meshlet.start_vertex_id + index];
let vertex = unpack_meshlet_vertex(meshlet_vertex_data[vertex_id]);
let instance_id = meshlet_thread_instance_ids[cluster_id];
let instance_id = meshlet_cluster_instance_ids[cluster_id];
let instance_uniform = meshlet_instance_uniforms[instance_id];
let model = affine3_to_square(instance_uniform.model);

View file

@ -15,6 +15,7 @@ use bevy_render::{
renderer::RenderContext,
view::{ViewDepthTexture, ViewUniformOffset},
};
use std::sync::atomic::Ordering;
/// Rasterize meshlets into a depth buffer, and optional visibility buffer + material depth buffer for shading passes.
pub struct MeshletVisibilityBufferRasterPassNode {
@ -72,6 +73,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
};
let Some((
fill_cluster_buffers_pipeline,
culling_first_pipeline,
culling_second_pipeline,
downsample_depth_pipeline,
@ -84,9 +86,14 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
return Ok(());
};
let culling_workgroups = (meshlet_view_resources.scene_meshlet_count.div_ceil(128) as f32)
.cbrt()
.ceil() as u32;
let first_node = meshlet_view_bind_groups
.first_node
.fetch_and(false, Ordering::SeqCst);
let thread_per_cluster_workgroups =
(meshlet_view_resources.scene_meshlet_count.div_ceil(128) as f32)
.cbrt()
.ceil() as u32;
render_context
.command_encoder()
@ -96,6 +103,15 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
0,
None,
);
if first_node {
fill_cluster_buffers_pass(
render_context,
&meshlet_view_bind_groups.fill_cluster_buffers,
fill_cluster_buffers_pipeline,
thread_per_cluster_workgroups,
meshlet_view_resources.scene_meshlet_count,
);
}
cull_pass(
"culling_first",
render_context,
@ -103,7 +119,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
view_offset,
previous_view_offset,
culling_first_pipeline,
culling_workgroups,
thread_per_cluster_workgroups,
);
raster_pass(
true,
@ -129,7 +145,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
view_offset,
previous_view_offset,
culling_second_pipeline,
culling_workgroups,
thread_per_cluster_workgroups,
);
raster_pass(
false,
@ -191,7 +207,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
view_offset,
previous_view_offset,
culling_first_pipeline,
culling_workgroups,
thread_per_cluster_workgroups,
);
raster_pass(
true,
@ -217,7 +233,7 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
view_offset,
previous_view_offset,
culling_second_pipeline,
culling_workgroups,
thread_per_cluster_workgroups,
);
raster_pass(
false,
@ -243,6 +259,29 @@ impl Node for MeshletVisibilityBufferRasterPassNode {
}
}
// TODO: Reuse same compute pass as cull_pass
fn fill_cluster_buffers_pass(
render_context: &mut RenderContext,
fill_cluster_buffers_bind_group: &BindGroup,
fill_cluster_buffers_pass_pipeline: &ComputePipeline,
fill_cluster_buffers_pass_workgroups: u32,
cluster_count: u32,
) {
let command_encoder = render_context.command_encoder();
let mut cull_pass = command_encoder.begin_compute_pass(&ComputePassDescriptor {
label: Some("fill_cluster_buffers"),
timestamp_writes: None,
});
cull_pass.set_pipeline(fill_cluster_buffers_pass_pipeline);
cull_pass.set_push_constants(0, &cluster_count.to_le_bytes());
cull_pass.set_bind_group(0, fill_cluster_buffers_bind_group, &[]);
cull_pass.dispatch_workgroups(
fill_cluster_buffers_pass_workgroups,
fill_cluster_buffers_pass_workgroups,
fill_cluster_buffers_pass_workgroups,
);
}
fn cull_pass(
label: &'static str,
render_context: &mut RenderContext,
@ -257,12 +296,12 @@ fn cull_pass(
label: Some(label),
timestamp_writes: None,
});
cull_pass.set_pipeline(culling_pipeline);
cull_pass.set_bind_group(
0,
culling_bind_group,
&[view_offset.offset, previous_view_offset.offset],
);
cull_pass.set_pipeline(culling_pipeline);
cull_pass.dispatch_workgroups(culling_workgroups, culling_workgroups, culling_workgroups);
}
@ -327,12 +366,12 @@ fn raster_pass(
draw_pass.set_camera_viewport(viewport);
}
draw_pass.set_render_pipeline(visibility_buffer_raster_pipeline);
draw_pass.set_bind_group(
0,
&meshlet_view_bind_groups.visibility_buffer_raster,
&[view_offset.offset],
);
draw_pass.set_render_pipeline(visibility_buffer_raster_pipeline);
draw_pass.draw_indirect(visibility_buffer_draw_indirect_args, 0);
}
@ -363,8 +402,8 @@ fn downsample_depth(
};
let mut downsample_pass = render_context.begin_tracked_render_pass(downsample_pass);
downsample_pass.set_bind_group(0, &meshlet_view_bind_groups.downsample_depth[i], &[]);
downsample_pass.set_render_pipeline(downsample_depth_pipeline);
downsample_pass.set_bind_group(0, &meshlet_view_bind_groups.downsample_depth[i], &[]);
downsample_pass.draw(0..3, 0..1);
}
@ -400,8 +439,8 @@ fn copy_material_depth_pass(
copy_pass.set_camera_viewport(viewport);
}
copy_pass.set_bind_group(0, copy_material_depth_bind_group, &[]);
copy_pass.set_render_pipeline(copy_material_depth_pipeline);
copy_pass.set_bind_group(0, copy_material_depth_bind_group, &[]);
copy_pass.draw(0..3, 0..1);
}
}

View file

@ -3,11 +3,11 @@
#import bevy_pbr::{
meshlet_bindings::{
meshlet_visibility_buffer,
meshlet_thread_meshlet_ids,
meshlet_cluster_meshlet_ids,
meshlets,
meshlet_vertex_ids,
meshlet_vertex_data,
meshlet_thread_instance_ids,
meshlet_cluster_instance_ids,
meshlet_instance_uniforms,
get_meshlet_index,
unpack_meshlet_vertex,
@ -95,11 +95,11 @@ struct VertexOutput {
/// Load the visibility buffer texture and resolve it into a VertexOutput.
fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let vbuffer = textureLoad(meshlet_visibility_buffer, vec2<i32>(frag_coord.xy), 0).r;
let cluster_id = vbuffer >> 6u;
let meshlet_id = meshlet_thread_meshlet_ids[cluster_id];
let packed_ids = textureLoad(meshlet_visibility_buffer, vec2<i32>(frag_coord.xy), 0).r;
let cluster_id = packed_ids >> 6u;
let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id];
let meshlet = meshlets[meshlet_id];
let triangle_id = extractBits(vbuffer, 0u, 6u);
let triangle_id = extractBits(packed_ids, 0u, 6u);
let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u);
let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z));
let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]);
@ -107,13 +107,14 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]);
let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]);
let instance_id = meshlet_thread_instance_ids[cluster_id];
let instance_id = meshlet_cluster_instance_ids[cluster_id];
let instance_uniform = meshlet_instance_uniforms[instance_id];
let model = affine3_to_square(instance_uniform.model);
let world_position_1 = mesh_position_local_to_world(model, vec4(vertex_1.position, 1.0));
let world_position_2 = mesh_position_local_to_world(model, vec4(vertex_2.position, 1.0));
let world_position_3 = mesh_position_local_to_world(model, vec4(vertex_3.position, 1.0));
let clip_position_1 = position_world_to_clip(world_position_1.xyz);
let clip_position_2 = position_world_to_clip(world_position_2.xyz);
let clip_position_3 = position_world_to_clip(world_position_3.xyz);