//! A shader that renders a mesh multiple times in one draw call. use bevy::{ core_pipeline::core_3d::Transparent3d, ecs::system::{lifetimeless::*, SystemParamItem}, math::prelude::*, pbr::{MeshPipeline, MeshPipelineKey, MeshUniform, SetMeshBindGroup, SetMeshViewBindGroup}, prelude::*, render::{ extract_component::{ExtractComponent, ExtractComponentPlugin}, mesh::{GpuBufferInfo, MeshVertexBufferLayout}, render_asset::RenderAssets, render_phase::{ AddRenderCommand, DrawFunctions, EntityRenderCommand, RenderCommandResult, RenderPhase, SetItemPipeline, TrackedRenderPass, }, render_resource::*, renderer::RenderDevice, view::{ComputedVisibility, ExtractedView, Msaa, NoFrustumCulling, Visibility}, RenderApp, RenderStage, }, }; use bytemuck::{Pod, Zeroable}; fn main() { App::new() .add_plugins(DefaultPlugins) .add_plugin(CustomMaterialPlugin) .add_startup_system(setup) .run(); } fn setup(mut commands: Commands, mut meshes: ResMut>) { commands.spawn().insert_bundle(( meshes.add(Mesh::from(shape::Cube { size: 0.5 })), Transform::from_xyz(0.0, 0.0, 0.0), GlobalTransform::default(), InstanceMaterialData( (1..=10) .flat_map(|x| (1..=10).map(move |y| (x as f32 / 10.0, y as f32 / 10.0))) .map(|(x, y)| InstanceData { position: Vec3::new(x * 10.0 - 5.0, y * 10.0 - 5.0, 0.0), scale: 1.0, color: Color::hsla(x * 360., y, 0.5, 1.0).as_rgba_f32(), }) .collect(), ), Visibility::default(), ComputedVisibility::default(), // NOTE: Frustum culling is done based on the Aabb of the Mesh and the GlobalTransform. // As the cube is at the origin, if its Aabb moves outside the view frustum, all the // instanced cubes will be culled. // The InstanceMaterialData contains the 'GlobalTransform' information for this custom // instancing, and that is not taken into account with the built-in frustum culling. // We must disable the built-in frustum culling by adding the `NoFrustumCulling` marker // component to avoid incorrect culling. NoFrustumCulling, )); // camera commands.spawn_bundle(Camera3dBundle { transform: Transform::from_xyz(0.0, 0.0, 15.0).looking_at(Vec3::ZERO, Vec3::Y), ..default() }); } #[derive(Component, Deref)] struct InstanceMaterialData(Vec); impl ExtractComponent for InstanceMaterialData { type Query = &'static InstanceMaterialData; type Filter = (); fn extract_component(item: bevy::ecs::query::QueryItem) -> Self { InstanceMaterialData(item.0.clone()) } } pub struct CustomMaterialPlugin; impl Plugin for CustomMaterialPlugin { fn build(&self, app: &mut App) { app.add_plugin(ExtractComponentPlugin::::default()); app.sub_app_mut(RenderApp) .add_render_command::() .init_resource::() .init_resource::>() .add_system_to_stage(RenderStage::Queue, queue_custom) .add_system_to_stage(RenderStage::Prepare, prepare_instance_buffers); } } #[derive(Clone, Copy, Pod, Zeroable)] #[repr(C)] struct InstanceData { position: Vec3, scale: f32, color: [f32; 4], } #[allow(clippy::too_many_arguments)] fn queue_custom( transparent_3d_draw_functions: Res>, custom_pipeline: Res, msaa: Res, mut pipelines: ResMut>, mut pipeline_cache: ResMut, meshes: Res>, material_meshes: Query<(Entity, &MeshUniform, &Handle), With>, mut views: Query<(&ExtractedView, &mut RenderPhase)>, ) { let draw_custom = transparent_3d_draw_functions .read() .get_id::() .unwrap(); let msaa_key = MeshPipelineKey::from_msaa_samples(msaa.samples); for (view, mut transparent_phase) in views.iter_mut() { let rangefinder = view.rangefinder3d(); for (entity, mesh_uniform, mesh_handle) in material_meshes.iter() { if let Some(mesh) = meshes.get(mesh_handle) { let key = msaa_key | MeshPipelineKey::from_primitive_topology(mesh.primitive_topology); let pipeline = pipelines .specialize(&mut pipeline_cache, &custom_pipeline, key, &mesh.layout) .unwrap(); transparent_phase.add(Transparent3d { entity, pipeline, draw_function: draw_custom, distance: rangefinder.distance(&mesh_uniform.transform), }); } } } } #[derive(Component)] pub struct InstanceBuffer { buffer: Buffer, length: usize, } fn prepare_instance_buffers( mut commands: Commands, query: Query<(Entity, &InstanceMaterialData)>, render_device: Res, ) { for (entity, instance_data) in query.iter() { let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { label: Some("instance data buffer"), contents: bytemuck::cast_slice(instance_data.as_slice()), usage: BufferUsages::VERTEX | BufferUsages::COPY_DST, }); commands.entity(entity).insert(InstanceBuffer { buffer, length: instance_data.len(), }); } } pub struct CustomPipeline { shader: Handle, mesh_pipeline: MeshPipeline, } impl FromWorld for CustomPipeline { fn from_world(world: &mut World) -> Self { let asset_server = world.resource::(); asset_server.watch_for_changes().unwrap(); let shader = asset_server.load("shaders/instancing.wgsl"); let mesh_pipeline = world.resource::(); CustomPipeline { shader, mesh_pipeline: mesh_pipeline.clone(), } } } impl SpecializedMeshPipeline for CustomPipeline { type Key = MeshPipelineKey; fn specialize( &self, key: Self::Key, layout: &MeshVertexBufferLayout, ) -> Result { let mut descriptor = self.mesh_pipeline.specialize(key, layout)?; descriptor.vertex.shader = self.shader.clone(); descriptor.vertex.buffers.push(VertexBufferLayout { array_stride: std::mem::size_of::() as u64, step_mode: VertexStepMode::Instance, attributes: vec![ VertexAttribute { format: VertexFormat::Float32x4, offset: 0, shader_location: 3, // shader locations 0-2 are taken up by Position, Normal and UV attributes }, VertexAttribute { format: VertexFormat::Float32x4, offset: VertexFormat::Float32x4.size(), shader_location: 4, }, ], }); descriptor.fragment.as_mut().unwrap().shader = self.shader.clone(); descriptor.layout = Some(vec![ self.mesh_pipeline.view_layout.clone(), self.mesh_pipeline.mesh_layout.clone(), ]); Ok(descriptor) } } type DrawCustom = ( SetItemPipeline, SetMeshViewBindGroup<0>, SetMeshBindGroup<1>, DrawMeshInstanced, ); pub struct DrawMeshInstanced; impl EntityRenderCommand for DrawMeshInstanced { type Param = ( SRes>, SQuery>>, SQuery>, ); #[inline] fn render<'w>( _view: Entity, item: Entity, (meshes, mesh_query, instance_buffer_query): SystemParamItem<'w, '_, Self::Param>, pass: &mut TrackedRenderPass<'w>, ) -> RenderCommandResult { let mesh_handle = mesh_query.get(item).unwrap(); let instance_buffer = instance_buffer_query.get_inner(item).unwrap(); let gpu_mesh = match meshes.into_inner().get(mesh_handle) { Some(gpu_mesh) => gpu_mesh, None => return RenderCommandResult::Failure, }; pass.set_vertex_buffer(0, gpu_mesh.vertex_buffer.slice(..)); pass.set_vertex_buffer(1, instance_buffer.buffer.slice(..)); match &gpu_mesh.buffer_info { GpuBufferInfo::Indexed { buffer, index_format, count, } => { pass.set_index_buffer(buffer.slice(..), 0, *index_format); pass.draw_indexed(0..*count, 0, 0..instance_buffer.length as u32); } GpuBufferInfo::NonIndexed { vertex_count } => { pass.draw(0..*vertex_count, 0..instance_buffer.length as u32); } } RenderCommandResult::Success } }