Fix meshlet vertex attribute interpolation (#13775)

# Objective

- Mikktspace requires that we normalize world normals/tangents _before_
interpolation across vertices, and then do _not_ normalize after. I had
it backwards.
- We do not (am not supposed to?) need a second set of barycentrics for
motion vectors. If you think about the typical raster pipeline, in the
vertex shader we calculate previous_world_position, and then it gets
interpolated using the current triangle's barycentrics.

## Solution

- Fix normal/tangent processing 
- Reuse barycentrics for motion vector calculations
- Not implementing this for 0.14, but long term I aim to remove explicit
vertex tangents and calculate them in the shader on the fly.

## Testing

- I tested out some of the normal maps we have in repo. Didn't seem to
make a difference, but mikktspace is all about correctness across
various baking tools. I probably just didn't have any of the ones that
would cause it to break.
- Didn't test motion vectors as there's a known bug with the depth
buffer and meshlets that I'm waiting on the render graph rewrite to fix.
This commit is contained in:
JMS55 2024-06-10 13:18:43 -07:00 committed by François
parent 7cd90990f9
commit 93f48edbc3
No known key found for this signature in database
2 changed files with 50 additions and 35 deletions

View file

@ -13,8 +13,8 @@
unpack_meshlet_vertex,
},
mesh_view_bindings::view,
mesh_functions::mesh_position_local_to_world,
mesh_types::MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT,
mesh_functions::{mesh_position_local_to_world, sign_determinant_model_3x3m},
mesh_types::{Mesh, MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT},
view_transformations::{position_world_to_clip, frag_coord_to_ndc},
}
#import bevy_render::maths::{affine3_to_square, mat2x4_f32_to_mat3x3_unpack}
@ -99,6 +99,7 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let cluster_id = packed_ids >> 6u;
let meshlet_id = meshlet_cluster_meshlet_ids[cluster_id];
let meshlet = meshlets[meshlet_id];
let triangle_id = extractBits(packed_ids, 0u, 6u);
let index_ids = meshlet.start_index_id + vec3(triangle_id * 3u) + vec3(0u, 1u, 2u);
let indices = meshlet.start_vertex_id + vec3(get_meshlet_index(index_ids.x), get_meshlet_index(index_ids.y), get_meshlet_index(index_ids.z));
@ -108,9 +109,9 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let vertex_3 = unpack_meshlet_vertex(meshlet_vertex_data[vertex_ids.z]);
let instance_id = meshlet_cluster_instance_ids[cluster_id];
let instance_uniform = meshlet_instance_uniforms[instance_id];
let world_from_local = affine3_to_square(instance_uniform.world_from_local);
var instance_uniform = meshlet_instance_uniforms[instance_id];
let world_from_local = affine3_to_square(instance_uniform.world_from_local);
let world_position_1 = mesh_position_local_to_world(world_from_local, vec4(vertex_1.position, 1.0));
let world_position_2 = mesh_position_local_to_world(world_from_local, vec4(vertex_2.position, 1.0));
let world_position_3 = mesh_position_local_to_world(world_from_local, vec4(vertex_3.position, 1.0));
@ -126,27 +127,19 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
);
let world_position = mat3x4(world_position_1, world_position_2, world_position_3) * partial_derivatives.barycentrics;
let vertex_normal = mat3x3(vertex_1.normal, vertex_2.normal, vertex_3.normal) * partial_derivatives.barycentrics;
let world_normal = normalize(
mat2x4_f32_to_mat3x3_unpack(
instance_uniform.local_from_world_transpose_a,
instance_uniform.local_from_world_transpose_b,
) * vertex_normal
);
let world_normal = mat3x3(
normal_local_to_world(vertex_1.normal, &instance_uniform),
normal_local_to_world(vertex_2.normal, &instance_uniform),
normal_local_to_world(vertex_3.normal, &instance_uniform),
) * partial_derivatives.barycentrics;
let uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.barycentrics;
let ddx_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddx;
let ddy_uv = mat3x2(vertex_1.uv, vertex_2.uv, vertex_3.uv) * partial_derivatives.ddy;
let vertex_tangent = mat3x4(vertex_1.tangent, vertex_2.tangent, vertex_3.tangent) * partial_derivatives.barycentrics;
let world_tangent = vec4(
normalize(
mat3x3(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz
) * vertex_tangent.xyz
),
vertex_tangent.w * (f32(bool(instance_uniform.flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0)
);
let world_tangent = mat3x4(
tangent_local_to_world(vertex_1.tangent, world_from_local, instance_uniform.flags),
tangent_local_to_world(vertex_2.tangent, world_from_local, instance_uniform.flags),
tangent_local_to_world(vertex_3.tangent, world_from_local, instance_uniform.flags),
) * partial_derivatives.barycentrics;
#ifdef PREPASS_FRAGMENT
#ifdef MOTION_VECTOR_PREPASS
@ -154,15 +147,7 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
let previous_world_position_1 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_1.position, 1.0));
let previous_world_position_2 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_2.position, 1.0));
let previous_world_position_3 = mesh_position_local_to_world(previous_world_from_local, vec4(vertex_3.position, 1.0));
let previous_clip_position_1 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_1.xyz, 1.0);
let previous_clip_position_2 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_2.xyz, 1.0);
let previous_clip_position_3 = previous_view_uniforms.clip_from_world * vec4(previous_world_position_3.xyz, 1.0);
let previous_partial_derivatives = compute_partial_derivatives(
array(previous_clip_position_1, previous_clip_position_2, previous_clip_position_3),
frag_coord_ndc,
view.viewport.zw,
);
let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * previous_partial_derivatives.barycentrics;
let previous_world_position = mat3x4(previous_world_position_1, previous_world_position_2, previous_world_position_3) * partial_derivatives.barycentrics;
let motion_vector = calculate_motion_vector(world_position, previous_world_position);
#endif
#endif
@ -184,4 +169,34 @@ fn resolve_vertex_output(frag_coord: vec4<f32>) -> VertexOutput {
#endif
);
}
fn normal_local_to_world(vertex_normal: vec3<f32>, instance_uniform: ptr<function, Mesh>) -> vec3<f32> {
if any(vertex_normal != vec3<f32>(0.0)) {
return normalize(
mat2x4_f32_to_mat3x3_unpack(
(*instance_uniform).local_from_world_transpose_a,
(*instance_uniform).local_from_world_transpose_b,
) * vertex_normal
);
} else {
return vertex_normal;
}
}
fn tangent_local_to_world(vertex_tangent: vec4<f32>, world_from_local: mat4x4<f32>, mesh_flags: u32) -> vec4<f32> {
if any(vertex_tangent != vec4<f32>(0.0)) {
return vec4<f32>(
normalize(
mat3x3<f32>(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz,
) * vertex_tangent.xyz
),
vertex_tangent.w * sign_determinant_model_3x3m(mesh_flags)
);
} else {
return vertex_tangent;
}
}
#endif

