bevy/pipelined/bevy_render2/src/render_resource/shader.rs
Robert Swain 14ce281904 bevy_render2: Support nested shader defs (#3113)
# Objective

Fix nested shader defs. For example, in:
```rust
#ifdef A
#ifdef B
some code here
#endif
#endif
```
...before this PR, if `A` *is not* defined, and `B` *is* defined, then `some code here` will be output.

## Solution

- Combine the logic of whether the parent and child scope guards are defined and use that as the resulting child scope guard boolean value
2021-11-16 03:19:08 +00:00

614 lines
17 KiB
Rust

use bevy_asset::{AssetLoader, LoadContext, LoadedAsset};
use bevy_reflect::{TypeUuid, Uuid};
use bevy_utils::{tracing::error, BoxedFuture};
use naga::{valid::ModuleInfo, Module};
use regex::Regex;
use std::{borrow::Cow, collections::HashSet, marker::Copy};
use thiserror::Error;
use wgpu::{ShaderModuleDescriptor, ShaderSource};
use crate::render_asset::{PrepareAssetError, RenderAsset};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ShaderId(Uuid);
impl ShaderId {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
ShaderId(Uuid::new_v4())
}
}
#[derive(Error, Debug)]
pub enum ShaderReflectError {
#[error(transparent)]
WgslParse(#[from] naga::front::wgsl::ParseError),
#[error("GLSL Parse Error: {0:?}")]
GlslParse(Vec<naga::front::glsl::Error>),
#[error(transparent)]
SpirVParse(#[from] naga::front::spv::Error),
#[error(transparent)]
Validation(#[from] naga::valid::ValidationError),
}
/// A shader, as defined by its [ShaderSource] and [ShaderStage]
/// This is an "unprocessed" shader. It can contain preprocessor directives.
#[derive(Debug, Clone, TypeUuid)]
#[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"]
pub enum Shader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
// TODO: consider the following
// PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
// NagaModule(Module) ... Module impls Serialize/Deserialize
}
/// A processed [Shader]. This cannot contain preprocessor directions. It must be "ready to compile"
pub enum ProcessedShader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
}
impl ProcessedShader {
pub fn reflect(&self) -> Result<ShaderReflection, ShaderReflectError> {
let module = match &self {
// TODO: process macros here
ProcessedShader::Wgsl(source) => naga::front::wgsl::parse_str(source)?,
ProcessedShader::Glsl(source, shader_stage) => {
let mut parser = naga::front::glsl::Parser::default();
parser
.parse(&naga::front::glsl::Options::from(*shader_stage), source)
.map_err(ShaderReflectError::GlslParse)?
}
ProcessedShader::SpirV(source) => naga::front::spv::parse_u8_slice(
source,
&naga::front::spv::Options {
adjust_coordinate_space: false,
..naga::front::spv::Options::default()
},
)?,
};
let module_info = naga::valid::Validator::new(
naga::valid::ValidationFlags::default(),
naga::valid::Capabilities::default(),
)
.validate(&module)?;
Ok(ShaderReflection {
module,
module_info,
})
}
pub fn get_module_descriptor(&self) -> Result<ShaderModuleDescriptor, AsModuleDescriptorError> {
Ok(ShaderModuleDescriptor {
label: None,
source: match self {
ProcessedShader::Wgsl(source) => ShaderSource::Wgsl(source.clone()),
ProcessedShader::Glsl(_source, _stage) => {
let reflection = self.reflect()?;
// TODO: it probably makes more sense to convert this to spirv, but as of writing
// this comment, naga's spirv conversion is broken
let wgsl = reflection.get_wgsl()?;
ShaderSource::Wgsl(wgsl.into())
}
ProcessedShader::SpirV(_) => {
// TODO: we can probably just transmute the u8 array to u32?
let reflection = self.reflect()?;
let spirv = reflection.get_spirv()?;
ShaderSource::SpirV(Cow::Owned(spirv))
}
},
})
}
}
#[derive(Error, Debug)]
pub enum AsModuleDescriptorError {
#[error(transparent)]
ShaderReflectError(#[from] ShaderReflectError),
#[error(transparent)]
WgslConversion(#[from] naga::back::wgsl::Error),
#[error(transparent)]
SpirVConversion(#[from] naga::back::spv::Error),
}
pub struct ShaderReflection {
pub module: Module,
pub module_info: ModuleInfo,
}
impl ShaderReflection {
pub fn get_spirv(&self) -> Result<Vec<u32>, naga::back::spv::Error> {
naga::back::spv::write_vec(
&self.module,
&self.module_info,
&naga::back::spv::Options {
flags: naga::back::spv::WriterFlags::empty(),
..naga::back::spv::Options::default()
},
None,
)
}
pub fn get_wgsl(&self) -> Result<String, naga::back::wgsl::Error> {
naga::back::wgsl::write_string(&self.module, &self.module_info)
}
}
impl Shader {
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
Shader::Wgsl(source.into())
}
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
Shader::Glsl(source.into(), stage)
}
pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>) -> Shader {
Shader::SpirV(source.into())
}
}
#[derive(Default)]
pub struct ShaderLoader;
impl AssetLoader for ShaderLoader {
fn load<'a>(
&'a self,
bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let ext = load_context.path().extension().unwrap().to_str().unwrap();
let shader = match ext {
"spv" => Shader::from_spirv(Vec::from(bytes)),
"wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?),
"vert" => Shader::from_glsl(
String::from_utf8(Vec::from(bytes))?,
naga::ShaderStage::Vertex,
),
"frag" => Shader::from_glsl(
String::from_utf8(Vec::from(bytes))?,
naga::ShaderStage::Fragment,
),
_ => panic!("unhandled extension: {}", ext),
};
load_context.set_default_asset(LoadedAsset::new(shader));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["spv", "wgsl", "vert", "frag"]
}
}
impl RenderAsset for Shader {
type ExtractedAsset = Shader;
type PreparedAsset = Shader;
type Param = ();
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
extracted_asset: Self::ExtractedAsset,
_param: &mut bevy_ecs::system::SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
Ok(extracted_asset)
}
}
#[derive(Error, Debug, PartialEq, Eq)]
pub enum ProcessShaderError {
#[error("Too many '# endif' lines. Each endif should be preceded by an if statement.")]
TooManyEndIfs,
#[error(
"Not enough '# endif' lines. Each if statement should be followed by an endif statement."
)]
NotEnoughEndIfs,
#[error("This Shader's format does not support processing shader defs.")]
ShaderFormatDoesNotSupportShaderDefs,
}
pub struct ShaderProcessor {
ifdef_regex: Regex,
ifndef_regex: Regex,
endif_regex: Regex,
}
impl Default for ShaderProcessor {
fn default() -> Self {
Self {
ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(),
ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(),
endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(),
}
}
}
impl ShaderProcessor {
pub fn process_shader(
&self,
shader: &Shader,
shader_defs: &[String],
) -> Result<ProcessedShader, ProcessShaderError> {
match shader {
Shader::Wgsl(source) => Ok(ProcessedShader::Wgsl(Cow::from(
self.process_str(source, shader_defs)?,
))),
Shader::Glsl(source, stage) => Ok(ProcessedShader::Glsl(
Cow::from(self.process_str(source, shader_defs)?),
*stage,
)),
Shader::SpirV(source) => {
if shader_defs.is_empty() {
Ok(ProcessedShader::SpirV(source.clone()))
} else {
Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs)
}
}
}
}
pub fn process_str(
&self,
shader: &str,
shader_defs: &[String],
) -> Result<String, ProcessShaderError> {
let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let mut scopes = vec![true];
let mut final_string = String::new();
for line in shader.split('\n') {
if let Some(cap) = self.ifdef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str()));
} else if let Some(cap) = self.ifndef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && !shader_defs.contains(def.as_str()));
} else if self.endif_regex.is_match(line) {
scopes.pop();
if scopes.is_empty() {
return Err(ProcessShaderError::TooManyEndIfs);
}
} else if *scopes.last().unwrap() {
final_string.push_str(line);
final_string.push('\n');
}
}
final_string.pop();
if scopes.len() != 1 {
return Err(ProcessShaderError::NotEnoughEndIfs);
}
Ok(final_string)
}
}
#[cfg(test)]
mod tests {
use crate::render_resource::{ProcessShaderError, ShaderProcessor};
#[rustfmt::skip]
const WGSL: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
# ifdef TEXTURE
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
# endif
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
const WGSL_NESTED_IFDEF: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
# ifdef TEXTURE
# ifdef ATTRIBUTE
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
# endif
# endif
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[test]
fn process_shader_def_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL, &["TEXTURE".to_string()])
.unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_shader_def_not_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor.process_str(WGSL, &[]).unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_shader_def_unclosed() {
#[rustfmt::skip]
const INPUT: &str = r"
# ifdef FOO
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
assert_eq!(result, Err(ProcessShaderError::NotEnoughEndIfs));
}
#[test]
fn process_shader_def_too_closed() {
#[rustfmt::skip]
const INPUT: &str = r"
# endif
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
assert_eq!(result, Err(ProcessShaderError::TooManyEndIfs));
}
#[test]
fn process_shader_def_commented() {
#[rustfmt::skip]
const INPUT: &str = r"
// # ifdef FOO
fn foo() { }
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]).unwrap();
assert_eq!(result, INPUT);
}
#[test]
fn process_nested_shader_def_outer_defined_inner_not() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL_NESTED_IFDEF, &["TEXTURE".to_string()])
.unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_nested_shader_def_neither_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor.process_str(WGSL_NESTED_IFDEF, &[]).unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_nested_shader_def_inner_defined_outer_not() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL_NESTED_IFDEF, &["ATTRIBUTE".to_string()])
.unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_nested_shader_def_both_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(
WGSL_NESTED_IFDEF,
&["TEXTURE".to_string(), "ATTRIBUTE".to_string()],
)
.unwrap();
assert_eq!(result, EXPECTED);
}
}