From aa626e4f0b4b43c543b95168baaa6dc9b9471359 Mon Sep 17 00:00:00 2001 From: JMS55 <47158642+JMS55@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:42:55 -0700 Subject: [PATCH] Per-meshlet compressed vertex data (#15643) # Objective - Prepare for streaming by storing vertex data per-meshlet, rather than per-mesh (this means duplicating vertices per-meshlet) - Compress vertex data to reduce the cost of this ## Solution The important parts are in from_mesh.rs, the changes to the Meshlet type in asset.rs, and the changes in meshlet_bindings.wgsl. Everything else is pretty secondary/boilerplate/straightforward changes. - Positions are quantized in centimeters with a user-provided power of 2 factor (ideally auto-determined, but that's a TODO for the future), encoded as an offset relative to the minimum value within the meshlet, and then stored as a packed list of bits using the minimum number of bits needed for each vertex position channel for that meshlet - E.g. quantize positions (lossly, throws away precision that's not needed leading to using less bits in the bitstream encoding) - Get the min/max quantized value of each X/Y/Z channel of the quantized positions within a meshlet - Encode values relative to the min value of the meshlet. E.g. convert from [min, max] to [0, max - min] - The new max value in the meshlet is (max - min), which only takes N bits, so we only need N bits to store each channel within the meshlet (lossless) - We can store the min value and that it takes N bits per channel in the meshlet metadata, and reconstruct the position from the bitstream - Normals are octahedral encoded and than snorm2x16 packed and stored as a single u32. - Would be better to implement the precise variant of octhedral encoding for extra precision (no extra decode cost), but decided to keep it simple for now and leave that as a followup - Tried doing a quantizing and bitstream encoding scheme like I did for positions, but struggled to get it smaller. Decided to go with this for simplicity for now - UVs are uncompressed and take a full 64bits per vertex which is expensive - In the future this should be improved - Tangents, as of the previous PR, are not explicitly stored and are instead derived from screen space gradients - While I'm here, split up MeshletMeshSaverLoader into two separate types Other future changes include implementing a smaller encoding of triangle data (3 u8 indices = 24 bits per triangle currently), and more disk-oriented compression schemes. References: * "A Deep Dive into UE5's Nanite Virtualized Geometry" https://advances.realtimerendering.com/s2021/Karis_Nanite_SIGGRAPH_Advances_2021_final.pdf#page=128 (also available on youtube) * "Towards Practical Meshlet Compression" https://arxiv.org/pdf/2404.06359 * "Vertex quantization in Omniforce Game Engine" https://daniilvinn.github.io/2024/05/04/omniforce-vertex-quantization.html ## Testing - Did you test these changes? If so, how? - Converted the stanford bunny, and rendered it with a debug material showing normals, and confirmed that it's identical to what's on main. EDIT: See additional testing in the comments below. - Are there any parts that need more testing? - Could use some more size comparisons on various meshes, and testing different quantization factors. Not sure if 4 is a good default. EDIT: See additional testing in the comments below. - Also did not test runtime performance of the shaders. EDIT: See additional testing in the comments below. - How can other people (reviewers) test your changes? Is there anything specific they need to know? - Use my unholy script, replacing the meshlet example https://paste.rs/7xQHk.rs (must make MeshletMesh fields pub instead of pub crate, must add lz4_flex as a dev-dependency) (must compile with meshlet and meshlet_processor features, mesh must have only positions, normals, and UVs, no vertex colors or tangents) --- ## Migration Guide - TBD by JMS55 at the end of the release --- Cargo.toml | 2 +- crates/bevy_pbr/Cargo.toml | 9 +- crates/bevy_pbr/src/meshlet/asset.rs | 77 +++++--- crates/bevy_pbr/src/meshlet/from_mesh.rs | 183 ++++++++++++++++-- .../src/meshlet/meshlet_bindings.wgsl | 147 ++++++++++---- .../src/meshlet/meshlet_mesh_manager.rs | 59 +++--- crates/bevy_pbr/src/meshlet/mod.rs | 8 +- .../src/meshlet/persistent_buffer_impls.rs | 41 ++-- .../bevy_pbr/src/meshlet/resource_manager.rs | 10 +- .../visibility_buffer_hardware_raster.wgsl | 18 +- .../meshlet/visibility_buffer_resolve.wgsl | 54 ++++-- .../visibility_buffer_software_raster.wgsl | 21 +- crates/bevy_pbr/src/render/utils.wgsl | 8 +- examples/3d/meshlet.rs | 2 +- 14 files changed, 467 insertions(+), 172 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 01d1b07ba2..fd47803f37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1209,7 +1209,7 @@ setup = [ "curl", "-o", "assets/models/bunny.meshlet_mesh", - "https://raw.githubusercontent.com/JMS55/bevy_meshlet_asset/854eb98353ad94aea1104f355fc24dbe4fda679d/bunny.meshlet_mesh", + "https://raw.githubusercontent.com/JMS55/bevy_meshlet_asset/8443bbdee0bf517e6c297dede7f6a46ab712ee4c/bunny.meshlet_mesh", ], ] diff --git a/crates/bevy_pbr/Cargo.toml b/crates/bevy_pbr/Cargo.toml index a4c46dfdbc..5afe67229d 100644 --- a/crates/bevy_pbr/Cargo.toml +++ b/crates/bevy_pbr/Cargo.toml @@ -20,7 +20,13 @@ ios_simulator = ["bevy_render/ios_simulator"] # Enables the meshlet renderer for dense high-poly scenes (experimental) meshlet = ["dep:lz4_flex", "dep:thiserror", "dep:range-alloc", "dep:bevy_tasks"] # Enables processing meshes into meshlet meshes -meshlet_processor = ["meshlet", "dep:meshopt", "dep:metis", "dep:itertools"] +meshlet_processor = [ + "meshlet", + "dep:meshopt", + "dep:metis", + "dep:itertools", + "dep:bitvec", +] [dependencies] # bevy @@ -53,6 +59,7 @@ range-alloc = { version = "0.1.3", optional = true } meshopt = { version = "0.3.0", optional = true } metis = { version = "0.2", optional = true } itertools = { version = "0.13", optional = true } +bitvec = { version = "1", optional = true } # direct dependency required for derive macro bytemuck = { version = "1", features = ["derive", "must_cast"] } radsort = "0.1" diff --git a/crates/bevy_pbr/src/meshlet/asset.rs b/crates/bevy_pbr/src/meshlet/asset.rs index 87e3483d0f..2ba821a9da 100644 --- a/crates/bevy_pbr/src/meshlet/asset.rs +++ b/crates/bevy_pbr/src/meshlet/asset.rs @@ -4,7 +4,7 @@ use bevy_asset::{ saver::{AssetSaver, SavedAsset}, Asset, AssetLoader, AsyncReadExt, AsyncWriteExt, LoadContext, }; -use bevy_math::Vec3; +use bevy_math::{Vec2, Vec3}; use bevy_reflect::TypePath; use bevy_tasks::block_on; use bytemuck::{Pod, Zeroable}; @@ -38,30 +38,51 @@ pub const MESHLET_MESH_ASSET_VERSION: u64 = 1; /// See also [`super::MaterialMeshletMeshBundle`] and [`super::MeshletPlugin`]. #[derive(Asset, TypePath, Clone)] pub struct MeshletMesh { - /// Raw vertex data bytes for the overall mesh. - pub(crate) vertex_data: Arc<[u8]>, - /// Indices into `vertex_data`. - pub(crate) vertex_ids: Arc<[u32]>, - /// Indices into `vertex_ids`. + /// Quantized and bitstream-packed vertex positions for meshlet vertices. + pub(crate) vertex_positions: Arc<[u32]>, + /// Octahedral-encoded and 2x16snorm packed normals for meshlet vertices. + pub(crate) vertex_normals: Arc<[u32]>, + /// Uncompressed vertex texture coordinates for meshlet vertices. + pub(crate) vertex_uvs: Arc<[Vec2]>, + /// Triangle indices for meshlets. pub(crate) indices: Arc<[u8]>, /// The list of meshlets making up this mesh. pub(crate) meshlets: Arc<[Meshlet]>, /// Spherical bounding volumes. - pub(crate) bounding_spheres: Arc<[MeshletBoundingSpheres]>, + pub(crate) meshlet_bounding_spheres: Arc<[MeshletBoundingSpheres]>, } /// A single meshlet within a [`MeshletMesh`]. #[derive(Copy, Clone, Pod, Zeroable)] #[repr(C)] pub struct Meshlet { - /// The offset within the parent mesh's [`MeshletMesh::vertex_ids`] buffer where the indices for this meshlet begin. - pub start_vertex_id: u32, + /// The bit offset within the parent mesh's [`MeshletMesh::vertex_positions`] buffer where the vertex positions for this meshlet begin. + pub start_vertex_position_bit: u32, + /// The offset within the parent mesh's [`MeshletMesh::vertex_normals`] and [`MeshletMesh::vertex_uvs`] buffers + /// where non-position vertex attributes for this meshlet begin. + pub start_vertex_attribute_id: u32, /// The offset within the parent mesh's [`MeshletMesh::indices`] buffer where the indices for this meshlet begin. pub start_index_id: u32, /// The amount of vertices in this meshlet. - pub vertex_count: u32, + pub vertex_count: u8, /// The amount of triangles in this meshlet. - pub triangle_count: u32, + pub triangle_count: u8, + /// Unused. + pub padding: u16, + /// Number of bits used to to store the X channel of vertex positions within this meshlet. + pub bits_per_vertex_position_channel_x: u8, + /// Number of bits used to to store the Y channel of vertex positions within this meshlet. + pub bits_per_vertex_position_channel_y: u8, + /// Number of bits used to to store the Z channel of vertex positions within this meshlet. + pub bits_per_vertex_position_channel_z: u8, + /// Power of 2 factor used to quantize vertex positions within this meshlet. + pub vertex_position_quantization_factor: u8, + /// Minimum quantized X channel value of vertex positions within this meshlet. + pub min_vertex_position_channel_x: f32, + /// Minimum quantized Y channel value of vertex positions within this meshlet. + pub min_vertex_position_channel_y: f32, + /// Minimum quantized Z channel value of vertex positions within this meshlet. + pub min_vertex_position_channel_z: f32, } /// Bounding spheres used for culling and choosing level of detail for a [`Meshlet`]. @@ -84,13 +105,13 @@ pub struct MeshletBoundingSphere { pub radius: f32, } -/// An [`AssetLoader`] and [`AssetSaver`] for `.meshlet_mesh` [`MeshletMesh`] assets. -pub struct MeshletMeshSaverLoader; +/// An [`AssetSaver`] for `.meshlet_mesh` [`MeshletMesh`] assets. +pub struct MeshletMeshSaver; -impl AssetSaver for MeshletMeshSaverLoader { +impl AssetSaver for MeshletMeshSaver { type Asset = MeshletMesh; type Settings = (); - type OutputLoader = Self; + type OutputLoader = MeshletMeshLoader; type Error = MeshletMeshSaveOrLoadError; async fn save( @@ -111,18 +132,22 @@ impl AssetSaver for MeshletMeshSaverLoader { // Compress and write asset data let mut writer = FrameEncoder::new(AsyncWriteSyncAdapter(writer)); - write_slice(&asset.vertex_data, &mut writer)?; - write_slice(&asset.vertex_ids, &mut writer)?; + write_slice(&asset.vertex_positions, &mut writer)?; + write_slice(&asset.vertex_normals, &mut writer)?; + write_slice(&asset.vertex_uvs, &mut writer)?; write_slice(&asset.indices, &mut writer)?; write_slice(&asset.meshlets, &mut writer)?; - write_slice(&asset.bounding_spheres, &mut writer)?; + write_slice(&asset.meshlet_bounding_spheres, &mut writer)?; writer.finish()?; Ok(()) } } -impl AssetLoader for MeshletMeshSaverLoader { +/// An [`AssetLoader`] for `.meshlet_mesh` [`MeshletMesh`] assets. +pub struct MeshletMeshLoader; + +impl AssetLoader for MeshletMeshLoader { type Asset = MeshletMesh; type Settings = (); type Error = MeshletMeshSaveOrLoadError; @@ -147,18 +172,20 @@ impl AssetLoader for MeshletMeshSaverLoader { // Load and decompress asset data let reader = &mut FrameDecoder::new(AsyncReadSyncAdapter(reader)); - let vertex_data = read_slice(reader)?; - let vertex_ids = read_slice(reader)?; + let vertex_positions = read_slice(reader)?; + let vertex_normals = read_slice(reader)?; + let vertex_uvs = read_slice(reader)?; let indices = read_slice(reader)?; let meshlets = read_slice(reader)?; - let bounding_spheres = read_slice(reader)?; + let meshlet_bounding_spheres = read_slice(reader)?; Ok(MeshletMesh { - vertex_data, - vertex_ids, + vertex_positions, + vertex_normals, + vertex_uvs, indices, meshlets, - bounding_spheres, + meshlet_bounding_spheres, }) } diff --git a/crates/bevy_pbr/src/meshlet/from_mesh.rs b/crates/bevy_pbr/src/meshlet/from_mesh.rs index 668ff32eba..367e572276 100644 --- a/crates/bevy_pbr/src/meshlet/from_mesh.rs +++ b/crates/bevy_pbr/src/meshlet/from_mesh.rs @@ -1,31 +1,62 @@ use super::asset::{Meshlet, MeshletBoundingSphere, MeshletBoundingSpheres, MeshletMesh}; use alloc::borrow::Cow; +use bevy_math::{ops::log2, IVec3, Vec2, Vec3, Vec3Swizzles}; use bevy_render::{ mesh::{Indices, Mesh}, render_resource::PrimitiveTopology, }; use bevy_utils::HashMap; +use bitvec::{order::Lsb0, vec::BitVec, view::BitView}; use core::ops::Range; use itertools::Itertools; use meshopt::{ - build_meshlets, compute_cluster_bounds, compute_meshlet_bounds, ffi::meshopt_Bounds, simplify, - Meshlets, SimplifyOptions, VertexDataAdapter, + build_meshlets, compute_cluster_bounds, compute_meshlet_bounds, + ffi::{meshopt_Bounds, meshopt_Meshlet}, + simplify, Meshlets, SimplifyOptions, VertexDataAdapter, }; use metis::Graph; use smallvec::SmallVec; +/// Default vertex position quantization factor for use with [`MeshletMesh::from_mesh`]. +/// +/// Snaps vertices to the nearest 1/16th of a centimeter (1/2^4). +pub const DEFAULT_VERTEX_POSITION_QUANTIZATION_FACTOR: u8 = 4; + +const MESHLET_VERTEX_SIZE_IN_BYTES: usize = 32; +const CENTIMETERS_PER_METER: f32 = 100.0; + impl MeshletMesh { /// Process a [`Mesh`] to generate a [`MeshletMesh`]. /// /// This process is very slow, and should be done ahead of time, and not at runtime. /// + /// # Requirements + /// /// This function requires the `meshlet_processor` cargo feature. /// /// The input mesh must: /// 1. Use [`PrimitiveTopology::TriangleList`] /// 2. Use indices /// 3. Have the exact following set of vertex attributes: `{POSITION, NORMAL, UV_0}` (tangents can be used in material shaders, but are calculated at runtime and are not stored in the mesh) - pub fn from_mesh(mesh: &Mesh) -> Result { + /// + /// # Vertex precision + /// + /// `vertex_position_quantization_factor` is the amount of precision to to use when quantizing vertex positions. + /// + /// Vertices are snapped to the nearest (1/2^x)th of a centimeter, where x = `vertex_position_quantization_factor`. + /// E.g. if x = 4, then vertices are snapped to the nearest 1/2^4 = 1/16th of a centimeter. + /// + /// Use [`DEFAULT_VERTEX_POSITION_QUANTIZATION_FACTOR`] as a default, adjusting lower to save memory and disk space, and higher to prevent artifacts if needed. + /// + /// To ensure that two different meshes do not have cracks between them when placed directly next to each other: + /// * Use the same quantization factor when converting each mesh to a meshlet mesh + /// * Ensure that their [`bevy_transform::components::Transform::translation`]s are a multiple of 1/2^x centimeters (note that translations are in meters) + /// * Ensure that their [`bevy_transform::components::Transform::scale`]s are the same + /// * Ensure that their [`bevy_transform::components::Transform::rotation`]s are a multiple of 90 degrees + pub fn from_mesh( + mesh: &Mesh, + vertex_position_quantization_factor: u8, + ) -> Result { // Validate mesh format let indices = validate_input_mesh(mesh)?; @@ -121,24 +152,32 @@ impl MeshletMesh { simplification_queue = next_lod_start..meshlets.len(); } - // Convert meshopt_Meshlet data to a custom format - let bevy_meshlets = meshlets - .meshlets - .into_iter() - .map(|m| Meshlet { - start_vertex_id: m.vertex_offset, - start_index_id: m.triangle_offset, - vertex_count: m.vertex_count, - triangle_count: m.triangle_count, - }) - .collect(); + // Copy vertex attributes per meshlet and compress + let mut vertex_positions = BitVec::::new(); + let mut vertex_normals = Vec::new(); + let mut vertex_uvs = Vec::new(); + let mut bevy_meshlets = Vec::with_capacity(meshlets.len()); + for (i, meshlet) in meshlets.meshlets.iter().enumerate() { + build_and_compress_meshlet_vertex_data( + meshlet, + meshlets.get(i).vertices, + &vertex_buffer, + &mut vertex_positions, + &mut vertex_normals, + &mut vertex_uvs, + &mut bevy_meshlets, + vertex_position_quantization_factor, + ); + } + vertex_positions.set_uninitialized(false); Ok(Self { - vertex_data: vertex_buffer.into(), - vertex_ids: meshlets.vertices.into(), + vertex_positions: vertex_positions.into_vec().into(), + vertex_normals: vertex_normals.into(), + vertex_uvs: vertex_uvs.into(), indices: meshlets.triangles.into(), - meshlets: bevy_meshlets, - bounding_spheres: bounding_spheres.into(), + meshlets: bevy_meshlets.into(), + meshlet_bounding_spheres: bounding_spheres.into(), }) } } @@ -323,6 +362,92 @@ fn split_simplified_group_into_new_meshlets( new_meshlets_count } +#[allow(clippy::too_many_arguments)] +fn build_and_compress_meshlet_vertex_data( + meshlet: &meshopt_Meshlet, + meshlet_vertex_ids: &[u32], + vertex_buffer: &[u8], + vertex_positions: &mut BitVec, + vertex_normals: &mut Vec, + vertex_uvs: &mut Vec, + meshlets: &mut Vec, + vertex_position_quantization_factor: u8, +) { + let start_vertex_position_bit = vertex_positions.len() as u32; + let start_vertex_attribute_id = vertex_normals.len() as u32; + + let quantization_factor = + (1 << vertex_position_quantization_factor) as f32 * CENTIMETERS_PER_METER; + + let mut min_quantized_position_channels = IVec3::MAX; + let mut max_quantized_position_channels = IVec3::MIN; + + // Lossy vertex compression + let mut quantized_positions = [IVec3::ZERO; 255]; + for (i, vertex_id) in meshlet_vertex_ids.iter().enumerate() { + // Load source vertex attributes + let vertex_id_byte = *vertex_id as usize * MESHLET_VERTEX_SIZE_IN_BYTES; + let vertex_data = + &vertex_buffer[vertex_id_byte..(vertex_id_byte + MESHLET_VERTEX_SIZE_IN_BYTES)]; + let position = Vec3::from_slice(bytemuck::cast_slice(&vertex_data[0..12])); + let normal = Vec3::from_slice(bytemuck::cast_slice(&vertex_data[12..24])); + let uv = Vec2::from_slice(bytemuck::cast_slice(&vertex_data[24..32])); + + // Copy uncompressed UV + vertex_uvs.push(uv); + + // Compress normal + vertex_normals.push(pack2x16snorm(octahedral_encode(normal))); + + // Quantize position to a fixed-point IVec3 + let quantized_position = (position * quantization_factor + 0.5).as_ivec3(); + quantized_positions[i] = quantized_position; + + // Compute per X/Y/Z-channel quantized position min/max for this meshlet + min_quantized_position_channels = min_quantized_position_channels.min(quantized_position); + max_quantized_position_channels = max_quantized_position_channels.max(quantized_position); + } + + // Calculate bits needed to encode each quantized vertex position channel based on the range of each channel + let range = max_quantized_position_channels - min_quantized_position_channels + 1; + let bits_per_vertex_position_channel_x = log2(range.x as f32).ceil() as u8; + let bits_per_vertex_position_channel_y = log2(range.y as f32).ceil() as u8; + let bits_per_vertex_position_channel_z = log2(range.z as f32).ceil() as u8; + + // Lossless encoding of vertex positions in the minimum number of bits per channel + for quantized_position in quantized_positions.iter().take(meshlet_vertex_ids.len()) { + // Remap [range_min, range_max] IVec3 to [0, range_max - range_min] UVec3 + let position = (quantized_position - min_quantized_position_channels).as_uvec3(); + + // Store as a packed bitstream + vertex_positions.extend_from_bitslice( + &position.x.view_bits::()[..bits_per_vertex_position_channel_x as usize], + ); + vertex_positions.extend_from_bitslice( + &position.y.view_bits::()[..bits_per_vertex_position_channel_y as usize], + ); + vertex_positions.extend_from_bitslice( + &position.z.view_bits::()[..bits_per_vertex_position_channel_z as usize], + ); + } + + meshlets.push(Meshlet { + start_vertex_position_bit, + start_vertex_attribute_id, + start_index_id: meshlet.triangle_offset, + vertex_count: meshlet.vertex_count as u8, + triangle_count: meshlet.triangle_count as u8, + padding: 0, + bits_per_vertex_position_channel_x, + bits_per_vertex_position_channel_y, + bits_per_vertex_position_channel_z, + vertex_position_quantization_factor, + min_vertex_position_channel_x: min_quantized_position_channels.x as f32, + min_vertex_position_channel_y: min_quantized_position_channels.y as f32, + min_vertex_position_channel_z: min_quantized_position_channels.z as f32, + }); +} + fn convert_meshlet_bounds(bounds: meshopt_Bounds) -> MeshletBoundingSphere { MeshletBoundingSphere { center: bounds.center.into(), @@ -330,6 +455,28 @@ fn convert_meshlet_bounds(bounds: meshopt_Bounds) -> MeshletBoundingSphere { } } +// TODO: Precise encode variant +fn octahedral_encode(v: Vec3) -> Vec2 { + let n = v / (v.x.abs() + v.y.abs() + v.z.abs()); + let octahedral_wrap = (1.0 - n.yx().abs()) + * Vec2::new( + if n.x >= 0.0 { 1.0 } else { -1.0 }, + if n.y >= 0.0 { 1.0 } else { -1.0 }, + ); + if n.z >= 0.0 { + n.xy() + } else { + octahedral_wrap + } +} + +// https://www.w3.org/TR/WGSL/#pack2x16snorm-builtin +fn pack2x16snorm(v: Vec2) -> u32 { + let v = v.clamp(Vec2::NEG_ONE, Vec2::ONE); + let v = (v * 32767.0 + 0.5).floor().as_i16vec2(); + bytemuck::cast(v) +} + /// An error produced by [`MeshletMesh::from_mesh`]. #[derive(thiserror::Error, Debug)] pub enum MeshToMeshletMeshConversionError { diff --git a/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl b/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl index 876eed1145..f870758353 100644 --- a/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl +++ b/crates/bevy_pbr/src/meshlet/meshlet_bindings.wgsl @@ -3,31 +3,25 @@ #import bevy_pbr::mesh_types::Mesh #import bevy_render::view::View #import bevy_pbr::prepass_bindings::PreviousViewUniforms - -struct PackedMeshletVertex { - a: vec4, - b: vec4, -} - -struct MeshletVertex { - position: vec3, - normal: vec3, - uv: vec2, -} - -fn unpack_meshlet_vertex(packed: PackedMeshletVertex) -> MeshletVertex { - var vertex: MeshletVertex; - vertex.position = packed.a.xyz; - vertex.normal = vec3(packed.a.w, packed.b.xy); - vertex.uv = packed.b.zw; - return vertex; -} +#import bevy_pbr::utils::octahedral_decode_signed struct Meshlet { - start_vertex_id: u32, + start_vertex_position_bit: u32, + start_vertex_attribute_id: u32, start_index_id: u32, - vertex_count: u32, - triangle_count: u32, + packed_a: u32, + packed_b: u32, + min_vertex_position_channel_x: f32, + min_vertex_position_channel_y: f32, + min_vertex_position_channel_z: f32, +} + +fn get_meshlet_vertex_count(meshlet: ptr) -> u32 { + return extractBits((*meshlet).packed_a, 0u, 8u); +} + +fn get_meshlet_triangle_count(meshlet: ptr) -> u32 { + return extractBits((*meshlet).packed_a, 8u, 8u); } struct MeshletBoundingSpheres { @@ -54,6 +48,8 @@ struct DrawIndirectArgs { first_instance: u32, } +const CENTIMETERS_PER_METER = 100.0; + #ifdef MESHLET_FILL_CLUSTER_BUFFERS_PASS var cluster_count: u32; @group(0) @binding(0) var meshlet_instance_meshlet_counts_prefix_sum: array; // Per entity instance @@ -95,25 +91,58 @@ fn cluster_is_second_pass_candidate(cluster_id: u32) -> bool { @group(0) @binding(0) var meshlet_cluster_meshlet_ids: array; // Per cluster @group(0) @binding(1) var meshlets: array; // Per meshlet @group(0) @binding(2) var meshlet_indices: array; // Many per meshlet -@group(0) @binding(3) var meshlet_vertex_ids: array; // Many per meshlet -@group(0) @binding(4) var meshlet_vertex_data: array; // Many per meshlet -@group(0) @binding(5) var meshlet_cluster_instance_ids: array; // Per cluster -@group(0) @binding(6) var meshlet_instance_uniforms: array; // Per entity instance -@group(0) @binding(7) var meshlet_raster_clusters: array; // Single object shared between all workgroups/clusters/triangles -@group(0) @binding(8) var meshlet_software_raster_cluster_count: u32; +@group(0) @binding(3) var meshlet_vertex_positions: array; // Many per meshlet +@group(0) @binding(4) var meshlet_cluster_instance_ids: array; // Per cluster +@group(0) @binding(5) var meshlet_instance_uniforms: array; // Per entity instance +@group(0) @binding(6) var meshlet_raster_clusters: array; // Single object shared between all workgroups/clusters/triangles +@group(0) @binding(7) var meshlet_software_raster_cluster_count: u32; #ifdef MESHLET_VISIBILITY_BUFFER_RASTER_PASS_OUTPUT -@group(0) @binding(9) var meshlet_visibility_buffer: array>; // Per pixel +@group(0) @binding(8) var meshlet_visibility_buffer: array>; // Per pixel #else -@group(0) @binding(9) var meshlet_visibility_buffer: array>; // Per pixel +@group(0) @binding(8) var meshlet_visibility_buffer: array>; // Per pixel #endif -@group(0) @binding(10) var view: View; +@group(0) @binding(9) var view: View; // TODO: Load only twice, instead of 3x in cases where you load 3 indices per thread? -fn get_meshlet_index(index_id: u32) -> u32 { +fn get_meshlet_vertex_id(index_id: u32) -> u32 { let packed_index = meshlet_indices[index_id / 4u]; let bit_offset = (index_id % 4u) * 8u; return extractBits(packed_index, bit_offset, 8u); } + +fn get_meshlet_vertex_position(meshlet: ptr, vertex_id: u32) -> vec3 { + // Get bitstream start for the vertex + let unpacked = unpack4xU8((*meshlet).packed_b); + let bits_per_channel = unpacked.xyz; + let bits_per_vertex = bits_per_channel.x + bits_per_channel.y + bits_per_channel.z; + var start_bit = (*meshlet).start_vertex_position_bit + (vertex_id * bits_per_vertex); + + // Read each vertex channel from the bitstream + var vertex_position_packed = vec3(0u); + for (var i = 0u; i < 3u; i++) { + let lower_word_index = start_bit / 32u; + let lower_word_bit_offset = start_bit & 31u; + var next_32_bits = meshlet_vertex_positions[lower_word_index] >> lower_word_bit_offset; + if lower_word_bit_offset + bits_per_channel[i] > 32u { + next_32_bits |= meshlet_vertex_positions[lower_word_index + 1u] << (32u - lower_word_bit_offset); + } + vertex_position_packed[i] = extractBits(next_32_bits, 0u, bits_per_channel[i]); + start_bit += bits_per_channel[i]; + } + + // Remap [0, range_max - range_min] vec3 to [range_min, range_max] vec3 + var vertex_position = vec3(vertex_position_packed) + vec3( + (*meshlet).min_vertex_position_channel_x, + (*meshlet).min_vertex_position_channel_y, + (*meshlet).min_vertex_position_channel_z, + ); + + // Reverse vertex quantization + let vertex_position_quantization_factor = unpacked.w; + vertex_position /= f32(1u << vertex_position_quantization_factor) * CENTIMETERS_PER_METER; + + return vertex_position; +} #endif #ifdef MESHLET_MESH_MATERIAL_PASS @@ -121,15 +150,59 @@ fn get_meshlet_index(index_id: u32) -> u32 { @group(1) @binding(1) var meshlet_cluster_meshlet_ids: array; // Per cluster @group(1) @binding(2) var meshlets: array; // Per meshlet @group(1) @binding(3) var meshlet_indices: array; // Many per meshlet -@group(1) @binding(4) var meshlet_vertex_ids: array; // Many per meshlet -@group(1) @binding(5) var meshlet_vertex_data: array; // Many per meshlet -@group(1) @binding(6) var meshlet_cluster_instance_ids: array; // Per cluster -@group(1) @binding(7) var meshlet_instance_uniforms: array; // Per entity instance +@group(1) @binding(4) var meshlet_vertex_positions: array; // Many per meshlet +@group(1) @binding(5) var meshlet_vertex_normals: array; // Many per meshlet +@group(1) @binding(6) var meshlet_vertex_uvs: array>; // Many per meshlet +@group(1) @binding(7) var meshlet_cluster_instance_ids: array; // Per cluster +@group(1) @binding(8) var meshlet_instance_uniforms: array; // Per entity instance // TODO: Load only twice, instead of 3x in cases where you load 3 indices per thread? -fn get_meshlet_index(index_id: u32) -> u32 { +fn get_meshlet_vertex_id(index_id: u32) -> u32 { let packed_index = meshlet_indices[index_id / 4u]; let bit_offset = (index_id % 4u) * 8u; return extractBits(packed_index, bit_offset, 8u); } + +fn get_meshlet_vertex_position(meshlet: ptr, vertex_id: u32) -> vec3 { + // Get bitstream start for the vertex + let unpacked = unpack4xU8((*meshlet).packed_b); + let bits_per_channel = unpacked.xyz; + let bits_per_vertex = bits_per_channel.x + bits_per_channel.y + bits_per_channel.z; + var start_bit = (*meshlet).start_vertex_position_bit + (vertex_id * bits_per_vertex); + + // Read each vertex channel from the bitstream + var vertex_position_packed = vec3(0u); + for (var i = 0u; i < 3u; i++) { + let lower_word_index = start_bit / 32u; + let lower_word_bit_offset = start_bit & 31u; + var next_32_bits = meshlet_vertex_positions[lower_word_index] >> lower_word_bit_offset; + if lower_word_bit_offset + bits_per_channel[i] > 32u { + next_32_bits |= meshlet_vertex_positions[lower_word_index + 1u] << (32u - lower_word_bit_offset); + } + vertex_position_packed[i] = extractBits(next_32_bits, 0u, bits_per_channel[i]); + start_bit += bits_per_channel[i]; + } + + // Remap [0, range_max - range_min] vec3 to [range_min, range_max] vec3 + var vertex_position = vec3(vertex_position_packed) + vec3( + (*meshlet).min_vertex_position_channel_x, + (*meshlet).min_vertex_position_channel_y, + (*meshlet).min_vertex_position_channel_z, + ); + + // Reverse vertex quantization + let vertex_position_quantization_factor = unpacked.w; + vertex_position /= f32(1u << vertex_position_quantization_factor) * CENTIMETERS_PER_METER; + + return vertex_position; +} + +fn get_meshlet_vertex_normal(meshlet: ptr, vertex_id: u32) -> vec3 { + let packed_normal = meshlet_vertex_normals[(*meshlet).start_vertex_attribute_id + vertex_id]; + return octahedral_decode_signed(unpack2x16snorm(packed_normal)); +} + +fn get_meshlet_vertex_uv(meshlet: ptr, vertex_id: u32) -> vec2 { + return meshlet_vertex_uvs[(*meshlet).start_vertex_attribute_id + vertex_id]; +} #endif diff --git a/crates/bevy_pbr/src/meshlet/meshlet_mesh_manager.rs b/crates/bevy_pbr/src/meshlet/meshlet_mesh_manager.rs index 09e065b019..831fe11fb3 100644 --- a/crates/bevy_pbr/src/meshlet/meshlet_mesh_manager.rs +++ b/crates/bevy_pbr/src/meshlet/meshlet_mesh_manager.rs @@ -9,6 +9,7 @@ use bevy_ecs::{ system::{Res, ResMut, Resource}, world::{FromWorld, World}, }; +use bevy_math::Vec2; use bevy_render::{ render_resource::BufferAddress, renderer::{RenderDevice, RenderQueue}, @@ -19,20 +20,22 @@ use core::ops::Range; /// Manages uploading [`MeshletMesh`] asset data to the GPU. #[derive(Resource)] pub struct MeshletMeshManager { - pub vertex_data: PersistentGpuBuffer>, - pub vertex_ids: PersistentGpuBuffer>, + pub vertex_positions: PersistentGpuBuffer>, + pub vertex_normals: PersistentGpuBuffer>, + pub vertex_uvs: PersistentGpuBuffer>, pub indices: PersistentGpuBuffer>, pub meshlets: PersistentGpuBuffer>, pub meshlet_bounding_spheres: PersistentGpuBuffer>, - meshlet_mesh_slices: HashMap, [Range; 5]>, + meshlet_mesh_slices: HashMap, [Range; 6]>, } impl FromWorld for MeshletMeshManager { fn from_world(world: &mut World) -> Self { let render_device = world.resource::(); Self { - vertex_data: PersistentGpuBuffer::new("meshlet_vertex_data", render_device), - vertex_ids: PersistentGpuBuffer::new("meshlet_vertex_ids", render_device), + vertex_positions: PersistentGpuBuffer::new("meshlet_vertex_positions", render_device), + vertex_normals: PersistentGpuBuffer::new("meshlet_vertex_normals", render_device), + vertex_uvs: PersistentGpuBuffer::new("meshlet_vertex_uvs", render_device), indices: PersistentGpuBuffer::new("meshlet_indices", render_device), meshlets: PersistentGpuBuffer::new("meshlets", render_device), meshlet_bounding_spheres: PersistentGpuBuffer::new( @@ -55,27 +58,34 @@ impl MeshletMeshManager { "MeshletMesh asset was already unloaded but is not registered with MeshletMeshManager", ); - let vertex_data_slice = self - .vertex_data - .queue_write(Arc::clone(&meshlet_mesh.vertex_data), ()); - let vertex_ids_slice = self.vertex_ids.queue_write( - Arc::clone(&meshlet_mesh.vertex_ids), - vertex_data_slice.start, - ); + let vertex_positions_slice = self + .vertex_positions + .queue_write(Arc::clone(&meshlet_mesh.vertex_positions), ()); + let vertex_normals_slice = self + .vertex_normals + .queue_write(Arc::clone(&meshlet_mesh.vertex_normals), ()); + let vertex_uvs_slice = self + .vertex_uvs + .queue_write(Arc::clone(&meshlet_mesh.vertex_uvs), ()); let indices_slice = self .indices .queue_write(Arc::clone(&meshlet_mesh.indices), ()); let meshlets_slice = self.meshlets.queue_write( Arc::clone(&meshlet_mesh.meshlets), - (vertex_ids_slice.start, indices_slice.start), + ( + vertex_positions_slice.start, + vertex_normals_slice.start, + indices_slice.start, + ), ); let meshlet_bounding_spheres_slice = self .meshlet_bounding_spheres - .queue_write(Arc::clone(&meshlet_mesh.bounding_spheres), ()); + .queue_write(Arc::clone(&meshlet_mesh.meshlet_bounding_spheres), ()); [ - vertex_data_slice, - vertex_ids_slice, + vertex_positions_slice, + vertex_normals_slice, + vertex_uvs_slice, indices_slice, meshlets_slice, meshlet_bounding_spheres_slice, @@ -83,7 +93,7 @@ impl MeshletMeshManager { }; // If the MeshletMesh asset has not been uploaded to the GPU yet, queue it for uploading - let [_, _, _, meshlets_slice, _] = self + let [_, _, _, _, meshlets_slice, _] = self .meshlet_mesh_slices .entry(asset_id) .or_insert_with_key(queue_meshlet_mesh) @@ -96,11 +106,13 @@ impl MeshletMeshManager { pub fn remove(&mut self, asset_id: &AssetId) { if let Some( - [vertex_data_slice, vertex_ids_slice, indices_slice, meshlets_slice, meshlet_bounding_spheres_slice], + [vertex_positions_slice, vertex_normals_slice, vertex_uvs_slice, indices_slice, meshlets_slice, meshlet_bounding_spheres_slice], ) = self.meshlet_mesh_slices.remove(asset_id) { - self.vertex_data.mark_slice_unused(vertex_data_slice); - self.vertex_ids.mark_slice_unused(vertex_ids_slice); + self.vertex_positions + .mark_slice_unused(vertex_positions_slice); + self.vertex_normals.mark_slice_unused(vertex_normals_slice); + self.vertex_uvs.mark_slice_unused(vertex_uvs_slice); self.indices.mark_slice_unused(indices_slice); self.meshlets.mark_slice_unused(meshlets_slice); self.meshlet_bounding_spheres @@ -116,10 +128,13 @@ pub fn perform_pending_meshlet_mesh_writes( render_device: Res, ) { meshlet_mesh_manager - .vertex_data + .vertex_positions .perform_writes(&render_queue, &render_device); meshlet_mesh_manager - .vertex_ids + .vertex_normals + .perform_writes(&render_queue, &render_device); + meshlet_mesh_manager + .vertex_uvs .perform_writes(&render_queue, &render_device); meshlet_mesh_manager .indices diff --git a/crates/bevy_pbr/src/meshlet/mod.rs b/crates/bevy_pbr/src/meshlet/mod.rs index d61dc05c56..84c89ff440 100644 --- a/crates/bevy_pbr/src/meshlet/mod.rs +++ b/crates/bevy_pbr/src/meshlet/mod.rs @@ -32,9 +32,11 @@ pub(crate) use self::{ }, }; -pub use self::asset::{MeshletMesh, MeshletMeshSaverLoader}; +pub use self::asset::{MeshletMesh, MeshletMeshLoader, MeshletMeshSaver}; #[cfg(feature = "meshlet_processor")] -pub use self::from_mesh::MeshToMeshletMeshConversionError; +pub use self::from_mesh::{ + MeshToMeshletMeshConversionError, DEFAULT_VERTEX_POSITION_QUANTIZATION_FACTOR, +}; use self::{ graph::NodeMeshlet, @@ -201,7 +203,7 @@ impl Plugin for MeshletPlugin { ); app.init_asset::() - .register_asset_loader(MeshletMeshSaverLoader) + .register_asset_loader(MeshletMeshLoader) .add_systems( PostUpdate, check_visibility::.in_set(VisibilitySystems::CheckVisibility), diff --git a/crates/bevy_pbr/src/meshlet/persistent_buffer_impls.rs b/crates/bevy_pbr/src/meshlet/persistent_buffer_impls.rs index 86054eb675..9fa7058772 100644 --- a/crates/bevy_pbr/src/meshlet/persistent_buffer_impls.rs +++ b/crates/bevy_pbr/src/meshlet/persistent_buffer_impls.rs @@ -3,8 +3,7 @@ use super::{ persistent_buffer::PersistentGpuBufferable, }; use alloc::sync::Arc; - -const MESHLET_VERTEX_SIZE_IN_BYTES: u32 = 32; +use bevy_math::Vec2; impl PersistentGpuBufferable for Arc<[u8]> { type Metadata = (); @@ -19,26 +18,31 @@ impl PersistentGpuBufferable for Arc<[u8]> { } impl PersistentGpuBufferable for Arc<[u32]> { - type Metadata = u64; + type Metadata = (); fn size_in_bytes(&self) -> usize { self.len() * size_of::() } - fn write_bytes_le(&self, offset: Self::Metadata, buffer_slice: &mut [u8]) { - let offset = offset as u32 / MESHLET_VERTEX_SIZE_IN_BYTES; + fn write_bytes_le(&self, _: Self::Metadata, buffer_slice: &mut [u8]) { + buffer_slice.clone_from_slice(bytemuck::cast_slice(self)); + } +} - for (i, index) in self.iter().enumerate() { - let size = size_of::(); - let i = i * size; - let bytes = (*index + offset).to_le_bytes(); - buffer_slice[i..(i + size)].clone_from_slice(&bytes); - } +impl PersistentGpuBufferable for Arc<[Vec2]> { + type Metadata = (); + + fn size_in_bytes(&self) -> usize { + self.len() * size_of::() + } + + fn write_bytes_le(&self, _: Self::Metadata, buffer_slice: &mut [u8]) { + buffer_slice.clone_from_slice(bytemuck::cast_slice(self)); } } impl PersistentGpuBufferable for Arc<[Meshlet]> { - type Metadata = (u64, u64); + type Metadata = (u64, u64, u64); fn size_in_bytes(&self) -> usize { self.len() * size_of::() @@ -46,20 +50,23 @@ impl PersistentGpuBufferable for Arc<[Meshlet]> { fn write_bytes_le( &self, - (vertex_offset, index_offset): Self::Metadata, + (vertex_position_offset, vertex_attribute_offset, index_offset): Self::Metadata, buffer_slice: &mut [u8], ) { - let vertex_offset = (vertex_offset as usize / size_of::()) as u32; + let vertex_position_offset = (vertex_position_offset * 8) as u32; + let vertex_attribute_offset = (vertex_attribute_offset as usize / size_of::()) as u32; let index_offset = index_offset as u32; for (i, meshlet) in self.iter().enumerate() { let size = size_of::(); let i = i * size; let bytes = bytemuck::cast::<_, [u8; size_of::()]>(Meshlet { - start_vertex_id: meshlet.start_vertex_id + vertex_offset, + start_vertex_position_bit: meshlet.start_vertex_position_bit + + vertex_position_offset, + start_vertex_attribute_id: meshlet.start_vertex_attribute_id + + vertex_attribute_offset, start_index_id: meshlet.start_index_id + index_offset, - vertex_count: meshlet.vertex_count, - triangle_count: meshlet.triangle_count, + ..*meshlet }); buffer_slice[i..(i + size)].clone_from_slice(&bytes); } diff --git a/crates/bevy_pbr/src/meshlet/resource_manager.rs b/crates/bevy_pbr/src/meshlet/resource_manager.rs index aca52d1333..fd95f45b3f 100644 --- a/crates/bevy_pbr/src/meshlet/resource_manager.rs +++ b/crates/bevy_pbr/src/meshlet/resource_manager.rs @@ -185,7 +185,6 @@ impl ResourceManager { storage_buffer_read_only_sized(false, None), storage_buffer_read_only_sized(false, None), storage_buffer_read_only_sized(false, None), - storage_buffer_read_only_sized(false, None), storage_buffer_sized(false, None), uniform_buffer::(true), ), @@ -222,6 +221,7 @@ impl ResourceManager { storage_buffer_read_only_sized(false, None), storage_buffer_read_only_sized(false, None), storage_buffer_read_only_sized(false, None), + storage_buffer_read_only_sized(false, None), ), ), ), @@ -702,8 +702,7 @@ pub fn prepare_meshlet_view_bind_groups( cluster_meshlet_ids.as_entire_binding(), meshlet_mesh_manager.meshlets.binding(), meshlet_mesh_manager.indices.binding(), - meshlet_mesh_manager.vertex_ids.binding(), - meshlet_mesh_manager.vertex_data.binding(), + meshlet_mesh_manager.vertex_positions.binding(), cluster_instance_ids.as_entire_binding(), instance_manager.instance_uniforms.binding().unwrap(), resource_manager @@ -746,8 +745,9 @@ pub fn prepare_meshlet_view_bind_groups( cluster_meshlet_ids.as_entire_binding(), meshlet_mesh_manager.meshlets.binding(), meshlet_mesh_manager.indices.binding(), - meshlet_mesh_manager.vertex_ids.binding(), - meshlet_mesh_manager.vertex_data.binding(), + meshlet_mesh_manager.vertex_positions.binding(), + meshlet_mesh_manager.vertex_normals.binding(), + meshlet_mesh_manager.vertex_uvs.binding(), cluster_instance_ids.as_entire_binding(), instance_manager.instance_uniforms.binding().unwrap(), )); diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_hardware_raster.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_hardware_raster.wgsl index eb0947f967..65ccb17482 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_hardware_raster.wgsl +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_hardware_raster.wgsl @@ -2,15 +2,14 @@ meshlet_bindings::{ meshlet_cluster_meshlet_ids, meshlets, - meshlet_vertex_ids, - meshlet_vertex_data, meshlet_cluster_instance_ids, meshlet_instance_uniforms, meshlet_raster_clusters, meshlet_visibility_buffer, view, - get_meshlet_index, - unpack_meshlet_vertex, + get_meshlet_triangle_count, + get_meshlet_vertex_id, + get_meshlet_vertex_position, }, mesh_functions::mesh_position_local_to_world, } @@ -33,20 +32,19 @@ struct VertexOutput { fn vertex(@builtin(instance_index) instance_index: u32, @builtin(vertex_index) vertex_index: u32) -> VertexOutput { let cluster_id = meshlet_raster_clusters[meshlet_raster_cluster_rightmost_slot - instance_index]; let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; - let meshlet = meshlets[meshlet_id]; + var meshlet = meshlets[meshlet_id]; let triangle_id = vertex_index / 3u; - if triangle_id >= meshlet.triangle_count { return dummy_vertex(); } + if triangle_id >= get_meshlet_triangle_count(&meshlet) { return dummy_vertex(); } let index_id = (triangle_id * 3u) + (vertex_index % 3u); - let index = get_meshlet_index(meshlet.start_index_id + index_id); - let vertex_id = meshlet_vertex_ids[meshlet.start_vertex_id + index]; - let vertex = unpack_meshlet_vertex(meshlet_vertex_data[vertex_id]); + let vertex_id = get_meshlet_vertex_id(meshlet.start_index_id + index_id); let instance_id = meshlet_cluster_instance_ids[cluster_id]; let instance_uniform = meshlet_instance_uniforms[instance_id]; + let vertex_position = get_meshlet_vertex_position(&meshlet, vertex_id); let world_from_local = affine3_to_square(instance_uniform.world_from_local); - let world_position = mesh_position_local_to_world(world_from_local, vec4(vertex.position, 1.0)); + let world_position = mesh_position_local_to_world(world_from_local, vec4(vertex_position, 1.0)); var clip_position = view.clip_from_world * vec4(world_position.xyz, 1.0); #ifdef DEPTH_CLAMP_ORTHO let unclamped_clip_depth = clip_position.z; diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl index e685d33866..4346a3b29c 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_resolve.wgsl @@ -2,15 +2,16 @@ #import bevy_pbr::{ meshlet_bindings::{ + Meshlet, meshlet_visibility_buffer, meshlet_cluster_meshlet_ids, meshlets, - meshlet_vertex_ids, - meshlet_vertex_data, meshlet_cluster_instance_ids, meshlet_instance_uniforms, - get_meshlet_index, - unpack_meshlet_vertex, + get_meshlet_vertex_id, + get_meshlet_vertex_position, + get_meshlet_vertex_normal, + get_meshlet_vertex_uv, }, mesh_view_bindings::view, mesh_functions::mesh_position_local_to_world, @@ -106,59 +107,58 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { let packed_ids = u32(meshlet_visibility_buffer[frag_coord_1d]); // TODO: Might be faster to load the correct u32 directly let cluster_id = packed_ids >> 7u; let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; - let meshlet = meshlets[meshlet_id]; + var meshlet = meshlets[meshlet_id]; let triangle_id = extractBits(packed_ids, 0u, 7u); let index_ids = meshlet.start_index_id + (triangle_id * 3u) + vec3(0u, 1u, 2u); - let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z)); - let vertex_ids = vec3(meshlet_vertex_ids[indices.x], meshlet_vertex_ids[indices.y], meshlet_vertex_ids[indices.z]); - let vertex_1 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.x]); - let vertex_2 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.y]); - let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]); + let vertex_ids = vec3(get_meshlet_vertex_id(index_ids[0]), get_meshlet_vertex_id(index_ids[1]), get_meshlet_vertex_id(index_ids[2])); + let vertex_0 = load_vertex(&meshlet, vertex_ids[0]); + let vertex_1 = load_vertex(&meshlet, vertex_ids[1]); + let vertex_2 = load_vertex(&meshlet, vertex_ids[2]); let instance_id = meshlet_cluster_instance_ids[cluster_id]; var instance_uniform = meshlet_instance_uniforms[instance_id]; let world_from_local = affine3_to_square(instance_uniform.world_from_local); + let world_position_0 = mesh_position_local_to_world(world_from_local, vec4(vertex_0.position, 1.0)); let world_position_1 = mesh_position_local_to_world(world_from_local, vec4(vertex_1.position, 1.0)); let world_position_2 = mesh_position_local_to_world(world_from_local, vec4(vertex_2.position, 1.0)); - let world_position_3 = mesh_position_local_to_world(world_from_local, vec4(vertex_3.position, 1.0)); let frag_coord_ndc = frag_coord_to_ndc(frag_coord).xy; let partial_derivatives = compute_partial_derivatives( - array(world_position_1, world_position_2, world_position_3), + array(world_position_0, world_position_1, world_position_2), frag_coord_ndc, view.viewport.zw / 2.0, ); - let world_position = mat3x4(world_position_1, world_position_2, world_position_3) * partial_derivatives.barycentrics; + let world_position = mat3x4(world_position_0, world_position_1, world_position_2) * partial_derivatives.barycentrics; let world_positions_camera_relative = mat3x3( + world_position_0.xyz - view.world_position, world_position_1.xyz - view.world_position, world_position_2.xyz - view.world_position, - world_position_3.xyz - view.world_position, ); let ddx_world_position = world_positions_camera_relative * partial_derivatives.ddx; let ddy_world_position = world_positions_camera_relative * partial_derivatives.ddy; let world_normal = mat3x3( + normal_local_to_world(vertex_0.normal, &instance_uniform), normal_local_to_world(vertex_1.normal, &instance_uniform), normal_local_to_world(vertex_2.normal, &instance_uniform), - normal_local_to_world(vertex_3.normal, &instance_uniform), ) * partial_derivatives.barycentrics; - let uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.barycentrics; - let ddx_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddx; - let ddy_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddy; + let uv = mat3x2(vertex_0.uv, vertex_1.uv, vertex_2.uv) * partial_derivatives.barycentrics; + let ddx_uv = mat3x2(vertex_0.uv, vertex_1.uv, vertex_2.uv) * partial_derivatives.ddx; + let ddy_uv = mat3x2(vertex_0.uv, vertex_1.uv, vertex_2.uv) * partial_derivatives.ddy; let world_tangent = calculate_world_tangent(world_normal, ddx_world_position, ddy_world_position, ddx_uv, ddy_uv); #ifdef PREPASS_FRAGMENT #ifdef MOTION_VECTOR_PREPASS let previous_world_from_local = affine3_to_square(instance_uniform.previous_world_from_local); + let previous_world_position_0 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_0.position, 1.0)); let previous_world_position_1 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_1.position, 1.0)); let previous_world_position_2 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_2.position, 1.0)); - let previous_world_position_3 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_3.position, 1.0)); - let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * partial_derivatives.barycentrics; + let previous_world_position = mat3x4(previous_world_position_0, previous_world_position_1, previous_world_position_2) * partial_derivatives.barycentrics; let motion_vector = calculate_motion_vector(world_position, previous_world_position); #endif #endif @@ -181,6 +181,20 @@ fn resolve_vertex_output(frag_coord: vec4) -> VertexOutput { ); } +struct MeshletVertex { + position: vec3, + normal: vec3, + uv: vec2, +} + +fn load_vertex(meshlet: ptr, vertex_id: u32) -> MeshletVertex { + return MeshletVertex( + get_meshlet_vertex_position(meshlet, vertex_id), + get_meshlet_vertex_normal(meshlet, vertex_id), + get_meshlet_vertex_uv(meshlet, vertex_id), + ); +} + fn normal_local_to_world(vertex_normal: vec3, instance_uniform: ptr) -> vec3 { if any(vertex_normal != vec3(0.0)) { return normalize( diff --git a/crates/bevy_pbr/src/meshlet/visibility_buffer_software_raster.wgsl b/crates/bevy_pbr/src/meshlet/visibility_buffer_software_raster.wgsl index 853eb853b7..6772ff28f2 100644 --- a/crates/bevy_pbr/src/meshlet/visibility_buffer_software_raster.wgsl +++ b/crates/bevy_pbr/src/meshlet/visibility_buffer_software_raster.wgsl @@ -2,16 +2,16 @@ meshlet_bindings::{ meshlet_cluster_meshlet_ids, meshlets, - meshlet_vertex_ids, - meshlet_vertex_data, meshlet_cluster_instance_ids, meshlet_instance_uniforms, meshlet_raster_clusters, meshlet_software_raster_cluster_count, meshlet_visibility_buffer, view, - get_meshlet_index, - unpack_meshlet_vertex, + get_meshlet_vertex_count, + get_meshlet_triangle_count, + get_meshlet_vertex_id, + get_meshlet_vertex_position, }, mesh_functions::mesh_position_local_to_world, view_transformations::ndc_to_uv, @@ -42,7 +42,7 @@ fn rasterize_cluster( let cluster_id = meshlet_raster_clusters[workgroup_id_1d]; let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id]; - let meshlet = meshlets[meshlet_id]; + var meshlet = meshlets[meshlet_id]; let instance_id = meshlet_cluster_instance_ids[cluster_id]; let instance_uniform = meshlet_instance_uniforms[instance_id]; @@ -51,12 +51,11 @@ fn rasterize_cluster( // Load and project 1 vertex per thread, and then again if there are more than 128 vertices in the meshlet for (var i = 0u; i <= 128u; i += 128u) { let vertex_id = local_invocation_index + i; - if vertex_id < meshlet.vertex_count { - let meshlet_vertex_id = meshlet_vertex_ids[meshlet.start_vertex_id + vertex_id]; - let vertex = unpack_meshlet_vertex(meshlet_vertex_data[meshlet_vertex_id]); + if vertex_id < get_meshlet_vertex_count(&meshlet) { + let vertex_position = get_meshlet_vertex_position(&meshlet, vertex_id); // Project vertex to viewport space - let world_position = mesh_position_local_to_world(world_from_local, vec4(vertex.position, 1.0)); + let world_position = mesh_position_local_to_world(world_from_local, vec4(vertex_position, 1.0)); let clip_position = view.clip_from_world * vec4(world_position.xyz, 1.0); var ndc_position = clip_position.xyz / clip_position.w; #ifdef DEPTH_CLAMP_ORTHO @@ -72,9 +71,9 @@ fn rasterize_cluster( // Load 1 triangle's worth of vertex data per thread let triangle_id = local_invocation_index; - if triangle_id >= meshlet.triangle_count { return; } + if triangle_id >= get_meshlet_triangle_count(&meshlet) { return; } let index_ids = meshlet.start_index_id + (triangle_id * 3u) + vec3(0u, 1u, 2u); - let vertex_ids = vec3(get_meshlet_index(index_ids[0]), get_meshlet_index(index_ids[1]), get_meshlet_index(index_ids[2])); + let vertex_ids = vec3(get_meshlet_vertex_id(index_ids[0]), get_meshlet_vertex_id(index_ids[1]), get_meshlet_vertex_id(index_ids[2])); let vertex_0 = viewport_vertices[vertex_ids[2]]; let vertex_1 = viewport_vertices[vertex_ids[1]]; let vertex_2 = viewport_vertices[vertex_ids[0]]; diff --git a/crates/bevy_pbr/src/render/utils.wgsl b/crates/bevy_pbr/src/render/utils.wgsl index dbee281815..c887e3005e 100644 --- a/crates/bevy_pbr/src/render/utils.wgsl +++ b/crates/bevy_pbr/src/render/utils.wgsl @@ -55,7 +55,13 @@ fn octahedral_encode(v: vec3) -> vec2 { // For decoding normals or unit direction vectors from octahedral coordinates. fn octahedral_decode(v: vec2) -> vec3 { let f = v * 2.0 - 1.0; - var n = vec3(f.xy, 1.0 - abs(f.x) - abs(f.y)); + var n = octahedral_decode_signed(f); + return normalize(n); +} + +// Like octahedral_decode, but for input in [-1, 1] instead of [0, 1]. +fn octahedral_decode_signed(v: vec2) -> vec3 { + var n = vec3(v.xy, 1.0 - abs(v.x) - abs(v.y)); let t = saturate(-n.z); let w = select(vec2(t), vec2(-t), n.xy >= vec2(0.0)); n = vec3(n.xy + w, n.z); diff --git a/examples/3d/meshlet.rs b/examples/3d/meshlet.rs index 21d1bb0f9b..d8a09a2356 100644 --- a/examples/3d/meshlet.rs +++ b/examples/3d/meshlet.rs @@ -17,7 +17,7 @@ use camera_controller::{CameraController, CameraControllerPlugin}; use std::{f32::consts::PI, path::Path, process::ExitCode}; const ASSET_URL: &str = - "https://raw.githubusercontent.com/JMS55/bevy_meshlet_asset/854eb98353ad94aea1104f355fc24dbe4fda679d/bunny.meshlet_mesh"; + "https://raw.githubusercontent.com/JMS55/bevy_meshlet_asset/8443bbdee0bf517e6c297dede7f6a46ab712ee4c/bunny.meshlet_mesh"; fn main() -> ExitCode { if !Path::new("./assets/models/bunny.meshlet_mesh").exists() {