View file

@ -55,11 +55,11 @@ fn mesh_normal_local_to_world(vertex_normal: vec3<f32>, instance_index: u32) ->
// Calculates the sign of the determinant of the 3x3 model matrix based on a
// mesh flag
fn sign_determinant_model_3x3m(instance_index: u32) -> f32 {
fn sign_determinant_model_3x3m(mesh_flags: u32) -> f32 {
// bool(u32) is false if 0u else true
// f32(bool) is 1.0 if true else 0.0
// * 2.0 - 1.0 remaps 0.0 or 1.0 to -1.0 or 1.0 respectively
return f32(bool(mesh[instance_index].flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0;
return f32(bool(mesh_flags & MESH_FLAGS_SIGN_DETERMINANT_MODEL_3X3_BIT)) * 2.0 - 1.0;
}
fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: vec4<f32>, instance_index: u32) -> vec4<f32> {
@ -76,12 +76,12 @@ fn mesh_tangent_local_to_world(world_from_local: mat4x4<f32>, vertex_tangent: ve
mat3x3<f32>(
world_from_local[0].xyz,
world_from_local[1].xyz,
world_from_local[2].xyz
world_from_local[2].xyz,
) * vertex_tangent.xyz
),
// NOTE: Multiplying by the sign of the determinant of the 3x3 model matrix accounts for
// situations such as negative scaling.
vertex_tangent.w * sign_determinant_model_3x3m(instance_index)
vertex_tangent.w * sign_determinant_model_3x3m(mesh[instance_index].flags)
);
} else {
return vertex_tangent;