Live reloading of shaders (#937)

* Add ShaderLoader, rebuild pipelines for modified shader assets
* New example
* Add shader_update_system, ShaderError, remove specialization assets
* Don't panic on shader compilation failure
This commit is contained in:
Al M 2020-12-07 13:32:13 -07:00 committed by GitHub
parent a3bca7e464
commit 2c9b7956d1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 329 additions and 52 deletions

View file

@ -274,6 +274,10 @@ path = "examples/reflection/trait_reflection.rs"
name = "scene"
path = "examples/scene/scene.rs"
[[example]]
name = "hot_shader_reloading"
path = "examples/shader/hot_shader_reloading.rs"
[[example]]
name = "mesh_custom_attribute"
path = "examples/shader/mesh_custom_attribute.rs"

11
assets/shaders/hot.frag Normal file
View file

@ -0,0 +1,11 @@
#version 450
layout(location = 0) out vec4 o_Target;
layout(set = 2, binding = 0) uniform MyMaterial_color {
vec4 color;
};
void main() {
o_Target = color * 0.5;
}

15
assets/shaders/hot.vert Normal file
View file

@ -0,0 +1,15 @@
#version 450
layout(location = 0) in vec3 Vertex_Position;
layout(set = 0, binding = 0) uniform Camera {
mat4 ViewProj;
};
layout(set = 1, binding = 0) uniform Transform {
mat4 Model;
};
void main() {
gl_Position = ViewProj * Model * vec4(Vertex_Position, 1.0);
}

View file

@ -44,6 +44,7 @@ use render_graph::{
RenderGraph,
};
use renderer::{AssetRenderResourceBindings, RenderResourceBindings};
use shader::ShaderLoader;
#[cfg(feature = "hdr")]
use texture::HdrTextureLoader;
#[cfg(feature = "png")]
@ -87,6 +88,8 @@ impl Plugin for RenderPlugin {
app.init_asset_loader::<HdrTextureLoader>();
}
app.init_asset_loader::<ShaderLoader>();
if app.resources().get::<ClearColor>().is_none() {
app.resources_mut().insert(ClearColor::default());
}
@ -134,6 +137,7 @@ impl Plugin for RenderPlugin {
camera::visible_entities_system,
)
// TODO: turn these "resource systems" into graph nodes and remove the RENDER_RESOURCE stage
.add_system_to_stage(stage::RENDER_RESOURCE, shader::shader_update_system)
.add_system_to_stage(stage::RENDER_RESOURCE, mesh::mesh_resource_provider_system)
.add_system_to_stage(stage::RENDER_RESOURCE, Texture::texture_resource_system)
.add_system_to_stage(

View file

@ -2,7 +2,7 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto
use crate::{
pipeline::{BindType, InputStepMode, VertexBufferDescriptor},
renderer::RenderResourceContext,
shader::{Shader, ShaderSource},
shader::{Shader, ShaderError, ShaderSource},
};
use bevy_asset::{Assets, Handle};
use bevy_reflect::Reflect;
@ -60,6 +60,7 @@ struct SpecializedPipeline {
#[derive(Debug, Default)]
pub struct PipelineCompiler {
specialized_shaders: HashMap<Handle<Shader>, Vec<SpecializedShader>>,
specialized_shader_pipelines: HashMap<Handle<Shader>, Vec<Handle<PipelineDescriptor>>>,
specialized_pipelines: HashMap<Handle<PipelineDescriptor>, Vec<SpecializedPipeline>>,
}
@ -70,7 +71,7 @@ impl PipelineCompiler {
shaders: &mut Assets<Shader>,
shader_handle: &Handle<Shader>,
shader_specialization: &ShaderSpecialization,
) -> Handle<Shader> {
) -> Result<Handle<Shader>, ShaderError> {
let specialized_shaders = self
.specialized_shaders
.entry(shader_handle.clone_weak())
@ -80,7 +81,7 @@ impl PipelineCompiler {
// don't produce new shader if the input source is already spirv
if let ShaderSource::Spirv(_) = shader.source {
return shader_handle.clone_weak();
return Ok(shader_handle.clone_weak());
}
if let Some(specialized_shader) =
@ -91,7 +92,7 @@ impl PipelineCompiler {
})
{
// if shader has already been compiled with current configuration, use existing shader
specialized_shader.shader.clone_weak()
Ok(specialized_shader.shader.clone_weak())
} else {
// if no shader exists with the current configuration, create new shader and compile
let shader_def_vec = shader_specialization
@ -100,14 +101,14 @@ impl PipelineCompiler {
.cloned()
.collect::<Vec<String>>();
let compiled_shader =
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec));
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec))?;
let specialized_handle = shaders.add(compiled_shader);
let weak_specialized_handle = specialized_handle.clone_weak();
specialized_shaders.push(SpecializedShader {
shader: specialized_handle,
specialization: shader_specialization.clone(),
});
weak_specialized_handle
Ok(weak_specialized_handle)
}
}
@ -138,23 +139,31 @@ impl PipelineCompiler {
) -> Handle<PipelineDescriptor> {
let source_descriptor = pipelines.get(source_pipeline).unwrap();
let mut specialized_descriptor = source_descriptor.clone();
specialized_descriptor.shader_stages.vertex = self.compile_shader(
render_resource_context,
shaders,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
);
let specialized_vertex_shader = self
.compile_shader(
render_resource_context,
shaders,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
)
.unwrap();
specialized_descriptor.shader_stages.vertex = specialized_vertex_shader.clone_weak();
let mut specialized_fragment_shader = None;
specialized_descriptor.shader_stages.fragment = specialized_descriptor
.shader_stages
.fragment
.as_ref()
.map(|fragment| {
self.compile_shader(
render_resource_context,
shaders,
fragment,
&pipeline_specialization.shader_specialization,
)
let shader = self
.compile_shader(
render_resource_context,
shaders,
fragment,
&pipeline_specialization.shader_specialization,
)
.unwrap();
specialized_fragment_shader = Some(shader.clone_weak());
shader
});
let mut layout = render_resource_context.reflect_pipeline_layout(
@ -244,6 +253,18 @@ impl PipelineCompiler {
&shaders,
);
// track specialized shader pipelines
self.specialized_shader_pipelines
.entry(specialized_vertex_shader)
.or_insert_with(Default::default)
.push(source_pipeline.clone_weak());
if let Some(specialized_fragment_shader) = specialized_fragment_shader {
self.specialized_shader_pipelines
.entry(specialized_fragment_shader)
.or_insert_with(Default::default)
.push(source_pipeline.clone_weak());
}
let specialized_pipelines = self
.specialized_pipelines
.entry(source_pipeline.clone_weak())
@ -282,4 +303,56 @@ impl PipelineCompiler {
})
.flatten()
}
/// Update specialized shaders and remove any related specialized
/// pipelines and assets.
pub fn update_shader(
&mut self,
shader: &Handle<Shader>,
pipelines: &mut Assets<PipelineDescriptor>,
shaders: &mut Assets<Shader>,
render_resource_context: &dyn RenderResourceContext,
) -> Result<(), ShaderError> {
if let Some(specialized_shaders) = self.specialized_shaders.get_mut(shader) {
for specialized_shader in specialized_shaders {
// Recompile specialized shader. If it fails, we bail immediately.
let shader_def_vec = specialized_shader
.specialization
.shader_defs
.iter()
.cloned()
.collect::<Vec<String>>();
let new_handle =
shaders.add(render_resource_context.get_specialized_shader(
shaders.get(shader).unwrap(),
Some(&shader_def_vec),
)?);
// Replace handle and remove old from assets.
let old_handle = std::mem::replace(&mut specialized_shader.shader, new_handle);
shaders.remove(&old_handle);
// Find source pipelines that use the old specialized
// shader, and remove from tracking.
if let Some(source_pipelines) =
self.specialized_shader_pipelines.remove(&old_handle)
{
// Remove all specialized pipelines from tracking
// and asset storage. They will be rebuilt on next
// draw.
for source_pipeline in source_pipelines {
if let Some(specialized_pipelines) =
self.specialized_pipelines.remove(&source_pipeline)
{
for p in specialized_pipelines {
pipelines.remove(p.pipeline);
}
}
}
}
}
}
Ok(())
}
}

