From d44e86507f66d2cbe2472557d2a2abb9c4f80a83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois?= Date: Mon, 21 Nov 2022 22:38:29 +0000 Subject: [PATCH] Shader defs can now have a value (#5900) # Objective - shaders defs can now have a `bool` or `int` value - `#if SHADER_DEF 3` - ok if `SHADER_DEF` is defined, has the correct type and pass the comparison - `==`, `!=`, `>=`, `>`, `<`, `<=` supported - `#SHADER_DEF` or `#{SHADER_DEF}` - will be replaced by the value in the shader code --- ## Migration Guide - replace `shader_defs.push(String::from("NAME"));` by `shader_defs.push("NAME".into());` - if you used shader def `NO_STORAGE_BUFFERS_SUPPORT`, check how `AVAILABLE_STORAGE_BUFFER_BINDINGS` is now used in Bevy default shaders --- crates/bevy_core_pipeline/src/fxaa/mod.rs | 4 +- .../bevy_core_pipeline/src/tonemapping/mod.rs | 2 +- crates/bevy_pbr/src/pbr_material.rs | 2 +- .../src/render/clustered_forward.wgsl | 12 +- crates/bevy_pbr/src/render/light.rs | 6 +- crates/bevy_pbr/src/render/mesh.rs | 23 +- .../src/render/mesh_view_bindings.wgsl | 16 +- .../bevy_pbr/src/render/mesh_view_types.wgsl | 24 +- .../src/render_resource/pipeline.rs | 7 +- .../src/render_resource/pipeline_cache.rs | 41 +- .../bevy_render/src/render_resource/shader.rs | 647 +++++++++++++++++- crates/bevy_sprite/src/mesh2d/mesh.rs | 14 +- crates/bevy_sprite/src/render/mod.rs | 6 +- examples/shader/shader_defs.rs | 2 +- 14 files changed, 721 insertions(+), 85 deletions(-) diff --git a/crates/bevy_core_pipeline/src/fxaa/mod.rs b/crates/bevy_core_pipeline/src/fxaa/mod.rs index e2bdce339f..f49425e725 100644 --- a/crates/bevy_core_pipeline/src/fxaa/mod.rs +++ b/crates/bevy_core_pipeline/src/fxaa/mod.rs @@ -203,8 +203,8 @@ impl SpecializedRenderPipeline for FxaaPipeline { fragment: Some(FragmentState { shader: FXAA_SHADER_HANDLE.typed(), shader_defs: vec![ - format!("EDGE_THRESH_{}", key.edge_threshold.get_str()), - format!("EDGE_THRESH_MIN_{}", key.edge_threshold_min.get_str()), + format!("EDGE_THRESH_{}", key.edge_threshold.get_str()).into(), + format!("EDGE_THRESH_MIN_{}", key.edge_threshold_min.get_str()).into(), ], entry_point: "fragment".into(), targets: vec![Some(ColorTargetState { diff --git a/crates/bevy_core_pipeline/src/tonemapping/mod.rs b/crates/bevy_core_pipeline/src/tonemapping/mod.rs index 9e3957a41f..785c231a43 100644 --- a/crates/bevy_core_pipeline/src/tonemapping/mod.rs +++ b/crates/bevy_core_pipeline/src/tonemapping/mod.rs @@ -66,7 +66,7 @@ impl SpecializedRenderPipeline for TonemappingPipeline { fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor { let mut shader_defs = Vec::new(); if key.deband_dither { - shader_defs.push("DEBAND_DITHER".to_string()); + shader_defs.push("DEBAND_DITHER".into()); } RenderPipelineDescriptor { label: Some("tonemapping pipeline".into()), diff --git a/crates/bevy_pbr/src/pbr_material.rs b/crates/bevy_pbr/src/pbr_material.rs index 8e48acf582..da179b7ccd 100644 --- a/crates/bevy_pbr/src/pbr_material.rs +++ b/crates/bevy_pbr/src/pbr_material.rs @@ -419,7 +419,7 @@ impl Material for StandardMaterial { .as_mut() .unwrap() .shader_defs - .push(String::from("STANDARDMATERIAL_NORMAL_MAP")); + .push("STANDARDMATERIAL_NORMAL_MAP".into()); } descriptor.primitive.cull_mode = key.bind_group_data.cull_mode; if let Some(label) = &mut descriptor.label { diff --git a/crates/bevy_pbr/src/render/clustered_forward.wgsl b/crates/bevy_pbr/src/render/clustered_forward.wgsl index 674f996dd5..c291aee2ff 100644 --- a/crates/bevy_pbr/src/render/clustered_forward.wgsl +++ b/crates/bevy_pbr/src/render/clustered_forward.wgsl @@ -29,7 +29,9 @@ fn fragment_cluster_index(frag_coord: vec2, view_z: f32, is_orthographic: b // this must match CLUSTER_COUNT_SIZE in light.rs let CLUSTER_COUNT_SIZE = 9u; fn unpack_offset_and_counts(cluster_index: u32) -> vec3 { -#ifdef NO_STORAGE_BUFFERS_SUPPORT +#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3 + return cluster_offsets_and_counts.data[cluster_index].xyz; +#else let offset_and_counts = cluster_offsets_and_counts.data[cluster_index >> 2u][cluster_index & ((1u << 2u) - 1u)]; // [ 31 .. 18 | 17 .. 9 | 8 .. 0 ] // [ offset | point light count | spot light count ] @@ -38,20 +40,18 @@ fn unpack_offset_and_counts(cluster_index: u32) -> vec3 { (offset_and_counts >> CLUSTER_COUNT_SIZE) & ((1u << CLUSTER_COUNT_SIZE) - 1u), offset_and_counts & ((1u << CLUSTER_COUNT_SIZE) - 1u), ); -#else - return cluster_offsets_and_counts.data[cluster_index].xyz; #endif } fn get_light_id(index: u32) -> u32 { -#ifdef NO_STORAGE_BUFFERS_SUPPORT +#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3 + return cluster_light_index_lists.data[index]; +#else // The index is correct but in cluster_light_index_lists we pack 4 u8s into a u32 // This means the index into cluster_light_index_lists is index / 4 let indices = cluster_light_index_lists.data[index >> 4u][(index >> 2u) & ((1u << 2u) - 1u)]; // And index % 4 gives the sub-index of the u8 within the u32 so we shift by 8 * sub-index return (indices >> (8u * (index & ((1u << 2u) - 1u)))) & ((1u << 8u) - 1u); -#else - return cluster_light_index_lists.data[index]; #endif } diff --git a/crates/bevy_pbr/src/render/light.rs b/crates/bevy_pbr/src/render/light.rs index b11aecff4a..f973c4f646 100644 --- a/crates/bevy_pbr/src/render/light.rs +++ b/crates/bevy_pbr/src/render/light.rs @@ -321,11 +321,15 @@ impl SpecializedMeshPipeline for ShadowPipeline { let mut bind_group_layout = vec![self.view_layout.clone()]; let mut shader_defs = Vec::new(); + shader_defs.push(ShaderDefVal::Int( + "MAX_DIRECTIONAL_LIGHTS".to_string(), + MAX_DIRECTIONAL_LIGHTS as i32, + )); if layout.contains(Mesh::ATTRIBUTE_JOINT_INDEX) && layout.contains(Mesh::ATTRIBUTE_JOINT_WEIGHT) { - shader_defs.push(String::from("SKINNED")); + shader_defs.push("SKINNED".into()); vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_INDEX.at_shader_location(4)); vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_WEIGHT.at_shader_location(5)); bind_group_layout.push(self.skinned_mesh_layout.clone()); diff --git a/crates/bevy_pbr/src/render/mesh.rs b/crates/bevy_pbr/src/render/mesh.rs index 73e7eea7b0..15fa70b81e 100644 --- a/crates/bevy_pbr/src/render/mesh.rs +++ b/crates/bevy_pbr/src/render/mesh.rs @@ -1,7 +1,7 @@ use crate::{ GlobalLightMeta, GpuLights, GpuPointLights, LightMeta, NotShadowCaster, NotShadowReceiver, ShadowPipeline, ViewClusterBindings, ViewLightsUniformOffset, ViewShadowBindings, - CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT, + CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT, MAX_DIRECTIONAL_LIGHTS, }; use bevy_app::Plugin; use bevy_asset::{load_internal_asset, Assets, Handle, HandleUntyped}; @@ -581,27 +581,32 @@ impl SpecializedMeshPipeline for MeshPipeline { let mut vertex_attributes = Vec::new(); if layout.contains(Mesh::ATTRIBUTE_POSITION) { - shader_defs.push(String::from("VERTEX_POSITIONS")); + shader_defs.push("VERTEX_POSITIONS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0)); } if layout.contains(Mesh::ATTRIBUTE_NORMAL) { - shader_defs.push(String::from("VERTEX_NORMALS")); + shader_defs.push("VERTEX_NORMALS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_NORMAL.at_shader_location(1)); } + shader_defs.push(ShaderDefVal::Int( + "MAX_DIRECTIONAL_LIGHTS".to_string(), + MAX_DIRECTIONAL_LIGHTS as i32, + )); + if layout.contains(Mesh::ATTRIBUTE_UV_0) { - shader_defs.push(String::from("VERTEX_UVS")); + shader_defs.push("VERTEX_UVS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_UV_0.at_shader_location(2)); } if layout.contains(Mesh::ATTRIBUTE_TANGENT) { - shader_defs.push(String::from("VERTEX_TANGENTS")); + shader_defs.push("VERTEX_TANGENTS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_TANGENT.at_shader_location(3)); } if layout.contains(Mesh::ATTRIBUTE_COLOR) { - shader_defs.push(String::from("VERTEX_COLORS")); + shader_defs.push("VERTEX_COLORS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_COLOR.at_shader_location(4)); } @@ -609,7 +614,7 @@ impl SpecializedMeshPipeline for MeshPipeline { if layout.contains(Mesh::ATTRIBUTE_JOINT_INDEX) && layout.contains(Mesh::ATTRIBUTE_JOINT_WEIGHT) { - shader_defs.push(String::from("SKINNED")); + shader_defs.push("SKINNED".into()); vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_INDEX.at_shader_location(5)); vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_WEIGHT.at_shader_location(6)); bind_group_layout.push(self.skinned_mesh_layout.clone()); @@ -636,11 +641,11 @@ impl SpecializedMeshPipeline for MeshPipeline { } if key.contains(MeshPipelineKey::TONEMAP_IN_SHADER) { - shader_defs.push("TONEMAP_IN_SHADER".to_string()); + shader_defs.push("TONEMAP_IN_SHADER".into()); // Debanding is tied to tonemapping in the shader, cannot run without it. if key.contains(MeshPipelineKey::DEBAND_DITHER) { - shader_defs.push("DEBAND_DITHER".to_string()); + shader_defs.push("DEBAND_DITHER".into()); } } diff --git a/crates/bevy_pbr/src/render/mesh_view_bindings.wgsl b/crates/bevy_pbr/src/render/mesh_view_bindings.wgsl index 9756ff05d6..cfe9ae87ef 100644 --- a/crates/bevy_pbr/src/render/mesh_view_bindings.wgsl +++ b/crates/bevy_pbr/src/render/mesh_view_bindings.wgsl @@ -25,20 +25,20 @@ var directional_shadow_textures: texture_depth_2d_array; @group(0) @binding(5) var directional_shadow_textures_sampler: sampler_comparison; -#ifdef NO_STORAGE_BUFFERS_SUPPORT -@group(0) @binding(6) -var point_lights: PointLights; -@group(0) @binding(7) -var cluster_light_index_lists: ClusterLightIndexLists; -@group(0) @binding(8) -var cluster_offsets_and_counts: ClusterOffsetsAndCounts; -#else +#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3 @group(0) @binding(6) var point_lights: PointLights; @group(0) @binding(7) var cluster_light_index_lists: ClusterLightIndexLists; @group(0) @binding(8) var cluster_offsets_and_counts: ClusterOffsetsAndCounts; +#else +@group(0) @binding(6) +var point_lights: PointLights; +@group(0) @binding(7) +var cluster_light_index_lists: ClusterLightIndexLists; +@group(0) @binding(8) +var cluster_offsets_and_counts: ClusterOffsetsAndCounts; #endif @group(0) @binding(9) diff --git a/crates/bevy_pbr/src/render/mesh_view_types.wgsl b/crates/bevy_pbr/src/render/mesh_view_types.wgsl index ab595b9e01..c6316f5965 100644 --- a/crates/bevy_pbr/src/render/mesh_view_types.wgsl +++ b/crates/bevy_pbr/src/render/mesh_view_types.wgsl @@ -42,7 +42,7 @@ let DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT: u32 = 1u; struct Lights { // NOTE: this array size must be kept in sync with the constants defined in bevy_pbr/src/render/light.rs - directional_lights: array, + directional_lights: array, ambient_color: vec4, // x/y/z dimensions and n_clusters in w cluster_dimensions: vec4, @@ -61,7 +61,17 @@ struct Lights { spot_light_shadowmap_offset: i32, }; -#ifdef NO_STORAGE_BUFFERS_SUPPORT +#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3 +struct PointLights { + data: array, +}; +struct ClusterLightIndexLists { + data: array, +}; +struct ClusterOffsetsAndCounts { + data: array>, +}; +#else struct PointLights { data: array, }; @@ -74,16 +84,6 @@ struct ClusterOffsetsAndCounts { // and an 8-bit count of the number of lights in the low 8 bits data: array, 1024u>, }; -#else -struct PointLights { - data: array, -}; -struct ClusterLightIndexLists { - data: array, -}; -struct ClusterOffsetsAndCounts { - data: array>, -}; #endif struct Globals { diff --git a/crates/bevy_render/src/render_resource/pipeline.rs b/crates/bevy_render/src/render_resource/pipeline.rs index b63febfd53..3a379b6052 100644 --- a/crates/bevy_render/src/render_resource/pipeline.rs +++ b/crates/bevy_render/src/render_resource/pipeline.rs @@ -7,6 +7,7 @@ use wgpu::{ VertexAttribute, VertexFormat, VertexStepMode, }; +use super::ShaderDefVal; use crate::render_resource::resource_macros::*; /// A [`RenderPipeline`] identifier. @@ -115,7 +116,7 @@ pub struct RenderPipelineDescriptor { pub struct VertexState { /// The compiled shader module for this stage. pub shader: Handle, - pub shader_defs: Vec, + pub shader_defs: Vec, /// The name of the entry point in the compiled shader. There must be a /// function with this name in the shader. pub entry_point: Cow<'static, str>, @@ -167,7 +168,7 @@ impl VertexBufferLayout { pub struct FragmentState { /// The compiled shader module for this stage. pub shader: Handle, - pub shader_defs: Vec, + pub shader_defs: Vec, /// The name of the entry point in the compiled shader. There must be a /// function with this name in the shader. pub entry_point: Cow<'static, str>, @@ -182,7 +183,7 @@ pub struct ComputePipelineDescriptor { pub layout: Option>, /// The compiled shader module for this stage. pub shader: Handle, - pub shader_defs: Vec, + pub shader_defs: Vec, /// The name of the entry point in the compiled shader. There must be a /// function with this name in the shader. pub entry_point: Cow<'static, str>, diff --git a/crates/bevy_render/src/render_resource/pipeline_cache.rs b/crates/bevy_render/src/render_resource/pipeline_cache.rs index 659844d93e..ce5faf9099 100644 --- a/crates/bevy_render/src/render_resource/pipeline_cache.rs +++ b/crates/bevy_render/src/render_resource/pipeline_cache.rs @@ -107,7 +107,7 @@ impl CachedPipelineState { #[derive(Default)] struct ShaderData { pipelines: HashSet, - processed_shaders: HashMap, ErasedShaderModule>, + processed_shaders: HashMap, ErasedShaderModule>, resolved_imports: HashMap>, dependents: HashSet>, } @@ -121,13 +121,31 @@ struct ShaderCache { processor: ShaderProcessor, } +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +pub enum ShaderDefVal { + Bool(String, bool), + Int(String, i32), +} + +impl From<&str> for ShaderDefVal { + fn from(key: &str) -> Self { + ShaderDefVal::Bool(key.to_string(), true) + } +} + +impl From for ShaderDefVal { + fn from(key: String) -> Self { + ShaderDefVal::Bool(key, true) + } +} + impl ShaderCache { fn get( &mut self, render_device: &RenderDevice, pipeline: CachedPipelineId, handle: &Handle, - shader_defs: &[String], + shader_defs: &[ShaderDefVal], ) -> Result { let shader = self .shaders @@ -156,21 +174,20 @@ impl ShaderCache { let mut shader_defs = shader_defs.to_vec(); #[cfg(feature = "webgl")] { - shader_defs.push(String::from("NO_ARRAY_TEXTURES_SUPPORT")); - shader_defs.push(String::from("SIXTEEN_BYTE_ALIGNMENT")); + shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into()); + shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into()); } - // TODO: 3 is the value from CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT declared in bevy_pbr - // consider exposing this in shaders in a more generally useful way, such as: - // # if AVAILABLE_STORAGE_BUFFER_BINDINGS == 3 - // /* use storage buffers here */ - // # elif - // /* use uniforms here */ - if !matches!( + // 3 is the value from CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT declared in bevy_pbr + // Using the value directly here to avoid the cyclic dependency + if matches!( render_device.get_supported_read_only_binding_type(3), BufferBindingType::Storage { .. } ) { - shader_defs.push(String::from("NO_STORAGE_BUFFERS_SUPPORT")); + shader_defs.push(ShaderDefVal::Int( + String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"), + 3, + )); } debug!( diff --git a/crates/bevy_render/src/render_resource/shader.rs b/crates/bevy_render/src/render_resource/shader.rs index b216c44558..74df14a9c7 100644 --- a/crates/bevy_render/src/render_resource/shader.rs +++ b/crates/bevy_render/src/render_resource/shader.rs @@ -6,13 +6,13 @@ use naga::valid::Capabilities; use naga::{valid::ModuleInfo, Module}; use once_cell::sync::Lazy; use regex::Regex; -use std::{ - borrow::Cow, collections::HashSet, marker::Copy, ops::Deref, path::PathBuf, str::FromStr, -}; +use std::{borrow::Cow, marker::Copy, ops::Deref, path::PathBuf, str::FromStr}; use thiserror::Error; use wgpu::Features; use wgpu::{util::make_spirv, ShaderModuleDescriptor, ShaderSource}; +use super::ShaderDefVal; + #[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)] pub struct ShaderId(Uuid); @@ -297,12 +297,24 @@ pub enum ProcessShaderError { NotEnoughEndIfs, #[error("This Shader's format does not support processing shader defs.")] ShaderFormatDoesNotSupportShaderDefs, - #[error("This Shader's formatdoes not support imports.")] + #[error("This Shader's format does not support imports.")] ShaderFormatDoesNotSupportImports, #[error("Unresolved import: {0:?}.")] UnresolvedImport(ShaderImport), #[error("The shader import {0:?} does not match the source file type. Support for this might be added in the future.")] MismatchedImportFormat(ShaderImport), + #[error("Unknown shader def operator: '{operator}'")] + UnknownShaderDefOperator { operator: String }, + #[error("Unknown shader def: '{shader_def_name}'")] + UnknownShaderDef { shader_def_name: String }, + #[error( + "Invalid shader def comparison for '{shader_def_name}': expected {expected}, got {value}" + )] + InvalidShaderDefComparisonValue { + shader_def_name: String, + expected: String, + value: String, + }, } pub struct ShaderImportProcessor { @@ -371,8 +383,11 @@ pub static SHADER_IMPORT_PROCESSOR: Lazy = pub struct ShaderProcessor { ifdef_regex: Regex, ifndef_regex: Regex, + ifop_regex: Regex, else_regex: Regex, endif_regex: Regex, + def_regex: Regex, + def_regex_delimited: Regex, } impl Default for ShaderProcessor { @@ -380,8 +395,11 @@ impl Default for ShaderProcessor { Self { ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(), ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(), + ifop_regex: Regex::new(r"^\s*#\s*if\s*([\w|\d|_]+)\s*([^\s]*)\s*([\w|\d]+)").unwrap(), else_regex: Regex::new(r"^\s*#\s*else").unwrap(), endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(), + def_regex: Regex::new(r"#\s*([\w|\d|_]+)").unwrap(), + def_regex_delimited: Regex::new(r"#\s*\{([\w|\d|_]+)\}").unwrap(), } } } @@ -390,7 +408,7 @@ impl ShaderProcessor { pub fn process( &self, shader: &Shader, - shader_defs: &[String], + shader_defs: &[ShaderDefVal], shaders: &HashMap, Shader>, import_handles: &HashMap>, ) -> Result { @@ -405,16 +423,69 @@ impl ShaderProcessor { } }; - let shader_defs_unique = HashSet::::from_iter(shader_defs.iter().cloned()); + let shader_defs_unique = + HashMap::::from_iter(shader_defs.iter().map(|v| match v { + ShaderDefVal::Bool(k, _) | ShaderDefVal::Int(k, _) => (k.clone(), v.clone()), + })); let mut scopes = vec![true]; let mut final_string = String::new(); for line in shader_str.lines() { if let Some(cap) = self.ifdef_regex.captures(line) { let def = cap.get(1).unwrap(); - scopes.push(*scopes.last().unwrap() && shader_defs_unique.contains(def.as_str())); + scopes + .push(*scopes.last().unwrap() && shader_defs_unique.contains_key(def.as_str())); } else if let Some(cap) = self.ifndef_regex.captures(line) { let def = cap.get(1).unwrap(); - scopes.push(*scopes.last().unwrap() && !shader_defs_unique.contains(def.as_str())); + scopes.push( + *scopes.last().unwrap() && !shader_defs_unique.contains_key(def.as_str()), + ); + } else if let Some(cap) = self.ifop_regex.captures(line) { + let def = cap.get(1).unwrap(); + let op = cap.get(2).unwrap(); + let val = cap.get(3).unwrap(); + + fn act_on(a: T, b: T, op: &str) -> Result { + match op { + "==" => Ok(a == b), + "!=" => Ok(a != b), + ">" => Ok(a > b), + ">=" => Ok(a >= b), + "<" => Ok(a < b), + "<=" => Ok(a <= b), + _ => Err(ProcessShaderError::UnknownShaderDefOperator { + operator: op.to_string(), + }), + } + } + + let def = shader_defs_unique.get(def.as_str()).ok_or( + ProcessShaderError::UnknownShaderDef { + shader_def_name: def.as_str().to_string(), + }, + )?; + let new_scope = match def { + ShaderDefVal::Bool(name, def) => { + let val = val.as_str().parse().map_err(|_| { + ProcessShaderError::InvalidShaderDefComparisonValue { + shader_def_name: name.clone(), + value: val.as_str().to_string(), + expected: "bool".to_string(), + } + })?; + act_on(*def, val, op.as_str())? + } + ShaderDefVal::Int(name, def) => { + let val = val.as_str().parse().map_err(|_| { + ProcessShaderError::InvalidShaderDefComparisonValue { + shader_def_name: name.clone(), + value: val.as_str().to_string(), + expected: "int".to_string(), + } + })?; + act_on(*def, val, op.as_str())? + } + }; + scopes.push(*scopes.last().unwrap() && new_scope); } else if self.else_regex.is_match(line) { let mut is_parent_scope_truthy = true; if scopes.len() > 1 { @@ -461,7 +532,32 @@ impl ShaderProcessor { { // ignore import path lines } else { - final_string.push_str(line); + let mut line_with_defs = line.to_string(); + for capture in self.def_regex.captures_iter(line) { + let def = capture.get(1).unwrap(); + if let Some(def) = shader_defs_unique.get(def.as_str()) { + let def = match def { + ShaderDefVal::Bool(_, def) => def.to_string(), + ShaderDefVal::Int(_, def) => def.to_string(), + }; + line_with_defs = + self.def_regex.replace(&line_with_defs, def).to_string(); + } + } + for capture in self.def_regex_delimited.captures_iter(line) { + let def = capture.get(1).unwrap(); + if let Some(def) = shader_defs_unique.get(def.as_str()) { + let def = match def { + ShaderDefVal::Bool(_, def) => def.to_string(), + ShaderDefVal::Int(_, def) => def.to_string(), + }; + line_with_defs = self + .def_regex_delimited + .replace(&line_with_defs, def) + .to_string(); + } + } + final_string.push_str(&line_with_defs); final_string.push('\n'); } } @@ -488,7 +584,7 @@ impl ShaderProcessor { shaders: &HashMap, Shader>, import: &ShaderImport, shader: &Shader, - shader_defs: &[String], + shader_defs: &[ShaderDefVal], final_string: &mut String, ) -> Result<(), ProcessShaderError> { let imported_shader = import_handles @@ -557,7 +653,9 @@ mod tests { use bevy_utils::HashMap; use naga::ShaderStage; - use crate::render_resource::{ProcessShaderError, Shader, ShaderImport, ShaderProcessor}; + use crate::render_resource::{ + ProcessShaderError, Shader, ShaderDefVal, ShaderImport, ShaderProcessor, + }; #[rustfmt::skip] const WGSL: &str = r" struct View { @@ -723,7 +821,7 @@ fn vertex( let result = processor .process( &Shader::from_wgsl(WGSL), - &["TEXTURE".to_string()], + &["TEXTURE".into()], &HashMap::default(), &HashMap::default(), ) @@ -968,7 +1066,7 @@ fn vertex( let result = processor .process( &Shader::from_wgsl(WGSL_NESTED_IFDEF), - &["TEXTURE".to_string()], + &["TEXTURE".into()], &HashMap::default(), &HashMap::default(), ) @@ -1010,7 +1108,7 @@ fn vertex( let result = processor .process( &Shader::from_wgsl(WGSL_NESTED_IFDEF_ELSE), - &["TEXTURE".to_string()], + &["TEXTURE".into()], &HashMap::default(), &HashMap::default(), ) @@ -1130,7 +1228,7 @@ fn vertex( let result = processor .process( &Shader::from_wgsl(WGSL_NESTED_IFDEF), - &["ATTRIBUTE".to_string()], + &["ATTRIBUTE".into()], &HashMap::default(), &HashMap::default(), ) @@ -1172,7 +1270,7 @@ fn vertex( let result = processor .process( &Shader::from_wgsl(WGSL_NESTED_IFDEF), - &["TEXTURE".to_string(), "ATTRIBUTE".to_string()], + &["TEXTURE".into(), "ATTRIBUTE".into()], &HashMap::default(), &HashMap::default(), ) @@ -1219,7 +1317,7 @@ fn in_main_present() { } let result = processor .process( &Shader::from_wgsl(INPUT), - &["MAIN_PRESENT".to_string(), "IMPORT_PRESENT".to_string()], + &["MAIN_PRESENT".into(), "IMPORT_PRESENT".into()], &shaders, &import_handles, ) @@ -1274,7 +1372,7 @@ fn in_main() { } let result = processor .process( &Shader::from_wgsl(INPUT), - &["DEEP".to_string()], + &["DEEP".into()], &shaders, &import_handles, ) @@ -1332,7 +1430,7 @@ fn baz() { } let result = processor .process( &Shader::from_wgsl(INPUT), - &["FOO".to_string()], + &["FOO".into()], &shaders, &import_handles, ) @@ -1344,4 +1442,515 @@ fn baz() { } .unwrap(); assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED); } + + #[test] + fn process_shader_def_unknown_operator() { + #[rustfmt::skip] + const WGSL: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +#if TEXTURE !! true +@group(1) @binding(0) +var sprite_texture: texture_2d; +#endif + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + let processor = ShaderProcessor::default(); + + let result_missing = processor.process( + &Shader::from_wgsl(WGSL), + &["TEXTURE".into()], + &HashMap::default(), + &HashMap::default(), + ); + assert_eq!( + result_missing, + Err(ProcessShaderError::UnknownShaderDefOperator { + operator: "!!".to_string() + }) + ); + } + #[test] + fn process_shader_def_equal_int() { + #[rustfmt::skip] + const WGSL: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +#if TEXTURE == 3 +@group(1) @binding(0) +var sprite_texture: texture_2d; +#endif + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_EQ: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +@group(1) @binding(0) +var sprite_texture: texture_2d; + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_NEQ: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + let processor = ShaderProcessor::default(); + let result_eq = processor + .process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Int("TEXTURE".to_string(), 3)], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result_eq.get_wgsl_source().unwrap(), EXPECTED_EQ); + + let result_neq = processor + .process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Int("TEXTURE".to_string(), 7)], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result_neq.get_wgsl_source().unwrap(), EXPECTED_NEQ); + + let result_missing = processor.process( + &Shader::from_wgsl(WGSL), + &[], + &HashMap::default(), + &HashMap::default(), + ); + assert_eq!( + result_missing, + Err(ProcessShaderError::UnknownShaderDef { + shader_def_name: "TEXTURE".to_string() + }) + ); + + let result_wrong_type = processor.process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Bool("TEXTURE".to_string(), true)], + &HashMap::default(), + &HashMap::default(), + ); + assert_eq!( + result_wrong_type, + Err(ProcessShaderError::InvalidShaderDefComparisonValue { + shader_def_name: "TEXTURE".to_string(), + expected: "bool".to_string(), + value: "3".to_string() + }) + ); + } + + #[test] + fn process_shader_def_equal_bool() { + #[rustfmt::skip] + const WGSL: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +#if TEXTURE == true +@group(1) @binding(0) +var sprite_texture: texture_2d; +#endif + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_EQ: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +@group(1) @binding(0) +var sprite_texture: texture_2d; + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_NEQ: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + let processor = ShaderProcessor::default(); + let result_eq = processor + .process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Bool("TEXTURE".to_string(), true)], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result_eq.get_wgsl_source().unwrap(), EXPECTED_EQ); + + let result_neq = processor + .process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Bool("TEXTURE".to_string(), false)], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result_neq.get_wgsl_source().unwrap(), EXPECTED_NEQ); + } + + #[test] + fn process_shader_def_not_equal_bool() { + #[rustfmt::skip] + const WGSL: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +#if TEXTURE != false +@group(1) @binding(0) +var sprite_texture: texture_2d; +#endif + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_EQ: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +@group(1) @binding(0) +var sprite_texture: texture_2d; + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_NEQ: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + let processor = ShaderProcessor::default(); + let result_eq = processor + .process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Bool("TEXTURE".to_string(), true)], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result_eq.get_wgsl_source().unwrap(), EXPECTED_EQ); + + let result_neq = processor + .process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Bool("TEXTURE".to_string(), false)], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result_neq.get_wgsl_source().unwrap(), EXPECTED_NEQ); + + let result_missing = processor.process( + &Shader::from_wgsl(WGSL), + &[], + &HashMap::default(), + &HashMap::default(), + ); + assert_eq!( + result_missing, + Err(ProcessShaderError::UnknownShaderDef { + shader_def_name: "TEXTURE".to_string() + }) + ); + + let result_wrong_type = processor.process( + &Shader::from_wgsl(WGSL), + &[ShaderDefVal::Int("TEXTURE".to_string(), 7)], + &HashMap::default(), + &HashMap::default(), + ); + assert_eq!( + result_wrong_type, + Err(ProcessShaderError::InvalidShaderDefComparisonValue { + shader_def_name: "TEXTURE".to_string(), + expected: "int".to_string(), + value: "false".to_string() + }) + ); + } + + #[test] + fn process_shader_def_replace() { + #[rustfmt::skip] + const WGSL: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + var a: i32 = #FIRST_VALUE; + var b: i32 = #FIRST_VALUE * #SECOND_VALUE; + var c: i32 = #MISSING_VALUE; + var d: bool = #BOOL_VALUE; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + + #[rustfmt::skip] + const EXPECTED_REPLACED: &str = r" +struct View { + view_proj: mat4x4, + world_position: vec3, +}; +@group(0) @binding(0) +var view: View; + +struct VertexOutput { + @location(0) uv: vec2, + @builtin(position) position: vec4, +}; + +@vertex +fn vertex( + @location(0) vertex_position: vec3, + @location(1) vertex_uv: vec2 +) -> VertexOutput { + var out: VertexOutput; + out.uv = vertex_uv; + var a: i32 = 5; + var b: i32 = 5 * 3; + var c: i32 = #MISSING_VALUE; + var d: bool = true; + out.position = view.view_proj * vec4(vertex_position, 1.0); + return out; +} +"; + let processor = ShaderProcessor::default(); + let result = processor + .process( + &Shader::from_wgsl(WGSL), + &[ + ShaderDefVal::Bool("BOOL_VALUE".to_string(), true), + ShaderDefVal::Int("FIRST_VALUE".to_string(), 5), + ShaderDefVal::Int("SECOND_VALUE".to_string(), 3), + ], + &HashMap::default(), + &HashMap::default(), + ) + .unwrap(); + assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED_REPLACED); + } } diff --git a/crates/bevy_sprite/src/mesh2d/mesh.rs b/crates/bevy_sprite/src/mesh2d/mesh.rs index 5483c2b98d..0910463f26 100644 --- a/crates/bevy_sprite/src/mesh2d/mesh.rs +++ b/crates/bevy_sprite/src/mesh2d/mesh.rs @@ -351,36 +351,36 @@ impl SpecializedMeshPipeline for Mesh2dPipeline { let mut vertex_attributes = Vec::new(); if layout.contains(Mesh::ATTRIBUTE_POSITION) { - shader_defs.push(String::from("VERTEX_POSITIONS")); + shader_defs.push("VERTEX_POSITIONS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0)); } if layout.contains(Mesh::ATTRIBUTE_NORMAL) { - shader_defs.push(String::from("VERTEX_NORMALS")); + shader_defs.push("VERTEX_NORMALS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_NORMAL.at_shader_location(1)); } if layout.contains(Mesh::ATTRIBUTE_UV_0) { - shader_defs.push(String::from("VERTEX_UVS")); + shader_defs.push("VERTEX_UVS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_UV_0.at_shader_location(2)); } if layout.contains(Mesh::ATTRIBUTE_TANGENT) { - shader_defs.push(String::from("VERTEX_TANGENTS")); + shader_defs.push("VERTEX_TANGENTS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_TANGENT.at_shader_location(3)); } if layout.contains(Mesh::ATTRIBUTE_COLOR) { - shader_defs.push(String::from("VERTEX_COLORS")); + shader_defs.push("VERTEX_COLORS".into()); vertex_attributes.push(Mesh::ATTRIBUTE_COLOR.at_shader_location(4)); } if key.contains(Mesh2dPipelineKey::TONEMAP_IN_SHADER) { - shader_defs.push("TONEMAP_IN_SHADER".to_string()); + shader_defs.push("TONEMAP_IN_SHADER".into()); // Debanding is tied to tonemapping in the shader, cannot run without it. if key.contains(Mesh2dPipelineKey::DEBAND_DITHER) { - shader_defs.push("DEBAND_DITHER".to_string()); + shader_defs.push("DEBAND_DITHER".into()); } } diff --git a/crates/bevy_sprite/src/render/mod.rs b/crates/bevy_sprite/src/render/mod.rs index 8b71558813..d99fdc8d4d 100644 --- a/crates/bevy_sprite/src/render/mod.rs +++ b/crates/bevy_sprite/src/render/mod.rs @@ -208,15 +208,15 @@ impl SpecializedRenderPipeline for SpritePipeline { let mut shader_defs = Vec::new(); if key.contains(SpritePipelineKey::COLORED) { - shader_defs.push("COLORED".to_string()); + shader_defs.push("COLORED".into()); } if key.contains(SpritePipelineKey::TONEMAP_IN_SHADER) { - shader_defs.push("TONEMAP_IN_SHADER".to_string()); + shader_defs.push("TONEMAP_IN_SHADER".into()); // Debanding is tied to tonemapping in the shader, cannot run without it. if key.contains(SpritePipelineKey::DEBAND_DITHER) { - shader_defs.push("DEBAND_DITHER".to_string()); + shader_defs.push("DEBAND_DITHER".into()); } } diff --git a/examples/shader/shader_defs.rs b/examples/shader/shader_defs.rs index ca9d953d8f..4735406ace 100644 --- a/examples/shader/shader_defs.rs +++ b/examples/shader/shader_defs.rs @@ -68,7 +68,7 @@ impl Material for CustomMaterial { ) -> Result<(), SpecializedMeshPipelineError> { if key.bind_group_data.is_red { let fragment = descriptor.fragment.as_mut().unwrap(); - fragment.shader_defs.push("IS_RED".to_string()); + fragment.shader_defs.push("IS_RED".into()); } Ok(()) }