bevy/pipelined/bevy_render2/src/render_resource/shader.rs

436 lines
13 KiB
Rust
Raw Normal View History

Pipeline Specialization, Shader Assets, and Shader Preprocessing (#3031) ## New Features This adds the following to the new renderer: * **Shader Assets** * Shaders are assets again! Users no longer need to call `include_str!` for their shaders * Shader hot-reloading * **Shader Defs / Shader Preprocessing** * Shaders now support `# ifdef NAME`, `# ifndef NAME`, and `# endif` preprocessor directives * **Bevy RenderPipelineDescriptor and RenderPipelineCache** * Bevy now provides its own `RenderPipelineDescriptor` and the wgpu version is now exported as `RawRenderPipelineDescriptor`. This allows users to define pipelines with `Handle<Shader>` instead of needing to manually compile and reference `ShaderModules`, enables passing in shader defs to configure the shader preprocessor, makes hot reloading possible (because the descriptor can be owned and used to create new pipelines when a shader changes), and opens the doors to pipeline specialization. * The `RenderPipelineCache` now handles compiling and re-compiling Bevy RenderPipelineDescriptors. It has internal PipelineLayout and ShaderModule caches. Users receive a `CachedPipelineId`, which can be used to look up the actual `&RenderPipeline` during rendering. * **Pipeline Specialization** * This enables defining per-entity-configurable pipelines that specialize on arbitrary custom keys. In practice this will involve specializing based on things like MSAA values, Shader Defs, Bind Group existence, and Vertex Layouts. * Adds a `SpecializedPipeline` trait and `SpecializedPipelines<MyPipeline>` resource. This is a simple layer that generates Bevy RenderPipelineDescriptors based on a custom key defined for the pipeline. * Specialized pipelines are also hot-reloadable. * This was the result of experimentation with two different approaches: 1. **"generic immediate mode multi-key hash pipeline specialization"** * breaks up the pipeline into multiple "identities" (the core pipeline definition, shader defs, mesh layout, bind group layout). each of these identities has its own key. looking up / compiling a specific version of a pipeline requires composing all of these keys together * the benefit of this approach is that it works for all pipelines / the pipeline is fully identified by the keys. the multiple keys allow pre-hashing parts of the pipeline identity where possible (ex: pre compute the mesh identity for all meshes) * the downside is that any per-entity data that informs the values of these keys could require expensive re-hashes. computing each key for each sprite tanked bevymark performance (sprites don't actually need this level of specialization yet ... but things like pbr and future sprite scenarios might). * this is the approach rafx used last time i checked 2. **"custom key specialization"** * Pipelines by default are not specialized * Pipelines that need specialization implement a SpecializedPipeline trait with a custom key associated type * This allows specialization keys to encode exactly the amount of information required (instead of needing to be a combined hash of the entire pipeline). Generally this should fit in a small number of bytes. Per-entity specialization barely registers anymore on things like bevymark. It also makes things like "shader defs" way cheaper to hash because we can use context specific bitflags instead of strings. * Despite the extra trait, it actually generally makes pipeline definitions + lookups simpler: managing multiple keys (and making the appropriate calls to manage these keys) was way more complicated. * I opted for custom key specialization. It performs better generally and in my opinion is better UX. Fortunately the way this is implemented also allows for custom caches as this all builds on a common abstraction: the RenderPipelineCache. The built in custom key trait is just a simple / pre-defined way to interact with the cache ## Callouts * The SpecializedPipeline trait makes it easy to inherit pipeline configuration in custom pipelines. The changes to `custom_shader_pipelined` and the new `shader_defs_pipelined` example illustrate how much simpler it is to define custom pipelines based on the PbrPipeline. * The shader preprocessor is currently pretty naive (it just uses regexes to process each line). Ultimately we might want to build a more custom parser for more performance + better error handling, but for now I'm happy to optimize for "easy to implement and understand". ## Next Steps * Port compute pipelines to the new system * Add more preprocessor directives (else, elif, import) * More flexible vertex attribute specialization / enable cheaply specializing on specific mesh vertex layouts
2021-10-28 19:07:47 +00:00
use bevy_asset::{AssetLoader, LoadContext, LoadedAsset};
use bevy_reflect::{TypeUuid, Uuid};
use bevy_utils::{tracing::error, BoxedFuture};
use naga::{valid::ModuleInfo, Module};
use regex::Regex;
use std::{borrow::Cow, collections::HashSet, marker::Copy};
use thiserror::Error;
use wgpu::{ShaderModuleDescriptor, ShaderSource};
use crate::render_asset::{PrepareAssetError, RenderAsset};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ShaderId(Uuid);
impl ShaderId {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
ShaderId(Uuid::new_v4())
}
}
#[derive(Error, Debug)]
pub enum ShaderReflectError {
#[error(transparent)]
WgslParse(#[from] naga::front::wgsl::ParseError),
#[error("GLSL Parse Error: {0:?}")]
GlslParse(Vec<naga::front::glsl::Error>),
#[error(transparent)]
SpirVParse(#[from] naga::front::spv::Error),
#[error(transparent)]
Validation(#[from] naga::valid::ValidationError),
}
/// A shader, as defined by its [ShaderSource] and [ShaderStage]
/// This is an "unprocessed" shader. It can contain preprocessor directives.
#[derive(Debug, Clone, TypeUuid)]
#[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"]
pub enum Shader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
// TODO: consider the following
// PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
// NagaModule(Module) ... Module impls Serialize/Deserialize
}
/// A processed [Shader]. This cannot contain preprocessor directions. It must be "ready to compile"
pub enum ProcessedShader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
}
impl ProcessedShader {
pub fn reflect(&self) -> Result<ShaderReflection, ShaderReflectError> {
let module = match &self {
// TODO: process macros here
ProcessedShader::Wgsl(source) => naga::front::wgsl::parse_str(source)?,
ProcessedShader::Glsl(source, shader_stage) => {
let mut parser = naga::front::glsl::Parser::default();
parser
.parse(&naga::front::glsl::Options::from(*shader_stage), source)
.map_err(ShaderReflectError::GlslParse)?
}
ProcessedShader::SpirV(source) => naga::front::spv::parse_u8_slice(
source,
&naga::front::spv::Options {
adjust_coordinate_space: false,
..naga::front::spv::Options::default()
},
)?,
};
let module_info = naga::valid::Validator::new(
naga::valid::ValidationFlags::default(),
naga::valid::Capabilities::default(),
)
.validate(&module)?;
Ok(ShaderReflection {
module,
module_info,
})
}
pub fn get_module_descriptor(&self) -> Result<ShaderModuleDescriptor, AsModuleDescriptorError> {
Ok(ShaderModuleDescriptor {
label: None,
source: match self {
ProcessedShader::Wgsl(source) => ShaderSource::Wgsl(source.clone()),
ProcessedShader::Glsl(_source, _stage) => {
let reflection = self.reflect()?;
// TODO: it probably makes more sense to convert this to spirv, but as of writing
// this comment, naga's spirv conversion is broken
let wgsl = reflection.get_wgsl()?;
ShaderSource::Wgsl(wgsl.into())
}
ProcessedShader::SpirV(_) => {
// TODO: we can probably just transmute the u8 array to u32?
let reflection = self.reflect()?;
let spirv = reflection.get_spirv()?;
ShaderSource::SpirV(Cow::Owned(spirv))
}
},
})
}
}
#[derive(Error, Debug)]
pub enum AsModuleDescriptorError {
#[error(transparent)]
ShaderReflectError(#[from] ShaderReflectError),
#[error(transparent)]
WgslConversion(#[from] naga::back::wgsl::Error),
#[error(transparent)]
SpirVConversion(#[from] naga::back::spv::Error),
}
pub struct ShaderReflection {
pub module: Module,
pub module_info: ModuleInfo,
}
impl ShaderReflection {
pub fn get_spirv(&self) -> Result<Vec<u32>, naga::back::spv::Error> {
naga::back::spv::write_vec(
&self.module,
&self.module_info,
&naga::back::spv::Options {
flags: naga::back::spv::WriterFlags::empty(),
..naga::back::spv::Options::default()
},
None,
)
}
pub fn get_wgsl(&self) -> Result<String, naga::back::wgsl::Error> {
naga::back::wgsl::write_string(&self.module, &self.module_info)
}
}
impl Shader {
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
Shader::Wgsl(source.into())
}
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
Shader::Glsl(source.into(), stage)
}
pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>) -> Shader {
Shader::SpirV(source.into())
}
}
#[derive(Default)]
pub struct ShaderLoader;
impl AssetLoader for ShaderLoader {
fn load<'a>(
&'a self,
bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let ext = load_context.path().extension().unwrap().to_str().unwrap();
let shader = match ext {
"spv" => Shader::from_spirv(Vec::from(bytes)),
"wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?),
"vert" => Shader::from_glsl(
String::from_utf8(Vec::from(bytes))?,
naga::ShaderStage::Vertex,
),
"frag" => Shader::from_glsl(
String::from_utf8(Vec::from(bytes))?,
naga::ShaderStage::Fragment,
),
_ => panic!("unhandled extension: {}", ext),
};
load_context.set_default_asset(LoadedAsset::new(shader));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["spv", "wgsl", "vert", "frag"]
}
}
impl RenderAsset for Shader {
type ExtractedAsset = Shader;
type PreparedAsset = Shader;
type Param = ();
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
extracted_asset: Self::ExtractedAsset,
_param: &mut bevy_ecs::system::SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
Ok(extracted_asset)
}
}
#[derive(Error, Debug, PartialEq, Eq)]
pub enum ProcessShaderError {
#[error("Too many '# endif' lines. Each endif should be preceded by an if statement.")]
TooManyEndIfs,
#[error(
"Not enough '# endif' lines. Each if statement should be followed by an endif statement."
)]
NotEnoughEndIfs,
#[error("This Shader's format does not support processing shader defs.")]
ShaderFormatDoesNotSupportShaderDefs,
}
pub struct ShaderProcessor {
ifdef_regex: Regex,
ifndef_regex: Regex,
endif_regex: Regex,
}
impl Default for ShaderProcessor {
fn default() -> Self {
Self {
ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(),
ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(),
endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(),
}
}
}
impl ShaderProcessor {
pub fn process_shader(
&self,
shader: &Shader,
shader_defs: &[String],
) -> Result<ProcessedShader, ProcessShaderError> {
match shader {
Shader::Wgsl(source) => Ok(ProcessedShader::Wgsl(Cow::from(
self.process_str(source, shader_defs)?,
))),
Shader::Glsl(source, stage) => Ok(ProcessedShader::Glsl(
Cow::from(self.process_str(source, shader_defs)?),
*stage,
)),
Shader::SpirV(source) => {
if shader_defs.is_empty() {
Ok(ProcessedShader::SpirV(source.clone()))
} else {
Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs)
}
}
}
}
pub fn process_str(
&self,
shader: &str,
shader_defs: &[String],
) -> Result<String, ProcessShaderError> {
let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let mut scopes = vec![true];
let mut final_string = String::new();
for line in shader.split('\n') {
if let Some(cap) = self.ifdef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(shader_defs.contains(def.as_str()));
} else if let Some(cap) = self.ifndef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(!shader_defs.contains(def.as_str()));
} else if self.endif_regex.is_match(line) {
scopes.pop();
if scopes.is_empty() {
return Err(ProcessShaderError::TooManyEndIfs);
}
} else if *scopes.last().unwrap() {
final_string.push_str(line);
final_string.push('\n');
}
}
final_string.pop();
if scopes.len() != 1 {
return Err(ProcessShaderError::NotEnoughEndIfs);
}
Ok(final_string)
}
}
#[cfg(test)]
mod tests {
use crate::render_resource::{ProcessShaderError, ShaderProcessor};
#[rustfmt::skip]
const WGSL: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
# ifdef TEXTURE
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
# endif
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[test]
fn process_shader_def_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL, &["TEXTURE".to_string()])
.unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_shader_def_not_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor.process_str(WGSL, &[]).unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_shader_def_unclosed() {
#[rustfmt::skip]
const INPUT: &str = r"
# ifdef FOO
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
assert_eq!(result, Err(ProcessShaderError::NotEnoughEndIfs));
}
#[test]
fn process_shader_def_too_closed() {
#[rustfmt::skip]
const INPUT: &str = r"
# endif
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
assert_eq!(result, Err(ProcessShaderError::TooManyEndIfs));
}
#[test]
fn process_shader_def_commented() {
#[rustfmt::skip]
const INPUT: &str = r"
// # ifdef FOO
fn foo() { }
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]).unwrap();
assert_eq!(result, INPUT);
}
}