//! A shader that renders a mesh multiple times in one draw call. use bevy::{ core_pipeline::core_3d::Transparent3d, ecs::{ query::QueryItem, system::{lifetimeless::*, SystemParamItem}, }, pbr::{ MeshPipeline, MeshPipelineKey, RenderMeshInstances, SetMeshBindGroup, SetMeshViewBindGroup, }, prelude::*, render::{ extract_component::{ExtractComponent, ExtractComponentPlugin}, mesh::{ allocator::MeshAllocator, MeshVertexBufferLayoutRef, RenderMesh, RenderMeshBufferInfo, }, render_asset::RenderAssets, render_phase::{ AddRenderCommand, DrawFunctions, PhaseItem, PhaseItemExtraIndex, RenderCommand, RenderCommandResult, SetItemPipeline, TrackedRenderPass, ViewSortedRenderPhases, }, render_resource::*, renderer::RenderDevice, view::{ExtractedView, NoFrustumCulling}, Render, RenderApp, RenderSet, }, }; use bytemuck::{Pod, Zeroable}; use std::mem::size_of; /// This example uses a shader source file from the assets subdirectory const SHADER_ASSET_PATH: &str = "shaders/instancing.wgsl"; fn main() { App::new() .add_plugins((DefaultPlugins, CustomMaterialPlugin)) .add_systems(Startup, setup) .run(); } fn setup(mut commands: Commands, mut meshes: ResMut>) { commands.spawn(( meshes.add(Cuboid::new(0.5, 0.5, 0.5)), SpatialBundle::INHERITED_IDENTITY, InstanceMaterialData( (1..=10) .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0))) .map(|(x, y)| InstanceData { position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0), scale: 1.0, color: LinearRgba::from(Color::hsla(x * 360., y, 0.5, 1.0)).to_f32_array(), }) .collect(), ), // NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform. // As the cube is at the origin, if its Aabb moves outside the view frustum, all the // instanced cubes will be culled. // The InstanceMaterialData contains the 'GlobalTransform' information for this custom // instancing, and that is not taken into account with the built-in frustum culling. // We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker // component to avoid incorrect culling. NoFrustumCulling, )); // camera commands.spawn(Camera3dBundle { transform: Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y), ..default() }); } #[derive(Component, Deref)] struct InstanceMaterialData(Vec); impl ExtractComponent for InstanceMaterialData { type QueryData = &'static InstanceMaterialData; type QueryFilter = (); type Out = Self; fn extract_component(item: QueryItem<'_, Self::QueryData>) -> Option { Some(InstanceMaterialData(item.0.clone())) } } struct CustomMaterialPlugin; impl Plugin for CustomMaterialPlugin { fn build(&self, app: &mut App) { app.add_plugins(ExtractComponentPlugin::::default()); app.sub_app_mut(RenderApp) .add_render_command::() .init_resource::>() .add_systems( Render, ( queue_custom.in_set(RenderSet::QueueMeshes), prepare_instance_buffers.in_set(RenderSet::PrepareResources), ), ); } fn finish(&self, app: &mut App) { app.sub_app_mut(RenderApp).init_resource::(); } } #[derive(Clone, Copy, Pod, Zeroable)] #[repr(C)] struct InstanceData { position: Vec3, scale: f32, color: [f32; 4], } #[allow(clippy::too_many_arguments)] fn queue_custom( transparent_3d_draw_functions: Res>, custom_pipeline: Res, mut pipelines: ResMut>, pipeline_cache: Res, meshes: Res>, render_mesh_instances: Res, material_meshes: Query>, mut transparent_render_phases: ResMut>, views: Query<(Entity, &ExtractedView, &Msaa)>, ) { let draw_custom = transparent_3d_draw_functions.read().id::(); for (view_entity, view, msaa) in &views { let Some(transparent_phase) = transparent_render_phases.get_mut(&view_entity) else { continue; }; let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples()); let view_key = msaa_key | MeshPipelineKey::from_hdr(view.hdr); let rangefinder = view.rangefinder3d(); for entity in &material_meshes { let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(entity) else { continue; }; let Some(mesh) = meshes.get(mesh_instance.mesh_asset_id) else { continue; }; let key = view_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology()); let pipeline = pipelines .specialize(&pipeline_cache, &custom_pipeline, key, &mesh.layout) .unwrap(); transparent_phase.add(Transparent3d { entity, pipeline, draw_function: draw_custom, distance: rangefinder.distance_translation(&mesh_instance.translation), batch_range: 0..1, extra_index: PhaseItemExtraIndex::NONE, }); } } } #[derive(Component)] struct InstanceBuffer { buffer: Buffer, length: usize, } fn prepare_instance_buffers( mut commands: Commands, query: Query<(Entity, &InstanceMaterialData)>, render_device: Res, ) { for (entity, instance_data) in &query { let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { label: Some("instance data buffer"), contents: bytemuck::cast_slice(instance_data.as_slice()), usage: BufferUsages::VERTEX | BufferUsages::COPY_DST, }); commands.entity(entity).insert(InstanceBuffer { buffer, length: instance_data.len(), }); } } #[derive(Resource)] struct CustomPipeline { shader: Handle, mesh_pipeline: MeshPipeline, } impl FromWorld for CustomPipeline { fn from_world(world: &mut World) -> Self { let mesh_pipeline = world.resource::(); CustomPipeline { shader: world.load_asset(SHADER_ASSET_PATH), mesh_pipeline: mesh_pipeline.clone(), } } } impl SpecializedMeshPipeline for CustomPipeline { type Key = MeshPipelineKey; fn specialize( &self, key: Self::Key, layout: &MeshVertexBufferLayoutRef, ) -> Result { let mut descriptor = self.mesh_pipeline.specialize(key, layout)?; descriptor.vertex.shader = self.shader.clone(); descriptor.vertex.buffers.push(VertexBufferLayout { array_stride: size_of::() as u64, step_mode: VertexStepMode::Instance, attributes: vec![ VertexAttribute { format: VertexFormat::Float32x4, offset: 0, shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes }, VertexAttribute { format: VertexFormat::Float32x4, offset: VertexFormat::Float32x4.size(), shader_location: 4, }, ], }); descriptor.fragment.as_mut().unwrap().shader = self.shader.clone(); Ok(descriptor) } } type DrawCustom = ( SetItemPipeline, SetMeshViewBindGroup<0>, SetMeshBindGroup<1>, DrawMeshInstanced, ); struct DrawMeshInstanced; impl RenderCommand