View file

@ -2,7 +2,7 @@ use super::RenderResourceContext;
use crate::{
pipeline::{BindGroupDescriptorId, PipelineDescriptor},
renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId},
shader::Shader,
shader::{Shader, ShaderError},
texture::{SamplerDescriptor, TextureDescriptor},
};
use bevy_asset::{Assets, Handle, HandleUntyped};
@ -149,8 +149,12 @@ impl RenderResourceContext for HeadlessRenderResourceContext {
size
}
fn get_specialized_shader(&self, shader: &Shader, _macros: Option<&[String]>) -> Shader {
shader.clone()
fn get_specialized_shader(
&self,
shader: &Shader,
_macros: Option<&[String]>,
) -> Result<Shader, ShaderError> {
Ok(shader.clone())
}
fn remove_stale_bind_groups(&self) {}

View file

@ -1,7 +1,7 @@
use crate::{
pipeline::{BindGroupDescriptorId, PipelineDescriptor, PipelineLayout},
renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId},
shader::{Shader, ShaderLayout, ShaderStages},
shader::{Shader, ShaderError, ShaderLayout, ShaderStages},
texture::{SamplerDescriptor, TextureDescriptor},
};
use bevy_asset::{Asset, Assets, Handle, HandleUntyped};
@ -29,7 +29,11 @@ pub trait RenderResourceContext: Downcast + Send + Sync + 'static {
fn create_buffer_with_data(&self, buffer_info: BufferInfo, data: &[u8]) -> BufferId;
fn create_shader_module(&self, shader_handle: &Handle<Shader>, shaders: &Assets<Shader>);
fn create_shader_module_from_source(&self, shader_handle: &Handle<Shader>, shader: &Shader);
fn get_specialized_shader(&self, shader: &Shader, macros: Option<&[String]>) -> Shader;
fn get_specialized_shader(
&self,
shader: &Shader,
macros: Option<&[String]>,
) -> Result<Shader, ShaderError>;
fn remove_buffer(&self, buffer: BufferId);
fn remove_texture(&self, texture: TextureId);
fn remove_sampler(&self, sampler: SamplerId);

