use bevy_asset::{AssetLoader, LoadContext, LoadedAsset}; use bevy_reflect::{TypeUuid, Uuid}; use bevy_utils::{tracing::error, BoxedFuture}; use naga::{valid::ModuleInfo, Module}; use regex::Regex; use std::{borrow::Cow, collections::HashSet, marker::Copy}; use thiserror::Error; use wgpu::{ShaderModuleDescriptor, ShaderSource}; use crate::render_asset::{PrepareAssetError, RenderAsset}; #[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)] pub struct ShaderId(Uuid); impl ShaderId { #[allow(clippy::new_without_default)] pub fn new() -> Self { ShaderId(Uuid::new_v4()) } } #[derive(Error, Debug)] pub enum ShaderReflectError { #[error(transparent)] WgslParse(#[from] naga::front::wgsl::ParseError), #[error("GLSL Parse Error: {0:?}")] GlslParse(Vec), #[error(transparent)] SpirVParse(#[from] naga::front::spv::Error), #[error(transparent)] Validation(#[from] naga::valid::ValidationError), } /// A shader, as defined by its [ShaderSource] and [ShaderStage] /// This is an "unprocessed" shader. It can contain preprocessor directives. #[derive(Debug, Clone, TypeUuid)] #[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"] pub enum Shader { Wgsl(Cow<'static, str>), Glsl(Cow<'static, str>, naga::ShaderStage), SpirV(Cow<'static, [u8]>), // TODO: consider the following // PrecompiledSpirVMacros(HashMap, Vec>) // NagaModule(Module) ... Module impls Serialize/Deserialize } /// A processed [Shader]. This cannot contain preprocessor directions. It must be "ready to compile" pub enum ProcessedShader { Wgsl(Cow<'static, str>), Glsl(Cow<'static, str>, naga::ShaderStage), SpirV(Cow<'static, [u8]>), } impl ProcessedShader { pub fn reflect(&self) -> Result { let module = match &self { // TODO: process macros here ProcessedShader::Wgsl(source) => naga::front::wgsl::parse_str(source)?, ProcessedShader::Glsl(source, shader_stage) => { let mut parser = naga::front::glsl::Parser::default(); parser .parse(&naga::front::glsl::Options::from(*shader_stage), source) .map_err(ShaderReflectError::GlslParse)? } ProcessedShader::SpirV(source) => naga::front::spv::parse_u8_slice( source, &naga::front::spv::Options { adjust_coordinate_space: false, ..naga::front::spv::Options::default() }, )?, }; let module_info = naga::valid::Validator::new( naga::valid::ValidationFlags::default(), naga::valid::Capabilities::default(), ) .validate(&module)?; Ok(ShaderReflection { module, module_info, }) } pub fn get_module_descriptor(&self) -> Result { Ok(ShaderModuleDescriptor { label: None, source: match self { ProcessedShader::Wgsl(source) => ShaderSource::Wgsl(source.clone()), ProcessedShader::Glsl(_source, _stage) => { let reflection = self.reflect()?; // TODO: it probably makes more sense to convert this to spirv, but as of writing // this comment, naga's spirv conversion is broken let wgsl = reflection.get_wgsl()?; ShaderSource::Wgsl(wgsl.into()) } ProcessedShader::SpirV(_) => { // TODO: we can probably just transmute the u8 array to u32? let reflection = self.reflect()?; let spirv = reflection.get_spirv()?; ShaderSource::SpirV(Cow::Owned(spirv)) } }, }) } } #[derive(Error, Debug)] pub enum AsModuleDescriptorError { #[error(transparent)] ShaderReflectError(#[from] ShaderReflectError), #[error(transparent)] WgslConversion(#[from] naga::back::wgsl::Error), #[error(transparent)] SpirVConversion(#[from] naga::back::spv::Error), } pub struct ShaderReflection { pub module: Module, pub module_info: ModuleInfo, } impl ShaderReflection { pub fn get_spirv(&self) -> Result, naga::back::spv::Error> { naga::back::spv::write_vec( &self.module, &self.module_info, &naga::back::spv::Options { flags: naga::back::spv::WriterFlags::empty(), ..naga::back::spv::Options::default() }, None, ) } pub fn get_wgsl(&self) -> Result { naga::back::wgsl::write_string(&self.module, &self.module_info) } } impl Shader { pub fn from_wgsl(source: impl Into>) -> Shader { Shader::Wgsl(source.into()) } pub fn from_glsl(source: impl Into>, stage: naga::ShaderStage) -> Shader { Shader::Glsl(source.into(), stage) } pub fn from_spirv(source: impl Into>) -> Shader { Shader::SpirV(source.into()) } } #[derive(Default)] pub struct ShaderLoader; impl AssetLoader for ShaderLoader { fn load<'a>( &'a self, bytes: &'a [u8], load_context: &'a mut LoadContext, ) -> BoxedFuture<'a, Result<(), anyhow::Error>> { Box::pin(async move { let ext = load_context.path().extension().unwrap().to_str().unwrap(); let shader = match ext { "spv" => Shader::from_spirv(Vec::from(bytes)), "wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?), "vert" => Shader::from_glsl( String::from_utf8(Vec::from(bytes))?, naga::ShaderStage::Vertex, ), "frag" => Shader::from_glsl( String::from_utf8(Vec::from(bytes))?, naga::ShaderStage::Fragment, ), _ => panic!("unhandled extension: {}", ext), }; load_context.set_default_asset(LoadedAsset::new(shader)); Ok(()) }) } fn extensions(&self) -> &[&str] { &["spv", "wgsl", "vert", "frag"] } } impl RenderAsset for Shader { type ExtractedAsset = Shader; type PreparedAsset = Shader; type Param = (); fn extract_asset(&self) -> Self::ExtractedAsset { self.clone() } fn prepare_asset( extracted_asset: Self::ExtractedAsset, _param: &mut bevy_ecs::system::SystemParamItem, ) -> Result> { Ok(extracted_asset) } } #[derive(Error, Debug, PartialEq, Eq)] pub enum ProcessShaderError { #[error("Too many '# endif' lines. Each endif should be preceded by an if statement.")] TooManyEndIfs, #[error( "Not enough '# endif' lines. Each if statement should be followed by an endif statement." )] NotEnoughEndIfs, #[error("This Shader's format does not support processing shader defs.")] ShaderFormatDoesNotSupportShaderDefs, } pub struct ShaderProcessor { ifdef_regex: Regex, ifndef_regex: Regex, endif_regex: Regex, } impl Default for ShaderProcessor { fn default() -> Self { Self { ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(), ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(), endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(), } } } impl ShaderProcessor { pub fn process_shader( &self, shader: &Shader, shader_defs: &[String], ) -> Result { match shader { Shader::Wgsl(source) => Ok(ProcessedShader::Wgsl(Cow::from( self.process_str(source, shader_defs)?, ))), Shader::Glsl(source, stage) => Ok(ProcessedShader::Glsl( Cow::from(self.process_str(source, shader_defs)?), *stage, )), Shader::SpirV(source) => { if shader_defs.is_empty() { Ok(ProcessedShader::SpirV(source.clone())) } else { Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs) } } } } pub fn process_str( &self, shader: &str, shader_defs: &[String], ) -> Result { let shader_defs = HashSet::::from_iter(shader_defs.iter().cloned()); let mut scopes = vec![true]; let mut final_string = String::new(); for line in shader.split('\n') { if let Some(cap) = self.ifdef_regex.captures(line) { let def = cap.get(1).unwrap(); scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str())); } else if let Some(cap) = self.ifndef_regex.captures(line) { let def = cap.get(1).unwrap(); scopes.push(*scopes.last().unwrap() && !shader_defs.contains(def.as_str())); } else if self.endif_regex.is_match(line) { scopes.pop(); if scopes.is_empty() { return Err(ProcessShaderError::TooManyEndIfs); } } else if *scopes.last().unwrap() { final_string.push_str(line); final_string.push('\n'); } } final_string.pop(); if scopes.len() != 1 { return Err(ProcessShaderError::NotEnoughEndIfs); } Ok(final_string) } } #[cfg(test)] mod tests { use crate::render_resource::{ProcessShaderError, ShaderProcessor}; #[rustfmt::skip] const WGSL: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; # ifdef TEXTURE [[group(1), binding(0)]] var sprite_texture: texture_2d; # endif struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; const WGSL_NESTED_IFDEF: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; # ifdef TEXTURE # ifdef ATTRIBUTE [[group(1), binding(0)]] var sprite_texture: texture_2d; # endif # endif struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; #[test] fn process_shader_def_defined() { #[rustfmt::skip] const EXPECTED: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; [[group(1), binding(0)]] var sprite_texture: texture_2d; struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; let processor = ShaderProcessor::default(); let result = processor .process_str(WGSL, &["TEXTURE".to_string()]) .unwrap(); assert_eq!(result, EXPECTED); } #[test] fn process_shader_def_not_defined() { #[rustfmt::skip] const EXPECTED: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; let processor = ShaderProcessor::default(); let result = processor.process_str(WGSL, &[]).unwrap(); assert_eq!(result, EXPECTED); } #[test] fn process_shader_def_unclosed() { #[rustfmt::skip] const INPUT: &str = r" # ifdef FOO "; let processor = ShaderProcessor::default(); let result = processor.process_str(INPUT, &[]); assert_eq!(result, Err(ProcessShaderError::NotEnoughEndIfs)); } #[test] fn process_shader_def_too_closed() { #[rustfmt::skip] const INPUT: &str = r" # endif "; let processor = ShaderProcessor::default(); let result = processor.process_str(INPUT, &[]); assert_eq!(result, Err(ProcessShaderError::TooManyEndIfs)); } #[test] fn process_shader_def_commented() { #[rustfmt::skip] const INPUT: &str = r" // # ifdef FOO fn foo() { } "; let processor = ShaderProcessor::default(); let result = processor.process_str(INPUT, &[]).unwrap(); assert_eq!(result, INPUT); } #[test] fn process_nested_shader_def_outer_defined_inner_not() { #[rustfmt::skip] const EXPECTED: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; let processor = ShaderProcessor::default(); let result = processor .process_str(WGSL_NESTED_IFDEF, &["TEXTURE".to_string()]) .unwrap(); assert_eq!(result, EXPECTED); } #[test] fn process_nested_shader_def_neither_defined() { #[rustfmt::skip] const EXPECTED: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; let processor = ShaderProcessor::default(); let result = processor.process_str(WGSL_NESTED_IFDEF, &[]).unwrap(); assert_eq!(result, EXPECTED); } #[test] fn process_nested_shader_def_inner_defined_outer_not() { #[rustfmt::skip] const EXPECTED: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; let processor = ShaderProcessor::default(); let result = processor .process_str(WGSL_NESTED_IFDEF, &["ATTRIBUTE".to_string()]) .unwrap(); assert_eq!(result, EXPECTED); } #[test] fn process_nested_shader_def_both_defined() { #[rustfmt::skip] const EXPECTED: &str = r" [[block]] struct View { view_proj: mat4x4; world_position: vec3; }; [[group(0), binding(0)]] var view: View; [[group(1), binding(0)]] var sprite_texture: texture_2d; struct VertexOutput { [[location(0)]] uv: vec2; [[builtin(position)]] position: vec4; }; [[stage(vertex)]] fn vertex( [[location(0)]] vertex_position: vec3, [[location(1)]] vertex_uv: vec2 ) -> VertexOutput { var out: VertexOutput; out.uv = vertex_uv; out.position = view.view_proj * vec4(vertex_position, 1.0); return out; } "; let processor = ShaderProcessor::default(); let result = processor .process_str( WGSL_NESTED_IFDEF, &["TEXTURE".to_string(), "ATTRIBUTE".to_string()], ) .unwrap(); assert_eq!(result, EXPECTED); } }