bevy/examples/shader/custom_shader_instancing.rs
Patrick Walton f5de3f08fb
Use multidraw for opaque meshes when GPU culling is in use. (#16427)
This commit adds support for *multidraw*, which is a feature that allows
multiple meshes to be drawn in a single drawcall. `wgpu` currently
implements multidraw on Vulkan, so this feature is only enabled there.
Multiple meshes can be drawn at once if they're in the same vertex and
index buffers and are otherwise placed in the same bin. (Thus, for
example, at present the materials and textures must be identical, but
see #16368.) Multidraw is a significant performance improvement during
the draw phase because it reduces the number of rebindings, as well as
the number of drawcalls.

This feature is currently only enabled when GPU culling is used: i.e.
when `GpuCulling` is present on a camera. Therefore, if you run for
example `scene_viewer`, you will not see any performance improvements,
because `scene_viewer` doesn't add the `GpuCulling` component to its
camera.

Additionally, the multidraw feature is only implemented for opaque 3D
meshes and not for shadows or 2D meshes. I plan to make GPU culling the
default and to extend the feature to shadows in the future. Also, in the
future I suspect that polyfilling multidraw on APIs that don't support
it will be fruitful, as even without driver-level support use of
multidraw allows us to avoid expensive `wgpu` rebindings.
2024-12-06 17:22:03 +00:00

315 lines
11 KiB
Rust

//! A shader that renders a mesh multiple times in one draw call.
//!
//! Bevy will automatically batch and instance your meshes assuming you use the same
//! `Handle<Material>` and `Handle<Mesh>` for all of your instances.
//!
//! This example is intended for advanced users and shows how to make a custom instancing
//! implementation using bevy's low level rendering api.
//! It's generally recommended to try the built-in instancing before going with this approach.
use bevy::{
core_pipeline::core_3d::Transparent3d,
ecs::{
query::QueryItem,
system::{lifetimeless::*, SystemParamItem},
},
pbr::{
MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup,
},
prelude::*,
render::{
extract_component::{ExtractComponent, ExtractComponentPlugin},
mesh::{
allocator::MeshAllocator, MeshVertexBufferLayoutRef, RenderMesh, RenderMeshBufferInfo,
},
render_asset::RenderAssets,
render_phase::{
AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand,
RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases,
},
render_resource::*,
renderer::RenderDevice,
sync_world::MainEntity,
view::{ExtractedView, NoFrustumCulling},
Render, RenderApp, RenderSet,
},
};
use bytemuck::{Pod, Zeroable};
/// This example uses a shader source file from the assets subdirectory
const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl";
fn main() {
App::new()
.add_plugins((DefaultPlugins, CustomMaterialPlugin))
.add_systems(Startup, setup)
.run();
}
fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
commands.spawn((
Mesh3d(meshes.add(Cuboid::new(0.5, 0.5, 0.5))),
InstanceMaterialData(
(1..=10)
.flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0)))
.map(|(x, y)| InstanceData {
position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0),
scale: 1.0,
color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(),
})
.collect(),
),
// NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform.
// As the cube is at the origin, if its Aabb moves outside the view frustum, all the
// instanced cubes will be culled.
// The InstanceMaterialData contains the 'GlobalTransform' information for this custom
// instancing, and that is not taken into account with the built-in frustum culling.
// We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker
// component to avoid incorrect culling.
NoFrustumCulling,
));
// camera
commands.spawn((
Camera3d::default(),
Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y),
));
}
#[derive(Component, Deref)]
struct InstanceMaterialData(Vec<InstanceData>);
impl ExtractComponent for InstanceMaterialData {
type QueryData = &'static InstanceMaterialData;
type QueryFilter = ();
type Out = Self;
fn extract_component(item: QueryItem<'_, Self::QueryData>) -> Option<Self> {
Some(InstanceMaterialData(item.0.clone()))
}
}
struct CustomMaterialPlugin;
impl Plugin for CustomMaterialPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(ExtractComponentPlugin::<InstanceMaterialData>::default());
app.sub_app_mut(RenderApp)
.add_render_command::<Transparent3d, DrawCustom>()
.init_resource::<SpecializedMeshPipelines<CustomPipeline>>()
.add_systems(
Render,
(
queue_custom.in_set(RenderSet::QueueMeshes),
prepare_instance_buffers.in_set(RenderSet::PrepareResources),
),
);
}
fn finish(&self, app: &mut App) {
app.sub_app_mut(RenderApp).init_resource::<CustomPipeline>();
}
}
#[derive(Clone, Copy, Pod, Zeroable)]
#[repr(C)]
struct InstanceData {
position: Vec3,
scale: f32,
color: [f32; 4],
}
#[allow(clippy::too_many_arguments)]
fn queue_custom(
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
custom_pipeline: Res<CustomPipeline>,
mut pipelines: ResMut<SpecializedMeshPipelines<CustomPipeline>>,
pipeline_cache: Res<PipelineCache>,
meshes: Res<RenderAssets<RenderMesh>>,
render_mesh_instances: Res<RenderMeshInstances>,
material_meshes: Query<(Entity, &MainEntity), With<InstanceMaterialData>>,
mut transparent_render_phases: ResMut<ViewSortedRenderPhases<Transparent3d>>,
views: Query<(Entity, &ExtractedView, &Msaa)>,
) {
let draw_custom = transparent_3d_draw_functions.read().id::<DrawCustom>();
for (view_entity, view, msaa) in &views {
let Some(transparent_phase) = transparent_render_phases.get_mut(&view_entity) else {
continue;
};
let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples());
let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr);
let rangefinder = view.rangefinder3d();
for (entity, main_entity) in &material_meshes {
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(*main_entity)
else {
continue;
};
let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else {
continue;
};
let key =
view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology());
let pipeline = pipelines
.specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout)
.unwrap();
transparent_phase.add(Transparent3d {
entity: (entity, *main_entity),
pipeline,
draw_function: draw_custom,
distance: rangefinder.distance_translation(&mesh_instance.translation),
batch_range: 0..1,
extra_index: PhaseItemExtraIndex::None,
});
}
}
}
#[derive(Component)]
struct InstanceBuffer {
buffer: Buffer,
length: usize,
}
fn prepare_instance_buffers(
mut commands: Commands,
query: Query<(Entity, &InstanceMaterialData)>,
render_device: Res<RenderDevice>,
) {
for (entity, instance_data) in &query {
let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor {
label: Some("instance data buffer"),
contents: bytemuck::cast_slice(instance_data.as_slice()),
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
});
commands.entity(entity).insert(InstanceBuffer {
buffer,
length: instance_data.len(),
});
}
}
#[derive(Resource)]
struct CustomPipeline {
shader: Handle<Shader>,
mesh_pipeline: MeshPipeline,
}
impl FromWorld for CustomPipeline {
fn from_world(world: &mut World) -> Self {
let mesh_pipeline = world.resource::<MeshPipeline>();
CustomPipeline {
shader: world.load_asset(SHADER_ASSET_PATH),
mesh_pipeline: mesh_pipeline.clone(),
}
}
}
impl SpecializedMeshPipeline for CustomPipeline {
type Key = MeshPipelineKey;
fn specialize(
&self,
key: Self::Key,
layout: &MeshVertexBufferLayoutRef,
) -> Result<RenderPipelineDescriptor, SpecializedMeshPipelineError> {
let mut descriptor = self.mesh_pipeline.specialize(key, layout)?;
descriptor.vertex.shader = self.shader.clone();
descriptor.vertex.buffers.push(VertexBufferLayout {
array_stride: size_of::<InstanceData>() as u64,
step_mode: VertexStepMode::Instance,
attributes: vec![
VertexAttribute {
format: VertexFormat::Float32x4,
offset: 0,
shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes
},
VertexAttribute {
format: VertexFormat::Float32x4,
offset: VertexFormat::Float32x4.size(),
shader_location: 4,
},
],
});
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
Ok(descriptor)
}
}
type DrawCustom = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetMeshBindGroup<1>,
DrawMeshInstanced,
);
struct DrawMeshInstanced;
impl<P: PhaseItem> RenderCommand<P> for DrawMeshInstanced {
type Param = (
SRes<RenderAssets<RenderMesh>>,
SRes<RenderMeshInstances>,
SRes<MeshAllocator>,
);
type ViewQuery = ();
type ItemQuery = Read<InstanceBuffer>;
#[inline]
fn render<'w>(
item: &P,
_view: (),
instance_buffer: Option<&'w InstanceBuffer>,
(meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
// A borrow check workaround.
let mesh_allocator = mesh_allocator.into_inner();
let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.main_entity())
else {
return RenderCommandResult::Skip;
};
let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else {
return RenderCommandResult::Skip;
};
let Some(instance_buffer) = instance_buffer else {
return RenderCommandResult::Skip;
};
let Some(vertex_buffer_slice) =
mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id)
else {
return RenderCommandResult::Skip;
};
pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..));
pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..));
match &gpu_mesh.buffer_info {
RenderMeshBufferInfo::Indexed {
index_format,
count,
} => {
let Some(index_buffer_slice) =
mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id)
else {
return RenderCommandResult::Skip;
};
pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format);
pass.draw_indexed(
index_buffer_slice.range.start..(index_buffer_slice.range.start + count),
vertex_buffer_slice.range.start as i32,
0..instance_buffer.length as u32,
);
}
RenderMeshBufferInfo::NonIndexed => {
pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32);
}
}
RenderCommandResult::Success
}
}