for DrawMeshInstanced { type Param = ( SRes>, SRes, SRes, ); type ViewQuery = (); type ItemQuery = Read; #[inline] fn render<'w>( item: &P, _view: (), instance_buffer: Option<&'w InstanceBuffer>, (meshes, render_mesh_instances, mesh_allocator): SystemParamItem<'w, '_, Self::Param>, pass: &mut TrackedRenderPass<'w>, ) -> RenderCommandResult { // A borrow check workaround. let mesh_allocator = mesh_allocator.into_inner(); let Some(mesh_instance) = render_mesh_instances.render_mesh_queue_data(item.entity()) else { return RenderCommandResult::Skip; }; let Some(gpu_mesh) = meshes.into_inner().get(mesh_instance.mesh_asset_id) else { return RenderCommandResult::Skip; }; let Some(instance_buffer) = instance_buffer else { return RenderCommandResult::Skip; }; let Some(vertex_buffer_slice) = mesh_allocator.mesh_vertex_slice(&mesh_instance.mesh_asset_id) else { return RenderCommandResult::Skip; }; pass.set_vertex_buffer(0, vertex_buffer_slice.buffer.slice(..)); pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..)); match &gpu_mesh.buffer_info { RenderMeshBufferInfo::Indexed { index_format, count, } => { let Some(index_buffer_slice) = mesh_allocator.mesh_index_slice(&mesh_instance.mesh_asset_id) else { return RenderCommandResult::Skip; }; pass.set_index_buffer(index_buffer_slice.buffer.slice(..), 0, *index_format); pass.draw_indexed( index_buffer_slice.range.start..(index_buffer_slice.range.start + count), vertex_buffer_slice.range.start as i32, 0..instance_buffer.length as u32, ); } RenderMeshBufferInfo::NonIndexed => { pass.draw(vertex_buffer_slice.range, 0..instance_buffer.length as u32); } } RenderCommandResult::Success } }