View file

@ -1,7 +1,16 @@
use crate::{
pipeline::{PipelineCompiler, PipelineDescriptor},
renderer::RenderResourceContext,
};
use super::ShaderLayout;
use bevy_asset::Handle;
use bevy_app::{EventReader, Events};
use bevy_asset::{AssetEvent, AssetLoader, Assets, Handle, LoadContext, LoadedAsset};
use bevy_ecs::{Local, Res, ResMut};
use bevy_reflect::TypeUuid;
use bevy_utils::{tracing::error, BoxedFuture};
use std::marker::Copy;
use thiserror::Error;
/// The stage of a shader
#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)]
@ -11,6 +20,18 @@ pub enum ShaderStage {
Compute,
}
/// An error that occurs during shader handling.
#[derive(Error, Debug)]
pub enum ShaderError {
/// Shader compilation error.
#[error("Shader compilation error: {0}")]
Compilation(String),
#[cfg(target_os = "ios")]
/// shaderc error.
#[error("shaderc error")]
ShaderC(#[from] shaderc::Error),
}
#[cfg(all(not(target_os = "ios"), not(target_arch = "wasm32")))]
impl Into<bevy_glsl_to_spirv::ShaderType> for ShaderStage {
fn into(self) -> bevy_glsl_to_spirv::ShaderType {
@ -27,8 +48,9 @@ pub fn glsl_to_spirv(
glsl_source: &str,
stage: ShaderStage,
shader_defs: Option<&[String]>,
) -> Vec<u32> {
bevy_glsl_to_spirv::compile(glsl_source, stage.into(), shader_defs).unwrap()
) -> Result<Vec<u32>, ShaderError> {
bevy_glsl_to_spirv::compile(glsl_source, stage.into(), shader_defs)
.map_err(ShaderError::Compilation)
}
#[cfg(target_os = "ios")]
@ -47,26 +69,24 @@ pub fn glsl_to_spirv(
glsl_source: &str,
stage: ShaderStage,
shader_defs: Option<&[String]>,
) -> Vec<u32> {
let mut compiler = shaderc::Compiler::new().unwrap();
let mut options = shaderc::CompileOptions::new().unwrap();
) -> Result<Vec<u32>, ShaderError> {
let mut compiler = shaderc::Compiler::new()?;
let mut options = shaderc::CompileOptions::new()?;
if let Some(shader_defs) = shader_defs {
for def in shader_defs.iter() {
options.add_macro_definition(def, None);
}
}
let binary_result = compiler
.compile_into_spirv(
glsl_source,
stage.into(),
"shader.glsl",
"main",
Some(&options),
)
.unwrap();
let binary_result = compiler.compile_into_spirv(
glsl_source,
stage.into(),
"shader.glsl",
"main",
Some(&options),
)?;
binary_result.as_binary().to_vec()
Ok(binary_result.as_binary().to_vec())
}
fn bytes_to_words(bytes: &[u8]) -> Vec<u32> {
@ -114,19 +134,19 @@ impl Shader {
}
#[cfg(not(target_arch = "wasm32"))]
pub fn get_spirv(&self, macros: Option<&[String]>) -> Vec<u32> {
pub fn get_spirv(&self, macros: Option<&[String]>) -> Result<Vec<u32>, ShaderError> {
match self.source {
ShaderSource::Spirv(ref bytes) => bytes.clone(),
ShaderSource::Spirv(ref bytes) => Ok(bytes.clone()),
ShaderSource::Glsl(ref source) => glsl_to_spirv(&source, self.stage, macros),
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn get_spirv_shader(&self, macros: Option<&[String]>) -> Shader {
Shader {
source: ShaderSource::Spirv(self.get_spirv(macros)),
pub fn get_spirv_shader(&self, macros: Option<&[String]>) -> Result<Shader, ShaderError> {
Ok(Shader {
source: ShaderSource::Spirv(self.get_spirv(macros)?),
stage: self.stage,
}
})
}
#[cfg(not(target_arch = "wasm32"))]
@ -188,3 +208,60 @@ impl ShaderStages {
}
}
}
#[derive(Default)]
pub struct ShaderLoader;
impl AssetLoader for ShaderLoader {
fn load<'a>(
&'a self,
bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let ext = load_context.path().extension().unwrap().to_str().unwrap();
let shader = match ext {
"vert" => Shader::from_glsl(ShaderStage::Vertex, std::str::from_utf8(bytes)?),
"frag" => Shader::from_glsl(ShaderStage::Fragment, std::str::from_utf8(bytes)?),
_ => panic!("unhandled extension: {}", ext),
};
load_context.set_default_asset(LoadedAsset::new(shader));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["vert", "frag"]
}
}
pub fn shader_update_system(
mut shaders: ResMut<Assets<Shader>>,
mut pipelines: ResMut<Assets<PipelineDescriptor>>,
shader_events: Res<Events<AssetEvent<Shader>>>,
mut shader_event_reader: Local<EventReader<AssetEvent<Shader>>>,
mut pipeline_compiler: ResMut<PipelineCompiler>,
render_resource_context: Res<Box<dyn RenderResourceContext>>,
) {
for event in shader_event_reader.iter(&shader_events) {
match event {
AssetEvent::Modified { handle } => {
if let Err(e) = pipeline_compiler.update_shader(
handle,
&mut pipelines,
&mut shaders,
&**render_resource_context,
) {
error!("Failed to update shader: {}", e);
}
}
// Creating shaders on the fly is unhandled since they
// have to exist already when assigned to a pipeline. If a
// shader is removed the pipeline keeps using its
// specialized version. Maybe this should be a warning?
AssetEvent::Created { .. } | AssetEvent::Removed { .. } => (),
}
}
}

View file

@ -328,7 +328,8 @@ mod tests {
}
"#,
)
.get_spirv_shader(None);
.get_spirv_shader(None)
.unwrap();
let layout = vertex_shader.reflect_layout(true).unwrap();
assert_eq!(

View file

@ -12,7 +12,7 @@ use bevy_render::{
BindGroup, BufferId, BufferInfo, RenderResourceBinding, RenderResourceContext,
RenderResourceId, SamplerId, TextureId,
},
shader::{glsl_to_spirv, Shader, ShaderSource},
shader::{glsl_to_spirv, Shader, ShaderError, ShaderSource},
texture::{Extent3d, SamplerDescriptor, TextureDescriptor},
};
use bevy_utils::tracing::trace;
@ -251,7 +251,7 @@ impl RenderResourceContext for WgpuRenderResourceContext {
fn create_shader_module_from_source(&self, shader_handle: &Handle<Shader>, shader: &Shader) {
let mut shader_modules = self.resources.shader_modules.write();
let spirv: Cow<[u32]> = shader.get_spirv(None).into();
let spirv: Cow<[u32]> = shader.get_spirv(None).unwrap().into();
let shader_module = self
.device
.create_shader_module(wgpu::ShaderModuleSource::SpirV(spirv));
@ -574,14 +574,18 @@ impl RenderResourceContext for WgpuRenderResourceContext {
}
}
fn get_specialized_shader(&self, shader: &Shader, macros: Option<&[String]>) -> Shader {
fn get_specialized_shader(
&self,
shader: &Shader,
macros: Option<&[String]>,
) -> Result<Shader, ShaderError> {
let spirv_data = match shader.source {
ShaderSource::Spirv(ref bytes) => bytes.clone(),
ShaderSource::Glsl(ref source) => glsl_to_spirv(&source, shader.stage, macros),
ShaderSource::Glsl(ref source) => glsl_to_spirv(&source, shader.stage, macros)?,
};
Shader {
Ok(Shader {
source: ShaderSource::Spirv(spirv_data),
..*shader
}
})
}
}

View file

@ -0,0 +1,80 @@
use bevy::{
prelude::*,
reflect::TypeUuid,
render::{
mesh::shape,
pipeline::{PipelineDescriptor, RenderPipeline},
render_graph::{base, AssetRenderResourcesNode, RenderGraph},
renderer::RenderResources,
shader::ShaderStages,
},
};
/// This example illustrates how to load shaders such that they can be
/// edited while the example is still running.
fn main() {
App::build()
.add_plugins(DefaultPlugins)
.add_asset::<MyMaterial>()
.add_startup_system(setup)
.run();
}
#[derive(RenderResources, Default, TypeUuid)]
#[uuid = "3bf9e364-f29d-4d6c-92cf-93298466c620"]
struct MyMaterial {
pub color: Color,
}
fn setup(
commands: &mut Commands,
asset_server: ResMut<AssetServer>,
mut pipelines: ResMut<Assets<PipelineDescriptor>>,
mut meshes: ResMut<Assets<Mesh>>,
mut materials: ResMut<Assets<MyMaterial>>,
mut render_graph: ResMut<RenderGraph>,
) {
// Watch for changes
asset_server.watch_for_changes().unwrap();
// Create a new shader pipeline with shaders loaded from the asset directory
let pipeline_handle = pipelines.add(PipelineDescriptor::default_config(ShaderStages {
vertex: asset_server.load::<Shader, _>("shaders/hot.vert"),
fragment: Some(asset_server.load::<Shader, _>("shaders/hot.frag")),
}));
// Add an AssetRenderResourcesNode to our Render Graph. This will bind MyMaterial resources to our shader
render_graph.add_system_node(
"my_material",
AssetRenderResourcesNode::<MyMaterial>::new(true),
);
// Add a Render Graph edge connecting our new "my_material" node to the main pass node. This ensures "my_material" runs before the main pass
render_graph
.add_node_edge("my_material", base::node::MAIN_PASS)
.unwrap();
// Create a new material
let material = materials.add(MyMaterial {
color: Color::rgb(0.0, 0.8, 0.0),
});
// Setup our world
commands
// cube
.spawn(MeshBundle {
mesh: meshes.add(Mesh::from(shape::Cube { size: 2.0 })),
render_pipelines: RenderPipelines::from_pipelines(vec![RenderPipeline::new(
pipeline_handle,
)]),
transform: Transform::from_translation(Vec3::new(0.0, 0.0, 0.0)),
..Default::default()
})
.with(material)
// camera
.spawn(Camera3dBundle {
transform: Transform::from_translation(Vec3::new(3.0, 5.0, -8.0))
.looking_at(Vec3::default(), Vec3::unit_y()),
..Default::default()
});
}