mirror of
https://github.com/bevyengine/bevy
synced 2025-01-07 02:38:59 +00:00
436 lines
13 KiB
Rust
436 lines
13 KiB
Rust
|
use bevy_asset::{AssetLoader, LoadContext, LoadedAsset};
|
||
|
use bevy_reflect::{TypeUuid, Uuid};
|
||
|
use bevy_utils::{tracing::error, BoxedFuture};
|
||
|
use naga::{valid::ModuleInfo, Module};
|
||
|
use regex::Regex;
|
||
|
use std::{borrow::Cow, collections::HashSet, marker::Copy};
|
||
|
use thiserror::Error;
|
||
|
use wgpu::{ShaderModuleDescriptor, ShaderSource};
|
||
|
|
||
|
use crate::render_asset::{PrepareAssetError, RenderAsset};
|
||
|
|
||
|
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
|
||
|
pub struct ShaderId(Uuid);
|
||
|
|
||
|
impl ShaderId {
|
||
|
#[allow(clippy::new_without_default)]
|
||
|
pub fn new() -> Self {
|
||
|
ShaderId(Uuid::new_v4())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Error, Debug)]
|
||
|
pub enum ShaderReflectError {
|
||
|
#[error(transparent)]
|
||
|
WgslParse(#[from] naga::front::wgsl::ParseError),
|
||
|
#[error("GLSL Parse Error: {0:?}")]
|
||
|
GlslParse(Vec<naga::front::glsl::Error>),
|
||
|
#[error(transparent)]
|
||
|
SpirVParse(#[from] naga::front::spv::Error),
|
||
|
#[error(transparent)]
|
||
|
Validation(#[from] naga::valid::ValidationError),
|
||
|
}
|
||
|
|
||
|
/// A shader, as defined by its [ShaderSource] and [ShaderStage]
|
||
|
/// This is an "unprocessed" shader. It can contain preprocessor directives.
|
||
|
#[derive(Debug, Clone, TypeUuid)]
|
||
|
#[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"]
|
||
|
pub enum Shader {
|
||
|
Wgsl(Cow<'static, str>),
|
||
|
Glsl(Cow<'static, str>, naga::ShaderStage),
|
||
|
SpirV(Cow<'static, [u8]>),
|
||
|
// TODO: consider the following
|
||
|
// PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
|
||
|
// NagaModule(Module) ... Module impls Serialize/Deserialize
|
||
|
}
|
||
|
|
||
|
/// A processed [Shader]. This cannot contain preprocessor directions. It must be "ready to compile"
|
||
|
pub enum ProcessedShader {
|
||
|
Wgsl(Cow<'static, str>),
|
||
|
Glsl(Cow<'static, str>, naga::ShaderStage),
|
||
|
SpirV(Cow<'static, [u8]>),
|
||
|
}
|
||
|
|
||
|
impl ProcessedShader {
|
||
|
pub fn reflect(&self) -> Result<ShaderReflection, ShaderReflectError> {
|
||
|
let module = match &self {
|
||
|
// TODO: process macros here
|
||
|
ProcessedShader::Wgsl(source) => naga::front::wgsl::parse_str(source)?,
|
||
|
ProcessedShader::Glsl(source, shader_stage) => {
|
||
|
let mut parser = naga::front::glsl::Parser::default();
|
||
|
parser
|
||
|
.parse(&naga::front::glsl::Options::from(*shader_stage), source)
|
||
|
.map_err(ShaderReflectError::GlslParse)?
|
||
|
}
|
||
|
ProcessedShader::SpirV(source) => naga::front::spv::parse_u8_slice(
|
||
|
source,
|
||
|
&naga::front::spv::Options {
|
||
|
adjust_coordinate_space: false,
|
||
|
..naga::front::spv::Options::default()
|
||
|
},
|
||
|
)?,
|
||
|
};
|
||
|
let module_info = naga::valid::Validator::new(
|
||
|
naga::valid::ValidationFlags::default(),
|
||
|
naga::valid::Capabilities::default(),
|
||
|
)
|
||
|
.validate(&module)?;
|
||
|
|
||
|
Ok(ShaderReflection {
|
||
|
module,
|
||
|
module_info,
|
||
|
})
|
||
|
}
|
||
|
|
||
|
pub fn get_module_descriptor(&self) -> Result<ShaderModuleDescriptor, AsModuleDescriptorError> {
|
||
|
Ok(ShaderModuleDescriptor {
|
||
|
label: None,
|
||
|
source: match self {
|
||
|
ProcessedShader::Wgsl(source) => ShaderSource::Wgsl(source.clone()),
|
||
|
ProcessedShader::Glsl(_source, _stage) => {
|
||
|
let reflection = self.reflect()?;
|
||
|
// TODO: it probably makes more sense to convert this to spirv, but as of writing
|
||
|
// this comment, naga's spirv conversion is broken
|
||
|
let wgsl = reflection.get_wgsl()?;
|
||
|
ShaderSource::Wgsl(wgsl.into())
|
||
|
}
|
||
|
ProcessedShader::SpirV(_) => {
|
||
|
// TODO: we can probably just transmute the u8 array to u32?
|
||
|
let reflection = self.reflect()?;
|
||
|
let spirv = reflection.get_spirv()?;
|
||
|
ShaderSource::SpirV(Cow::Owned(spirv))
|
||
|
}
|
||
|
},
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Error, Debug)]
|
||
|
pub enum AsModuleDescriptorError {
|
||
|
#[error(transparent)]
|
||
|
ShaderReflectError(#[from] ShaderReflectError),
|
||
|
#[error(transparent)]
|
||
|
WgslConversion(#[from] naga::back::wgsl::Error),
|
||
|
#[error(transparent)]
|
||
|
SpirVConversion(#[from] naga::back::spv::Error),
|
||
|
}
|
||
|
|
||
|
pub struct ShaderReflection {
|
||
|
pub module: Module,
|
||
|
pub module_info: ModuleInfo,
|
||
|
}
|
||
|
|
||
|
impl ShaderReflection {
|
||
|
pub fn get_spirv(&self) -> Result<Vec<u32>, naga::back::spv::Error> {
|
||
|
naga::back::spv::write_vec(
|
||
|
&self.module,
|
||
|
&self.module_info,
|
||
|
&naga::back::spv::Options {
|
||
|
flags: naga::back::spv::WriterFlags::empty(),
|
||
|
..naga::back::spv::Options::default()
|
||
|
},
|
||
|
None,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
pub fn get_wgsl(&self) -> Result<String, naga::back::wgsl::Error> {
|
||
|
naga::back::wgsl::write_string(&self.module, &self.module_info)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Shader {
|
||
|
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
|
||
|
Shader::Wgsl(source.into())
|
||
|
}
|
||
|
|
||
|
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
|
||
|
Shader::Glsl(source.into(), stage)
|
||
|
}
|
||
|
|
||
|
pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>) -> Shader {
|
||
|
Shader::SpirV(source.into())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Default)]
|
||
|
pub struct ShaderLoader;
|
||
|
|
||
|
impl AssetLoader for ShaderLoader {
|
||
|
fn load<'a>(
|
||
|
&'a self,
|
||
|
bytes: &'a [u8],
|
||
|
load_context: &'a mut LoadContext,
|
||
|
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
|
||
|
Box::pin(async move {
|
||
|
let ext = load_context.path().extension().unwrap().to_str().unwrap();
|
||
|
|
||
|
let shader = match ext {
|
||
|
"spv" => Shader::from_spirv(Vec::from(bytes)),
|
||
|
"wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?),
|
||
|
"vert" => Shader::from_glsl(
|
||
|
String::from_utf8(Vec::from(bytes))?,
|
||
|
naga::ShaderStage::Vertex,
|
||
|
),
|
||
|
"frag" => Shader::from_glsl(
|
||
|
String::from_utf8(Vec::from(bytes))?,
|
||
|
naga::ShaderStage::Fragment,
|
||
|
),
|
||
|
_ => panic!("unhandled extension: {}", ext),
|
||
|
};
|
||
|
|
||
|
load_context.set_default_asset(LoadedAsset::new(shader));
|
||
|
Ok(())
|
||
|
})
|
||
|
}
|
||
|
|
||
|
fn extensions(&self) -> &[&str] {
|
||
|
&["spv", "wgsl", "vert", "frag"]
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl RenderAsset for Shader {
|
||
|
type ExtractedAsset = Shader;
|
||
|
type PreparedAsset = Shader;
|
||
|
type Param = ();
|
||
|
|
||
|
fn extract_asset(&self) -> Self::ExtractedAsset {
|
||
|
self.clone()
|
||
|
}
|
||
|
|
||
|
fn prepare_asset(
|
||
|
extracted_asset: Self::ExtractedAsset,
|
||
|
_param: &mut bevy_ecs::system::SystemParamItem<Self::Param>,
|
||
|
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
|
||
|
Ok(extracted_asset)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[derive(Error, Debug, PartialEq, Eq)]
|
||
|
pub enum ProcessShaderError {
|
||
|
#[error("Too many '# endif' lines. Each endif should be preceded by an if statement.")]
|
||
|
TooManyEndIfs,
|
||
|
#[error(
|
||
|
"Not enough '# endif' lines. Each if statement should be followed by an endif statement."
|
||
|
)]
|
||
|
NotEnoughEndIfs,
|
||
|
#[error("This Shader's format does not support processing shader defs.")]
|
||
|
ShaderFormatDoesNotSupportShaderDefs,
|
||
|
}
|
||
|
|
||
|
pub struct ShaderProcessor {
|
||
|
ifdef_regex: Regex,
|
||
|
ifndef_regex: Regex,
|
||
|
endif_regex: Regex,
|
||
|
}
|
||
|
|
||
|
impl Default for ShaderProcessor {
|
||
|
fn default() -> Self {
|
||
|
Self {
|
||
|
ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(),
|
||
|
ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(),
|
||
|
endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(),
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl ShaderProcessor {
|
||
|
pub fn process_shader(
|
||
|
&self,
|
||
|
shader: &Shader,
|
||
|
shader_defs: &[String],
|
||
|
) -> Result<ProcessedShader, ProcessShaderError> {
|
||
|
match shader {
|
||
|
Shader::Wgsl(source) => Ok(ProcessedShader::Wgsl(Cow::from(
|
||
|
self.process_str(source, shader_defs)?,
|
||
|
))),
|
||
|
Shader::Glsl(source, stage) => Ok(ProcessedShader::Glsl(
|
||
|
Cow::from(self.process_str(source, shader_defs)?),
|
||
|
*stage,
|
||
|
)),
|
||
|
Shader::SpirV(source) => {
|
||
|
if shader_defs.is_empty() {
|
||
|
Ok(ProcessedShader::SpirV(source.clone()))
|
||
|
} else {
|
||
|
Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pub fn process_str(
|
||
|
&self,
|
||
|
shader: &str,
|
||
|
shader_defs: &[String],
|
||
|
) -> Result<String, ProcessShaderError> {
|
||
|
let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
|
||
|
let mut scopes = vec![true];
|
||
|
let mut final_string = String::new();
|
||
|
for line in shader.split('\n') {
|
||
|
if let Some(cap) = self.ifdef_regex.captures(line) {
|
||
|
let def = cap.get(1).unwrap();
|
||
|
scopes.push(shader_defs.contains(def.as_str()));
|
||
|
} else if let Some(cap) = self.ifndef_regex.captures(line) {
|
||
|
let def = cap.get(1).unwrap();
|
||
|
scopes.push(!shader_defs.contains(def.as_str()));
|
||
|
} else if self.endif_regex.is_match(line) {
|
||
|
scopes.pop();
|
||
|
if scopes.is_empty() {
|
||
|
return Err(ProcessShaderError::TooManyEndIfs);
|
||
|
}
|
||
|
} else if *scopes.last().unwrap() {
|
||
|
final_string.push_str(line);
|
||
|
final_string.push('\n');
|
||
|
}
|
||
|
}
|
||
|
|
||
|
final_string.pop();
|
||
|
|
||
|
if scopes.len() != 1 {
|
||
|
return Err(ProcessShaderError::NotEnoughEndIfs);
|
||
|
}
|
||
|
Ok(final_string)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
#[cfg(test)]
|
||
|
mod tests {
|
||
|
use crate::render_resource::{ProcessShaderError, ShaderProcessor};
|
||
|
#[rustfmt::skip]
|
||
|
const WGSL: &str = r"
|
||
|
[[block]]
|
||
|
struct View {
|
||
|
view_proj: mat4x4<f32>;
|
||
|
world_position: vec3<f32>;
|
||
|
};
|
||
|
[[group(0), binding(0)]]
|
||
|
var<uniform> view: View;
|
||
|
|
||
|
# ifdef TEXTURE
|
||
|
[[group(1), binding(0)]]
|
||
|
var sprite_texture: texture_2d<f32>;
|
||
|
# endif
|
||
|
|
||
|
struct VertexOutput {
|
||
|
[[location(0)]] uv: vec2<f32>;
|
||
|
[[builtin(position)]] position: vec4<f32>;
|
||
|
};
|
||
|
|
||
|
[[stage(vertex)]]
|
||
|
fn vertex(
|
||
|
[[location(0)]] vertex_position: vec3<f32>,
|
||
|
[[location(1)]] vertex_uv: vec2<f32>
|
||
|
) -> VertexOutput {
|
||
|
var out: VertexOutput;
|
||
|
out.uv = vertex_uv;
|
||
|
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
|
||
|
return out;
|
||
|
}
|
||
|
";
|
||
|
|
||
|
#[test]
|
||
|
fn process_shader_def_defined() {
|
||
|
#[rustfmt::skip]
|
||
|
const EXPECTED: &str = r"
|
||
|
[[block]]
|
||
|
struct View {
|
||
|
view_proj: mat4x4<f32>;
|
||
|
world_position: vec3<f32>;
|
||
|
};
|
||
|
[[group(0), binding(0)]]
|
||
|
var<uniform> view: View;
|
||
|
|
||
|
[[group(1), binding(0)]]
|
||
|
var sprite_texture: texture_2d<f32>;
|
||
|
|
||
|
struct VertexOutput {
|
||
|
[[location(0)]] uv: vec2<f32>;
|
||
|
[[builtin(position)]] position: vec4<f32>;
|
||
|
};
|
||
|
|
||
|
[[stage(vertex)]]
|
||
|
fn vertex(
|
||
|
[[location(0)]] vertex_position: vec3<f32>,
|
||
|
[[location(1)]] vertex_uv: vec2<f32>
|
||
|
) -> VertexOutput {
|
||
|
var out: VertexOutput;
|
||
|
out.uv = vertex_uv;
|
||
|
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
|
||
|
return out;
|
||
|
}
|
||
|
";
|
||
|
let processor = ShaderProcessor::default();
|
||
|
let result = processor
|
||
|
.process_str(WGSL, &["TEXTURE".to_string()])
|
||
|
.unwrap();
|
||
|
assert_eq!(result, EXPECTED);
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn process_shader_def_not_defined() {
|
||
|
#[rustfmt::skip]
|
||
|
const EXPECTED: &str = r"
|
||
|
[[block]]
|
||
|
struct View {
|
||
|
view_proj: mat4x4<f32>;
|
||
|
world_position: vec3<f32>;
|
||
|
};
|
||
|
[[group(0), binding(0)]]
|
||
|
var<uniform> view: View;
|
||
|
|
||
|
|
||
|
struct VertexOutput {
|
||
|
[[location(0)]] uv: vec2<f32>;
|
||
|
[[builtin(position)]] position: vec4<f32>;
|
||
|
};
|
||
|
|
||
|
[[stage(vertex)]]
|
||
|
fn vertex(
|
||
|
[[location(0)]] vertex_position: vec3<f32>,
|
||
|
[[location(1)]] vertex_uv: vec2<f32>
|
||
|
) -> VertexOutput {
|
||
|
var out: VertexOutput;
|
||
|
out.uv = vertex_uv;
|
||
|
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
|
||
|
return out;
|
||
|
}
|
||
|
";
|
||
|
let processor = ShaderProcessor::default();
|
||
|
let result = processor.process_str(WGSL, &[]).unwrap();
|
||
|
assert_eq!(result, EXPECTED);
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn process_shader_def_unclosed() {
|
||
|
#[rustfmt::skip]
|
||
|
const INPUT: &str = r"
|
||
|
# ifdef FOO
|
||
|
";
|
||
|
let processor = ShaderProcessor::default();
|
||
|
let result = processor.process_str(INPUT, &[]);
|
||
|
assert_eq!(result, Err(ProcessShaderError::NotEnoughEndIfs));
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn process_shader_def_too_closed() {
|
||
|
#[rustfmt::skip]
|
||
|
const INPUT: &str = r"
|
||
|
# endif
|
||
|
";
|
||
|
let processor = ShaderProcessor::default();
|
||
|
let result = processor.process_str(INPUT, &[]);
|
||
|
assert_eq!(result, Err(ProcessShaderError::TooManyEndIfs));
|
||
|
}
|
||
|
|
||
|
#[test]
|
||
|
fn process_shader_def_commented() {
|
||
|
#[rustfmt::skip]
|
||
|
const INPUT: &str = r"
|
||
|
// # ifdef FOO
|
||
|
fn foo() { }
|
||
|
";
|
||
|
let processor = ShaderProcessor::default();
|
||
|
let result = processor.process_str(INPUT, &[]).unwrap();
|
||
|
assert_eq!(result, INPUT);
|
||
|
}
|
||
|
}
|