Shader defs can now have a value (#5900)

# Objective

- shaders defs can now have a `bool` or `int` value
- `#if SHADER_DEF <operator> 3`
  - ok if `SHADER_DEF` is defined, has the correct type and pass the comparison
  - `==`, `!=`, `>=`, `>`, `<`, `<=` supported
- `#SHADER_DEF` or `#{SHADER_DEF}`
  - will be replaced by the value in the shader code
---

## Migration Guide

- replace `shader_defs.push(String::from("NAME"));` by `shader_defs.push("NAME".into());`
- if you used shader def `NO_STORAGE_BUFFERS_SUPPORT`, check how `AVAILABLE_STORAGE_BUFFER_BINDINGS` is now used in Bevy default shaders
This commit is contained in:
François 2022-11-21 22:38:29 +00:00
parent daa57fe489
commit d44e86507f
14 changed files with 721 additions and 85 deletions

View file

@ -203,8 +203,8 @@ impl SpecializedRenderPipeline for FxaaPipeline {
fragment: Some(FragmentState {
shader: FXAA_SHADER_HANDLE.typed(),
shader_defs: vec![
format!("EDGE_THRESH_{}", key.edge_threshold.get_str()),
format!("EDGE_THRESH_MIN_{}", key.edge_threshold_min.get_str()),
format!("EDGE_THRESH_{}", key.edge_threshold.get_str()).into(),
format!("EDGE_THRESH_MIN_{}", key.edge_threshold_min.get_str()).into(),
],
entry_point: "fragment".into(),
targets: vec![Some(ColorTargetState {

View file

@ -66,7 +66,7 @@ impl SpecializedRenderPipeline for TonemappingPipeline {
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
let mut shader_defs = Vec::new();
if key.deband_dither {
shader_defs.push("DEBAND_DITHER".to_string());
shader_defs.push("DEBAND_DITHER".into());
}
RenderPipelineDescriptor {
label: Some("tonemapping pipeline".into()),

View file

@ -419,7 +419,7 @@ impl Material for StandardMaterial {
.as_mut()
.unwrap()
.shader_defs
.push(String::from("STANDARDMATERIAL_NORMAL_MAP"));
.push("STANDARDMATERIAL_NORMAL_MAP".into());
}
descriptor.primitive.cull_mode = key.bind_group_data.cull_mode;
if let Some(label) = &mut descriptor.label {

View file

@ -29,7 +29,9 @@ fn fragment_cluster_index(frag_coord: vec2<f32>, view_z: f32, is_orthographic: b
// this must match CLUSTER_COUNT_SIZE in light.rs
let CLUSTER_COUNT_SIZE = 9u;
fn unpack_offset_and_counts(cluster_index: u32) -> vec3<u32> {
#ifdef NO_STORAGE_BUFFERS_SUPPORT
#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3
return cluster_offsets_and_counts.data[cluster_index].xyz;
#else
let offset_and_counts = cluster_offsets_and_counts.data[cluster_index >> 2u][cluster_index & ((1u << 2u) - 1u)];
// [ 31 .. 18 | 17 .. 9 | 8 .. 0 ]
// [ offset | point light count | spot light count ]
@ -38,20 +40,18 @@ fn unpack_offset_and_counts(cluster_index: u32) -> vec3<u32> {
(offset_and_counts >> CLUSTER_COUNT_SIZE) & ((1u << CLUSTER_COUNT_SIZE) - 1u),
offset_and_counts & ((1u << CLUSTER_COUNT_SIZE) - 1u),
);
#else
return cluster_offsets_and_counts.data[cluster_index].xyz;
#endif
}
fn get_light_id(index: u32) -> u32 {
#ifdef NO_STORAGE_BUFFERS_SUPPORT
#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3
return cluster_light_index_lists.data[index];
#else
// The index is correct but in cluster_light_index_lists we pack 4 u8s into a u32
// This means the index into cluster_light_index_lists is index / 4
let indices = cluster_light_index_lists.data[index >> 4u][(index >> 2u) & ((1u << 2u) - 1u)];
// And index % 4 gives the sub-index of the u8 within the u32 so we shift by 8 * sub-index
return (indices >> (8u * (index & ((1u << 2u) - 1u)))) & ((1u << 8u) - 1u);
#else
return cluster_light_index_lists.data[index];
#endif
}

View file

@ -321,11 +321,15 @@ impl SpecializedMeshPipeline for ShadowPipeline {
let mut bind_group_layout = vec![self.view_layout.clone()];
let mut shader_defs = Vec::new();
shader_defs.push(ShaderDefVal::Int(
"MAX_DIRECTIONAL_LIGHTS".to_string(),
MAX_DIRECTIONAL_LIGHTS as i32,
));
if layout.contains(Mesh::ATTRIBUTE_JOINT_INDEX)
&& layout.contains(Mesh::ATTRIBUTE_JOINT_WEIGHT)
{
shader_defs.push(String::from("SKINNED"));
shader_defs.push("SKINNED".into());
vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_INDEX.at_shader_location(4));
vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_WEIGHT.at_shader_location(5));
bind_group_layout.push(self.skinned_mesh_layout.clone());

View file

@ -1,7 +1,7 @@
use crate::{
GlobalLightMeta, GpuLights, GpuPointLights, LightMeta, NotShadowCaster, NotShadowReceiver,
ShadowPipeline, ViewClusterBindings, ViewLightsUniformOffset, ViewShadowBindings,
CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT,
CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT, MAX_DIRECTIONAL_LIGHTS,
};
use bevy_app::Plugin;
use bevy_asset::{load_internal_asset, Assets, Handle, HandleUntyped};
@ -581,27 +581,32 @@ impl SpecializedMeshPipeline for MeshPipeline {
let mut vertex_attributes = Vec::new();
if layout.contains(Mesh::ATTRIBUTE_POSITION) {
shader_defs.push(String::from("VERTEX_POSITIONS"));
shader_defs.push("VERTEX_POSITIONS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0));
}
if layout.contains(Mesh::ATTRIBUTE_NORMAL) {
shader_defs.push(String::from("VERTEX_NORMALS"));
shader_defs.push("VERTEX_NORMALS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_NORMAL.at_shader_location(1));
}
shader_defs.push(ShaderDefVal::Int(
"MAX_DIRECTIONAL_LIGHTS".to_string(),
MAX_DIRECTIONAL_LIGHTS as i32,
));
if layout.contains(Mesh::ATTRIBUTE_UV_0) {
shader_defs.push(String::from("VERTEX_UVS"));
shader_defs.push("VERTEX_UVS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_UV_0.at_shader_location(2));
}
if layout.contains(Mesh::ATTRIBUTE_TANGENT) {
shader_defs.push(String::from("VERTEX_TANGENTS"));
shader_defs.push("VERTEX_TANGENTS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_TANGENT.at_shader_location(3));
}
if layout.contains(Mesh::ATTRIBUTE_COLOR) {
shader_defs.push(String::from("VERTEX_COLORS"));
shader_defs.push("VERTEX_COLORS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_COLOR.at_shader_location(4));
}
@ -609,7 +614,7 @@ impl SpecializedMeshPipeline for MeshPipeline {
if layout.contains(Mesh::ATTRIBUTE_JOINT_INDEX)
&& layout.contains(Mesh::ATTRIBUTE_JOINT_WEIGHT)
{
shader_defs.push(String::from("SKINNED"));
shader_defs.push("SKINNED".into());
vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_INDEX.at_shader_location(5));
vertex_attributes.push(Mesh::ATTRIBUTE_JOINT_WEIGHT.at_shader_location(6));
bind_group_layout.push(self.skinned_mesh_layout.clone());
@ -636,11 +641,11 @@ impl SpecializedMeshPipeline for MeshPipeline {
}
if key.contains(MeshPipelineKey::TONEMAP_IN_SHADER) {
shader_defs.push("TONEMAP_IN_SHADER".to_string());
shader_defs.push("TONEMAP_IN_SHADER".into());
// Debanding is tied to tonemapping in the shader, cannot run without it.
if key.contains(MeshPipelineKey::DEBAND_DITHER) {
shader_defs.push("DEBAND_DITHER".to_string());
shader_defs.push("DEBAND_DITHER".into());
}
}

View file

@ -25,20 +25,20 @@ var directional_shadow_textures: texture_depth_2d_array;
@group(0) @binding(5)
var directional_shadow_textures_sampler: sampler_comparison;
#ifdef NO_STORAGE_BUFFERS_SUPPORT
@group(0) @binding(6)
var<uniform> point_lights: PointLights;
@group(0) @binding(7)
var<uniform> cluster_light_index_lists: ClusterLightIndexLists;
@group(0) @binding(8)
var<uniform> cluster_offsets_and_counts: ClusterOffsetsAndCounts;
#else
#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3
@group(0) @binding(6)
var<storage> point_lights: PointLights;
@group(0) @binding(7)
var<storage> cluster_light_index_lists: ClusterLightIndexLists;
@group(0) @binding(8)
var<storage> cluster_offsets_and_counts: ClusterOffsetsAndCounts;
#else
@group(0) @binding(6)
var<uniform> point_lights: PointLights;
@group(0) @binding(7)
var<uniform> cluster_light_index_lists: ClusterLightIndexLists;
@group(0) @binding(8)
var<uniform> cluster_offsets_and_counts: ClusterOffsetsAndCounts;
#endif
@group(0) @binding(9)

View file

@ -42,7 +42,7 @@ let DIRECTIONAL_LIGHT_FLAGS_SHADOWS_ENABLED_BIT: u32 = 1u;
struct Lights {
// NOTE: this array size must be kept in sync with the constants defined in bevy_pbr/src/render/light.rs
directional_lights: array<DirectionalLight, 10u>,
directional_lights: array<DirectionalLight, #{MAX_DIRECTIONAL_LIGHTS}u>,
ambient_color: vec4<f32>,
// x/y/z dimensions and n_clusters in w
cluster_dimensions: vec4<u32>,
@ -61,7 +61,17 @@ struct Lights {
spot_light_shadowmap_offset: i32,
};
#ifdef NO_STORAGE_BUFFERS_SUPPORT
#if AVAILABLE_STORAGE_BUFFER_BINDINGS >= 3
struct PointLights {
data: array<PointLight>,
};
struct ClusterLightIndexLists {
data: array<u32>,
};
struct ClusterOffsetsAndCounts {
data: array<vec4<u32>>,
};
#else
struct PointLights {
data: array<PointLight, 256u>,
};
@ -74,16 +84,6 @@ struct ClusterOffsetsAndCounts {
// and an 8-bit count of the number of lights in the low 8 bits
data: array<vec4<u32>, 1024u>,
};
#else
struct PointLights {
data: array<PointLight>,
};
struct ClusterLightIndexLists {
data: array<u32>,
};
struct ClusterOffsetsAndCounts {
data: array<vec4<u32>>,
};
#endif
struct Globals {

View file

@ -7,6 +7,7 @@ use wgpu::{
VertexAttribute, VertexFormat, VertexStepMode,
};
use super::ShaderDefVal;
use crate::render_resource::resource_macros::*;
/// A [`RenderPipeline`] identifier.
@ -115,7 +116,7 @@ pub struct RenderPipelineDescriptor {
pub struct VertexState {
/// The compiled shader module for this stage.
pub shader: Handle<Shader>,
pub shader_defs: Vec<String>,
pub shader_defs: Vec<ShaderDefVal>,
/// The name of the entry point in the compiled shader. There must be a
/// function with this name in the shader.
pub entry_point: Cow<'static, str>,
@ -167,7 +168,7 @@ impl VertexBufferLayout {
pub struct FragmentState {
/// The compiled shader module for this stage.
pub shader: Handle<Shader>,
pub shader_defs: Vec<String>,
pub shader_defs: Vec<ShaderDefVal>,
/// The name of the entry point in the compiled shader. There must be a
/// function with this name in the shader.
pub entry_point: Cow<'static, str>,
@ -182,7 +183,7 @@ pub struct ComputePipelineDescriptor {
pub layout: Option<Vec<BindGroupLayout>>,
/// The compiled shader module for this stage.
pub shader: Handle<Shader>,
pub shader_defs: Vec<String>,
pub shader_defs: Vec<ShaderDefVal>,
/// The name of the entry point in the compiled shader. There must be a
/// function with this name in the shader.
pub entry_point: Cow<'static, str>,

View file

@ -107,7 +107,7 @@ impl CachedPipelineState {
#[derive(Default)]
struct ShaderData {
pipelines: HashSet<CachedPipelineId>,
processed_shaders: HashMap<Vec<String>, ErasedShaderModule>,
processed_shaders: HashMap<Vec<ShaderDefVal>, ErasedShaderModule>,
resolved_imports: HashMap<ShaderImport, Handle<Shader>>,
dependents: HashSet<Handle<Shader>>,
}
@ -121,13 +121,31 @@ struct ShaderCache {
processor: ShaderProcessor,
}
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
pub enum ShaderDefVal {
Bool(String, bool),
Int(String, i32),
}
impl From<&str> for ShaderDefVal {
fn from(key: &str) -> Self {
ShaderDefVal::Bool(key.to_string(), true)
}
}
impl From<String> for ShaderDefVal {
fn from(key: String) -> Self {
ShaderDefVal::Bool(key, true)
}
}
impl ShaderCache {
fn get(
&mut self,
render_device: &RenderDevice,
pipeline: CachedPipelineId,
handle: &Handle<Shader>,
shader_defs: &[String],
shader_defs: &[ShaderDefVal],
) -> Result<ErasedShaderModule, PipelineCacheError> {
let shader = self
.shaders
@ -156,21 +174,20 @@ impl ShaderCache {
let mut shader_defs = shader_defs.to_vec();
#[cfg(feature = "webgl")]
{
shader_defs.push(String::from("NO_ARRAY_TEXTURES_SUPPORT"));
shader_defs.push(String::from("SIXTEEN_BYTE_ALIGNMENT"));
shader_defs.push("NO_ARRAY_TEXTURES_SUPPORT".into());
shader_defs.push("SIXTEEN_BYTE_ALIGNMENT".into());
}
// TODO: 3 is the value from CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT declared in bevy_pbr
// consider exposing this in shaders in a more generally useful way, such as:
// # if AVAILABLE_STORAGE_BUFFER_BINDINGS == 3
// /* use storage buffers here */
// # elif
// /* use uniforms here */
if !matches!(
// 3 is the value from CLUSTERED_FORWARD_STORAGE_BUFFER_COUNT declared in bevy_pbr
// Using the value directly here to avoid the cyclic dependency
if matches!(
render_device.get_supported_read_only_binding_type(3),
BufferBindingType::Storage { .. }
) {
shader_defs.push(String::from("NO_STORAGE_BUFFERS_SUPPORT"));
shader_defs.push(ShaderDefVal::Int(
String::from("AVAILABLE_STORAGE_BUFFER_BINDINGS"),
3,
));
}
debug!(

View file

@ -6,13 +6,13 @@ use naga::valid::Capabilities;
use naga::{valid::ModuleInfo, Module};
use once_cell::sync::Lazy;
use regex::Regex;
use std::{
borrow::Cow, collections::HashSet, marker::Copy, ops::Deref, path::PathBuf, str::FromStr,
};
use std::{borrow::Cow, marker::Copy, ops::Deref, path::PathBuf, str::FromStr};
use thiserror::Error;
use wgpu::Features;
use wgpu::{util::make_spirv, ShaderModuleDescriptor, ShaderSource};
use super::ShaderDefVal;
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ShaderId(Uuid);
@ -303,6 +303,18 @@ pub enum ProcessShaderError {
UnresolvedImport(ShaderImport),
#[error("The shader import {0:?} does not match the source file type. Support for this might be added in the future.")]
MismatchedImportFormat(ShaderImport),
#[error("Unknown shader def operator: '{operator}'")]
UnknownShaderDefOperator { operator: String },
#[error("Unknown shader def: '{shader_def_name}'")]
UnknownShaderDef { shader_def_name: String },
#[error(
"Invalid shader def comparison for '{shader_def_name}': expected {expected}, got {value}"
)]
InvalidShaderDefComparisonValue {
shader_def_name: String,
expected: String,
value: String,
},
}
pub struct ShaderImportProcessor {
@ -371,8 +383,11 @@ pub static SHADER_IMPORT_PROCESSOR: Lazy<ShaderImportProcessor> =
pub struct ShaderProcessor {
ifdef_regex: Regex,
ifndef_regex: Regex,
ifop_regex: Regex,
else_regex: Regex,
endif_regex: Regex,
def_regex: Regex,
def_regex_delimited: Regex,
}
impl Default for ShaderProcessor {
@ -380,8 +395,11 @@ impl Default for ShaderProcessor {
Self {
ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(),
ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(),
ifop_regex: Regex::new(r"^\s*#\s*if\s*([\w|\d|_]+)\s*([^\s]*)\s*([\w|\d]+)").unwrap(),
else_regex: Regex::new(r"^\s*#\s*else").unwrap(),
endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(),
def_regex: Regex::new(r"#\s*([\w|\d|_]+)").unwrap(),
def_regex_delimited: Regex::new(r"#\s*\{([\w|\d|_]+)\}").unwrap(),
}
}
}
@ -390,7 +408,7 @@ impl ShaderProcessor {
pub fn process(
&self,
shader: &Shader,
shader_defs: &[String],
shader_defs: &[ShaderDefVal],
shaders: &HashMap<Handle<Shader>, Shader>,
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
) -> Result<ProcessedShader, ProcessShaderError> {
@ -405,16 +423,69 @@ impl ShaderProcessor {
}
};
let shader_defs_unique = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let shader_defs_unique =
HashMap::<String, ShaderDefVal>::from_iter(shader_defs.iter().map(|v| match v {
ShaderDefVal::Bool(k, _) | ShaderDefVal::Int(k, _) => (k.clone(), v.clone()),
}));
let mut scopes = vec![true];
let mut final_string = String::new();
for line in shader_str.lines() {
if let Some(cap) = self.ifdef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && shader_defs_unique.contains(def.as_str()));
scopes
.push(*scopes.last().unwrap() && shader_defs_unique.contains_key(def.as_str()));
} else if let Some(cap) = self.ifndef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && !shader_defs_unique.contains(def.as_str()));
scopes.push(
*scopes.last().unwrap() && !shader_defs_unique.contains_key(def.as_str()),
);
} else if let Some(cap) = self.ifop_regex.captures(line) {
let def = cap.get(1).unwrap();
let op = cap.get(2).unwrap();
let val = cap.get(3).unwrap();
fn act_on<T: Eq + Ord>(a: T, b: T, op: &str) -> Result<bool, ProcessShaderError> {
match op {
"==" => Ok(a == b),
"!=" => Ok(a != b),
">" => Ok(a > b),
">=" => Ok(a >= b),
"<" => Ok(a < b),
"<=" => Ok(a <= b),
_ => Err(ProcessShaderError::UnknownShaderDefOperator {
operator: op.to_string(),
}),
}
}
let def = shader_defs_unique.get(def.as_str()).ok_or(
ProcessShaderError::UnknownShaderDef {
shader_def_name: def.as_str().to_string(),
},
)?;
let new_scope = match def {
ShaderDefVal::Bool(name, def) => {
let val = val.as_str().parse().map_err(|_| {
ProcessShaderError::InvalidShaderDefComparisonValue {
shader_def_name: name.clone(),
value: val.as_str().to_string(),
expected: "bool".to_string(),
}
})?;
act_on(*def, val, op.as_str())?
}
ShaderDefVal::Int(name, def) => {
let val = val.as_str().parse().map_err(|_| {
ProcessShaderError::InvalidShaderDefComparisonValue {
shader_def_name: name.clone(),
value: val.as_str().to_string(),
expected: "int".to_string(),
}
})?;
act_on(*def, val, op.as_str())?
}
};
scopes.push(*scopes.last().unwrap() && new_scope);
} else if self.else_regex.is_match(line) {
let mut is_parent_scope_truthy = true;
if scopes.len() > 1 {
@ -461,7 +532,32 @@ impl ShaderProcessor {
{
// ignore import path lines
} else {
final_string.push_str(line);
let mut line_with_defs = line.to_string();
for capture in self.def_regex.captures_iter(line) {
let def = capture.get(1).unwrap();
if let Some(def) = shader_defs_unique.get(def.as_str()) {
let def = match def {
ShaderDefVal::Bool(_, def) => def.to_string(),
ShaderDefVal::Int(_, def) => def.to_string(),
};
line_with_defs =
self.def_regex.replace(&line_with_defs, def).to_string();
}
}
for capture in self.def_regex_delimited.captures_iter(line) {
let def = capture.get(1).unwrap();
if let Some(def) = shader_defs_unique.get(def.as_str()) {
let def = match def {
ShaderDefVal::Bool(_, def) => def.to_string(),
ShaderDefVal::Int(_, def) => def.to_string(),
};
line_with_defs = self
.def_regex_delimited
.replace(&line_with_defs, def)
.to_string();
}
}
final_string.push_str(&line_with_defs);
final_string.push('\n');
}
}
@ -488,7 +584,7 @@ impl ShaderProcessor {
shaders: &HashMap<Handle<Shader>, Shader>,
import: &ShaderImport,
shader: &Shader,
shader_defs: &[String],
shader_defs: &[ShaderDefVal],
final_string: &mut String,
) -> Result<(), ProcessShaderError> {
let imported_shader = import_handles
@ -557,7 +653,9 @@ mod tests {
use bevy_utils::HashMap;
use naga::ShaderStage;
use crate::render_resource::{ProcessShaderError, Shader, ShaderImport, ShaderProcessor};
use crate::render_resource::{
ProcessShaderError, Shader, ShaderDefVal, ShaderImport, ShaderProcessor,
};
#[rustfmt::skip]
const WGSL: &str = r"
struct View {
@ -723,7 +821,7 @@ fn vertex(
let result = processor
.process(
&Shader::from_wgsl(WGSL),
&["TEXTURE".to_string()],
&["TEXTURE".into()],
&HashMap::default(),
&HashMap::default(),
)
@ -968,7 +1066,7 @@ fn vertex(
let result = processor
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&["TEXTURE".to_string()],
&["TEXTURE".into()],
&HashMap::default(),
&HashMap::default(),
)
@ -1010,7 +1108,7 @@ fn vertex(
let result = processor
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF_ELSE),
&["TEXTURE".to_string()],
&["TEXTURE".into()],
&HashMap::default(),
&HashMap::default(),
)
@ -1130,7 +1228,7 @@ fn vertex(
let result = processor
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&["ATTRIBUTE".to_string()],
&["ATTRIBUTE".into()],
&HashMap::default(),
&HashMap::default(),
)
@ -1172,7 +1270,7 @@ fn vertex(
let result = processor
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&["TEXTURE".to_string(), "ATTRIBUTE".to_string()],
&["TEXTURE".into(), "ATTRIBUTE".into()],
&HashMap::default(),
&HashMap::default(),
)
@ -1219,7 +1317,7 @@ fn in_main_present() { }
let result = processor
.process(
&Shader::from_wgsl(INPUT),
&["MAIN_PRESENT".to_string(), "IMPORT_PRESENT".to_string()],
&["MAIN_PRESENT".into(), "IMPORT_PRESENT".into()],
&shaders,
&import_handles,
)
@ -1274,7 +1372,7 @@ fn in_main() { }
let result = processor
.process(
&Shader::from_wgsl(INPUT),
&["DEEP".to_string()],
&["DEEP".into()],
&shaders,
&import_handles,
)
@ -1332,7 +1430,7 @@ fn baz() { }
let result = processor
.process(
&Shader::from_wgsl(INPUT),
&["FOO".to_string()],
&["FOO".into()],
&shaders,
&import_handles,
)
@ -1344,4 +1442,515 @@ fn baz() { }
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
fn process_shader_def_unknown_operator() {
#[rustfmt::skip]
const WGSL: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
#if TEXTURE !! true
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
#endif
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result_missing = processor.process(
&Shader::from_wgsl(WGSL),
&["TEXTURE".into()],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(
result_missing,
Err(ProcessShaderError::UnknownShaderDefOperator {
operator: "!!".to_string()
})
);
}
#[test]
fn process_shader_def_equal_int() {
#[rustfmt::skip]
const WGSL: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
#if TEXTURE == 3
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
#endif
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_EQ: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_NEQ: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result_eq = processor
.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Int("TEXTURE".to_string(), 3)],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result_eq.get_wgsl_source().unwrap(), EXPECTED_EQ);
let result_neq = processor
.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Int("TEXTURE".to_string(), 7)],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result_neq.get_wgsl_source().unwrap(), EXPECTED_NEQ);
let result_missing = processor.process(
&Shader::from_wgsl(WGSL),
&[],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(
result_missing,
Err(ProcessShaderError::UnknownShaderDef {
shader_def_name: "TEXTURE".to_string()
})
);
let result_wrong_type = processor.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Bool("TEXTURE".to_string(), true)],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(
result_wrong_type,
Err(ProcessShaderError::InvalidShaderDefComparisonValue {
shader_def_name: "TEXTURE".to_string(),
expected: "bool".to_string(),
value: "3".to_string()
})
);
}
#[test]
fn process_shader_def_equal_bool() {
#[rustfmt::skip]
const WGSL: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
#if TEXTURE == true
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
#endif
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_EQ: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_NEQ: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result_eq = processor
.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Bool("TEXTURE".to_string(), true)],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result_eq.get_wgsl_source().unwrap(), EXPECTED_EQ);
let result_neq = processor
.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Bool("TEXTURE".to_string(), false)],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result_neq.get_wgsl_source().unwrap(), EXPECTED_NEQ);
}
#[test]
fn process_shader_def_not_equal_bool() {
#[rustfmt::skip]
const WGSL: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
#if TEXTURE != false
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
#endif
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_EQ: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
@group(1) @binding(0)
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_NEQ: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result_eq = processor
.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Bool("TEXTURE".to_string(), true)],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result_eq.get_wgsl_source().unwrap(), EXPECTED_EQ);
let result_neq = processor
.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Bool("TEXTURE".to_string(), false)],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result_neq.get_wgsl_source().unwrap(), EXPECTED_NEQ);
let result_missing = processor.process(
&Shader::from_wgsl(WGSL),
&[],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(
result_missing,
Err(ProcessShaderError::UnknownShaderDef {
shader_def_name: "TEXTURE".to_string()
})
);
let result_wrong_type = processor.process(
&Shader::from_wgsl(WGSL),
&[ShaderDefVal::Int("TEXTURE".to_string(), 7)],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(
result_wrong_type,
Err(ProcessShaderError::InvalidShaderDefComparisonValue {
shader_def_name: "TEXTURE".to_string(),
expected: "int".to_string(),
value: "false".to_string()
})
);
}
#[test]
fn process_shader_def_replace() {
#[rustfmt::skip]
const WGSL: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
var a: i32 = #FIRST_VALUE;
var b: i32 = #FIRST_VALUE * #SECOND_VALUE;
var c: i32 = #MISSING_VALUE;
var d: bool = #BOOL_VALUE;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[rustfmt::skip]
const EXPECTED_REPLACED: &str = r"
struct View {
view_proj: mat4x4<f32>,
world_position: vec3<f32>,
};
@group(0) @binding(0)
var<uniform> view: View;
struct VertexOutput {
@location(0) uv: vec2<f32>,
@builtin(position) position: vec4<f32>,
};
@vertex
fn vertex(
@location(0) vertex_position: vec3<f32>,
@location(1) vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
var a: i32 = 5;
var b: i32 = 5 * 3;
var c: i32 = #MISSING_VALUE;
var d: bool = true;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process(
&Shader::from_wgsl(WGSL),
&[
ShaderDefVal::Bool("BOOL_VALUE".to_string(), true),
ShaderDefVal::Int("FIRST_VALUE".to_string(), 5),
ShaderDefVal::Int("SECOND_VALUE".to_string(), 3),
],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED_REPLACED);
}
}

View file

@ -351,36 +351,36 @@ impl SpecializedMeshPipeline for Mesh2dPipeline {
let mut vertex_attributes = Vec::new();
if layout.contains(Mesh::ATTRIBUTE_POSITION) {
shader_defs.push(String::from("VERTEX_POSITIONS"));
shader_defs.push("VERTEX_POSITIONS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_POSITION.at_shader_location(0));
}
if layout.contains(Mesh::ATTRIBUTE_NORMAL) {
shader_defs.push(String::from("VERTEX_NORMALS"));
shader_defs.push("VERTEX_NORMALS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_NORMAL.at_shader_location(1));
}
if layout.contains(Mesh::ATTRIBUTE_UV_0) {
shader_defs.push(String::from("VERTEX_UVS"));
shader_defs.push("VERTEX_UVS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_UV_0.at_shader_location(2));
}
if layout.contains(Mesh::ATTRIBUTE_TANGENT) {
shader_defs.push(String::from("VERTEX_TANGENTS"));
shader_defs.push("VERTEX_TANGENTS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_TANGENT.at_shader_location(3));
}
if layout.contains(Mesh::ATTRIBUTE_COLOR) {
shader_defs.push(String::from("VERTEX_COLORS"));
shader_defs.push("VERTEX_COLORS".into());
vertex_attributes.push(Mesh::ATTRIBUTE_COLOR.at_shader_location(4));
}
if key.contains(Mesh2dPipelineKey::TONEMAP_IN_SHADER) {
shader_defs.push("TONEMAP_IN_SHADER".to_string());
shader_defs.push("TONEMAP_IN_SHADER".into());
// Debanding is tied to tonemapping in the shader, cannot run without it.
if key.contains(Mesh2dPipelineKey::DEBAND_DITHER) {
shader_defs.push("DEBAND_DITHER".to_string());
shader_defs.push("DEBAND_DITHER".into());
}
}

View file

@ -208,15 +208,15 @@ impl SpecializedRenderPipeline for SpritePipeline {
let mut shader_defs = Vec::new();
if key.contains(SpritePipelineKey::COLORED) {
shader_defs.push("COLORED".to_string());
shader_defs.push("COLORED".into());
}
if key.contains(SpritePipelineKey::TONEMAP_IN_SHADER) {
shader_defs.push("TONEMAP_IN_SHADER".to_string());
shader_defs.push("TONEMAP_IN_SHADER".into());
// Debanding is tied to tonemapping in the shader, cannot run without it.
if key.contains(SpritePipelineKey::DEBAND_DITHER) {
shader_defs.push("DEBAND_DITHER".to_string());
shader_defs.push("DEBAND_DITHER".into());
}
}

View file

@ -68,7 +68,7 @@ impl Material for CustomMaterial {
) -> Result<(), SpecializedMeshPipelineError> {
if key.bind_group_data.is_red {
let fragment = descriptor.fragment.as_mut().unwrap();
fragment.shader_defs.push("IS_RED".to_string());
fragment.shader_defs.push("IS_RED".into());
}
Ok(())
}