Pipeline Specialization, Shader Assets, and Shader Preprocessing (#3031)

## New Features
This adds the following to the new renderer:

* **Shader Assets**
  * Shaders are assets again! Users no longer need to call `include_str!` for their shaders
  * Shader hot-reloading
* **Shader Defs / Shader Preprocessing**
  * Shaders now support `# ifdef NAME`, `# ifndef NAME`, and `# endif` preprocessor directives
* **Bevy RenderPipelineDescriptor and RenderPipelineCache**
  * Bevy now provides its own `RenderPipelineDescriptor` and the wgpu version is now exported as `RawRenderPipelineDescriptor`. This allows users to define pipelines with `Handle<Shader>` instead of needing to manually compile and reference `ShaderModules`, enables passing in shader defs to configure the shader preprocessor, makes hot reloading possible (because the descriptor can be owned and used to create new pipelines when a shader changes), and opens the doors to pipeline specialization.
  * The `RenderPipelineCache` now handles compiling and re-compiling Bevy RenderPipelineDescriptors. It has internal PipelineLayout and ShaderModule caches. Users receive a `CachedPipelineId`, which can be used to look up the actual `&RenderPipeline` during rendering. 
* **Pipeline Specialization**
  * This enables defining per-entity-configurable pipelines that specialize on arbitrary custom keys. In practice this will involve specializing based on things like MSAA values, Shader Defs, Bind Group existence, and Vertex Layouts.
  * Adds a `SpecializedPipeline` trait and `SpecializedPipelines<MyPipeline>` resource. This is a simple layer that generates Bevy RenderPipelineDescriptors based on a custom key defined for the pipeline.
  * Specialized pipelines are also hot-reloadable.
  * This was the result of experimentation with two different approaches:
    1. **"generic immediate mode multi-key hash pipeline specialization"**
      * breaks up the pipeline into multiple "identities" (the core pipeline definition, shader defs, mesh layout, bind group layout). each of these identities has its own key. looking up / compiling a specific version of a pipeline requires composing all of these keys together
      * the benefit of this approach is that it works for all pipelines / the pipeline is fully identified by the keys. the multiple keys allow pre-hashing parts of the pipeline identity where possible (ex: pre compute the mesh identity for all meshes)
      * the downside is that any per-entity data that informs the values of these keys could require expensive re-hashes. computing each key for each sprite tanked bevymark performance (sprites don't actually need this level of specialization yet ... but things like pbr and future sprite scenarios might). 
      * this is the approach rafx used last time i checked
    2. **"custom key specialization"**
      * Pipelines by default are not specialized
      * Pipelines that need specialization implement a SpecializedPipeline trait with a custom key associated type
      * This allows specialization keys to encode exactly the amount of information required (instead of needing to be a combined hash of the entire pipeline). Generally this should fit in a small number of bytes. Per-entity specialization barely registers anymore on things like bevymark. It also makes things like "shader defs" way cheaper to hash because we can use context specific bitflags instead of strings.
      * Despite the extra trait, it actually generally makes pipeline definitions + lookups simpler: managing multiple keys (and making the appropriate calls to manage these keys) was way more complicated.
  * I opted for custom key specialization. It performs better generally and in my opinion is better UX. Fortunately the way this is implemented also allows for custom caches as this all builds on a common abstraction: the RenderPipelineCache. The built in custom key trait is just a simple / pre-defined way to interact with the cache 

## Callouts

* The SpecializedPipeline trait makes it easy to inherit pipeline configuration in custom pipelines. The changes to `custom_shader_pipelined` and the new `shader_defs_pipelined` example illustrate how much simpler it is to define custom pipelines based on the PbrPipeline.
* The shader preprocessor is currently pretty naive (it just uses regexes to process each line). Ultimately we might want to build a more custom parser for more performance + better error handling, but for now I'm happy to optimize for "easy to implement and understand". 

## Next Steps

* Port compute pipelines to the new system
* Add more preprocessor directives (else, elif, import)
* More flexible vertex attribute specialization / enable cheaply specializing on specific mesh vertex layouts
This commit is contained in:
Carter Anderson 2021-10-28 19:07:47 +00:00
parent 9f47697e40
commit 015617a774
24 changed files with 1529 additions and 592 deletions

View file

@ -471,6 +471,10 @@ path = "examples/shader/shader_defs.rs"
name = "custom_shader_pipelined"
path = "examples/shader/custom_shader_pipelined.rs"
[[example]]
name = "shader_defs_pipelined"
path = "examples/shader/shader_defs_pipelined.rs"
# Tools
[[example]]
name = "bevymark"

View file

@ -0,0 +1,43 @@
[[block]]
struct View {
view_proj: mat4x4<f32>;
projection: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[block]]
struct Mesh {
transform: mat4x4<f32>;
};
[[group(1), binding(0)]]
var<uniform> mesh: Mesh;
struct Vertex {
[[location(0)]] position: vec3<f32>;
[[location(1)]] normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
};
struct VertexOutput {
[[builtin(position)]] clip_position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.transform * vec4<f32>(vertex.position, 1.0);
var out: VertexOutput;
out.clip_position = view.view_proj * world_position;
return out;
}
[[stage(fragment)]]
fn fragment() -> [[location(0)]] vec4<f32> {
var color = vec4<f32>(0.0, 0.0, 1.0, 1.0);
# ifdef IS_RED
color = vec4<f32>(1.0, 0.0, 0.0, 1.0);
# endif
return color;
}

View file

@ -228,6 +228,7 @@ Example | File | Description
`mesh_custom_attribute` | [`shader/mesh_custom_attribute.rs`](./shader/mesh_custom_attribute.rs) | Illustrates how to add a custom attribute to a mesh and use it in a custom shader
`shader_custom_material` | [`shader/shader_custom_material.rs`](./shader/shader_custom_material.rs) | Illustrates creating a custom material and a shader that uses it
`shader_defs` | [`shader/shader_defs.rs`](./shader/shader_defs.rs) | Demonstrates creating a custom material that uses "shaders defs" (a tool to selectively toggle parts of a shader)
`shader_defs_pipelined` | [`shader/shader_defs_pipelined.rs`](./shader/shader_defs_pipelined.rs) | Demonstrates creating a custom material that uses "shaders defs" (a tool to selectively toggle parts of a shader)
## Tests

View file

@ -1,13 +1,16 @@
use bevy::{
core_pipeline::Transparent3d,
core_pipeline::{SetItemPipeline, Transparent3d},
diagnostic::{FrameTimeDiagnosticsPlugin, LogDiagnosticsPlugin},
ecs::{
prelude::*,
system::{lifetimeless::*, SystemParamItem},
},
math::{Vec3, Vec4},
pbr2::{DrawMesh, MeshUniform, PbrShaders, SetMeshViewBindGroup, SetTransformBindGroup},
prelude::{AddAsset, App, Assets, GlobalTransform, Handle, Plugin, Transform},
pbr2::{
DrawMesh, MeshUniform, PbrPipeline, PbrPipelineKey, SetMeshViewBindGroup,
SetTransformBindGroup,
},
prelude::{AddAsset, App, AssetServer, Assets, GlobalTransform, Handle, Plugin, Transform},
reflect::TypeUuid,
render2::{
camera::PerspectiveCameraBundle,
@ -20,8 +23,6 @@ use bevy::{
},
render_resource::*,
renderer::RenderDevice,
shader::Shader,
texture::BevyDefault,
view::ExtractedView,
RenderApp, RenderStage,
},
@ -29,6 +30,39 @@ use bevy::{
};
use crevice::std140::{AsStd140, Std140};
fn main() {
App::new()
.add_plugins(PipelinedDefaultPlugins)
.add_plugin(FrameTimeDiagnosticsPlugin::default())
.add_plugin(LogDiagnosticsPlugin::default())
.add_plugin(CustomMaterialPlugin)
.add_startup_system(setup)
.run();
}
/// set up a simple 3D scene
fn setup(
mut commands: Commands,
mut meshes: ResMut<Assets<Mesh>>,
mut materials: ResMut<Assets<CustomMaterial>>,
) {
// cube
commands.spawn().insert_bundle((
meshes.add(Mesh::from(shape::Cube { size: 1.0 })),
Transform::from_xyz(0.0, 0.5, 0.0),
GlobalTransform::default(),
materials.add(CustomMaterial {
color: Color::GREEN,
}),
));
// camera
commands.spawn_bundle(PerspectiveCameraBundle {
transform: Transform::from_xyz(-2.0, 2.5, 5.0).looking_at(Vec3::ZERO, Vec3::Y),
..Default::default()
});
}
#[derive(Debug, Clone, TypeUuid)]
#[uuid = "4ee9c363-1124-4113-890e-199d81b00281"]
pub struct CustomMaterial {
@ -88,51 +122,18 @@ impl Plugin for CustomMaterialPlugin {
}
}
fn main() {
App::new()
.add_plugins(PipelinedDefaultPlugins)
.add_plugin(FrameTimeDiagnosticsPlugin::default())
.add_plugin(LogDiagnosticsPlugin::default())
.add_plugin(CustomMaterialPlugin)
.add_startup_system(setup)
.run();
}
/// set up a simple 3D scene
fn setup(
mut commands: Commands,
mut meshes: ResMut<Assets<Mesh>>,
mut materials: ResMut<Assets<CustomMaterial>>,
) {
// cube
commands.spawn().insert_bundle((
meshes.add(Mesh::from(shape::Cube { size: 1.0 })),
Transform::from_xyz(0.0, 0.5, 0.0),
GlobalTransform::default(),
materials.add(CustomMaterial {
color: Color::GREEN,
}),
));
// camera
commands.spawn_bundle(PerspectiveCameraBundle {
transform: Transform::from_xyz(-2.0, 2.5, 5.0).looking_at(Vec3::ZERO, Vec3::Y),
..Default::default()
});
}
pub struct CustomPipeline {
material_layout: BindGroupLayout,
pipeline: RenderPipeline,
pipeline: CachedPipelineId,
}
// TODO: this pattern for initializing the shaders / pipeline isn't ideal. this should be handled by the asset system
impl FromWorld for CustomPipeline {
fn from_world(world: &mut World) -> Self {
let render_device = world.get_resource::<RenderDevice>().unwrap();
let shader = Shader::from_wgsl(include_str!("../../assets/shaders/custom.wgsl"));
let shader_module = render_device.create_shader_module(&shader);
let world = world.cell();
let asset_server = world.get_resource::<AssetServer>().unwrap();
let shader = asset_server.load("shaders/custom.wgsl");
let render_device = world.get_resource::<RenderDevice>().unwrap();
let material_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[BindGroupLayoutEntry {
binding: 0,
@ -146,99 +147,20 @@ impl FromWorld for CustomPipeline {
}],
label: None,
});
let pbr_pipeline = world.get_resource::<PbrShaders>().unwrap();
let pipeline_layout = render_device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: None,
push_constant_ranges: &[],
bind_group_layouts: &[
&pbr_pipeline.view_layout,
&material_layout,
&pbr_pipeline.mesh_layout,
],
});
let pipeline = render_device.create_render_pipeline(&RenderPipelineDescriptor {
label: None,
vertex: VertexState {
buffers: &[VertexBufferLayout {
array_stride: 32,
step_mode: VertexStepMode::Vertex,
attributes: &[
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 24,
shader_location: 2,
},
],
}],
module: &shader_module,
entry_point: "vertex",
},
fragment: Some(FragmentState {
module: &shader_module,
entry_point: "fragment",
targets: &[ColorTargetState {
format: TextureFormat::bevy_default(),
blend: Some(BlendState {
color: BlendComponent {
src_factor: BlendFactor::SrcAlpha,
dst_factor: BlendFactor::OneMinusSrcAlpha,
operation: BlendOperation::Add,
},
alpha: BlendComponent {
src_factor: BlendFactor::One,
dst_factor: BlendFactor::One,
operation: BlendOperation::Add,
},
}),
write_mask: ColorWrites::ALL,
}],
}),
depth_stencil: Some(DepthStencilState {
format: TextureFormat::Depth32Float,
depth_write_enabled: true,
depth_compare: CompareFunction::Greater,
stencil: StencilState {
front: StencilFaceState::IGNORE,
back: StencilFaceState::IGNORE,
read_mask: 0,
write_mask: 0,
},
bias: DepthBiasState {
constant: 0,
slope_scale: 0.0,
clamp: 0.0,
},
}),
layout: Some(&pipeline_layout),
multisample: MultisampleState::default(),
primitive: PrimitiveState {
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: Some(Face::Back),
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
},
});
let pbr_pipeline = world.get_resource::<PbrPipeline>().unwrap();
let mut descriptor = pbr_pipeline.specialize(PbrPipelineKey::empty());
descriptor.vertex.shader = shader.clone();
descriptor.fragment.as_mut().unwrap().shader = shader;
descriptor.layout = Some(vec![
pbr_pipeline.view_layout.clone(),
material_layout.clone(),
pbr_pipeline.mesh_layout.clone(),
]);
let mut pipeline_cache = world.get_resource_mut::<RenderPipelineCache>().unwrap();
CustomPipeline {
pipeline,
pipeline: pipeline_cache.queue(descriptor),
material_layout,
}
}
@ -247,6 +169,7 @@ impl FromWorld for CustomPipeline {
pub fn queue_custom(
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
materials: Res<RenderAssets<CustomMaterial>>,
custom_pipeline: Res<CustomPipeline>,
material_meshes: Query<(Entity, &Handle<CustomMaterial>, &MeshUniform), With<Handle<Mesh>>>,
mut views: Query<(&ExtractedView, &mut RenderPhase<Transparent3d>)>,
) {
@ -261,6 +184,7 @@ pub fn queue_custom(
if materials.contains_key(material_handle) {
transparent_phase.add(Transparent3d {
entity,
pipeline: custom_pipeline.pipeline,
draw_function: draw_custom,
distance: view_row_2.dot(mesh_uniform.transform.col(3)),
});
@ -270,28 +194,27 @@ pub fn queue_custom(
}
type DrawCustom = (
SetCustomMaterialPipeline,
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetCustomMaterialBindGroup,
SetTransformBindGroup<2>,
DrawMesh,
);
struct SetCustomMaterialPipeline;
impl RenderCommand<Transparent3d> for SetCustomMaterialPipeline {
struct SetCustomMaterialBindGroup;
impl RenderCommand<Transparent3d> for SetCustomMaterialBindGroup {
type Param = (
SRes<RenderAssets<CustomMaterial>>,
SRes<CustomPipeline>,
SQuery<Read<Handle<CustomMaterial>>>,
);
fn render<'w>(
_view: Entity,
item: &Transparent3d,
(materials, custom_pipeline, query): SystemParamItem<'w, '_, Self::Param>,
(materials, query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let material_handle = query.get(item.entity).unwrap();
let material = materials.into_inner().get(material_handle).unwrap();
pass.set_render_pipeline(&custom_pipeline.into_inner().pipeline);
pass.set_bind_group(1, &material.bind_group, &[]);
}
}

View file

@ -0,0 +1,155 @@
use bevy::{
core_pipeline::{SetItemPipeline, Transparent3d},
diagnostic::{FrameTimeDiagnosticsPlugin, LogDiagnosticsPlugin},
ecs::prelude::*,
math::Vec3,
pbr2::{
DrawMesh, MeshUniform, PbrPipeline, PbrPipelineKey, SetMeshViewBindGroup,
SetTransformBindGroup,
},
prelude::{App, AssetServer, Assets, GlobalTransform, Handle, Plugin, Transform},
render2::{
camera::PerspectiveCameraBundle,
mesh::{shape, Mesh},
render_component::{ExtractComponent, ExtractComponentPlugin},
render_phase::{AddRenderCommand, DrawFunctions, RenderPhase},
render_resource::*,
view::ExtractedView,
RenderApp, RenderStage,
},
PipelinedDefaultPlugins,
};
pub struct IsRedPlugin;
impl Plugin for IsRedPlugin {
fn build(&self, app: &mut App) {
app.add_plugin(ExtractComponentPlugin::<IsRed>::default());
app.sub_app(RenderApp)
.add_render_command::<Transparent3d, DrawIsRed>()
.init_resource::<IsRedPipeline>()
.init_resource::<SpecializedPipelines<IsRedPipeline>>()
.add_system_to_stage(RenderStage::Queue, queue_custom);
}
}
fn main() {
App::new()
.add_plugins(PipelinedDefaultPlugins)
.add_plugin(FrameTimeDiagnosticsPlugin::default())
.add_plugin(LogDiagnosticsPlugin::default())
.add_plugin(IsRedPlugin)
.add_startup_system(setup)
.run();
}
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
struct IsRed(bool);
impl ExtractComponent for IsRed {
type Query = &'static IsRed;
type Filter = ();
fn extract_component(item: bevy::ecs::query::QueryItem<Self::Query>) -> Self {
*item
}
}
/// set up a simple 3D scene
fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
// red cube
commands.spawn().insert_bundle((
meshes.add(Mesh::from(shape::Cube { size: 1.0 })),
IsRed(true),
Transform::from_xyz(-1.0, 0.5, 0.0),
GlobalTransform::default(),
));
// blue cube
commands.spawn().insert_bundle((
meshes.add(Mesh::from(shape::Cube { size: 1.0 })),
IsRed(false),
Transform::from_xyz(1.0, 0.5, 0.0),
GlobalTransform::default(),
));
// camera
commands.spawn_bundle(PerspectiveCameraBundle {
transform: Transform::from_xyz(-2.0, 2.5, 5.0).looking_at(Vec3::ZERO, Vec3::Y),
..Default::default()
});
}
struct IsRedPipeline {
shader: Handle<Shader>,
pbr_pipeline: PbrPipeline,
}
impl FromWorld for IsRedPipeline {
fn from_world(world: &mut World) -> Self {
let asset_server = world.get_resource::<AssetServer>().unwrap();
let pbr_pipeline = world.get_resource::<PbrPipeline>().unwrap();
let shader = asset_server.load("shaders/shader_defs.wgsl");
IsRedPipeline {
shader,
pbr_pipeline: pbr_pipeline.clone(),
}
}
}
impl SpecializedPipeline for IsRedPipeline {
type Key = IsRed;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
let mut shader_defs = Vec::new();
if key.0 {
shader_defs.push("IS_RED".to_string());
}
let mut descriptor = self.pbr_pipeline.specialize(PbrPipelineKey::empty());
descriptor.vertex.shader = self.shader.clone();
descriptor.vertex.shader_defs = shader_defs.clone();
let fragment = descriptor.fragment.as_mut().unwrap();
fragment.shader = self.shader.clone();
fragment.shader_defs = shader_defs;
descriptor.layout = Some(vec![
self.pbr_pipeline.view_layout.clone(),
self.pbr_pipeline.mesh_layout.clone(),
]);
descriptor
}
}
type DrawIsRed = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetTransformBindGroup<1>,
DrawMesh,
);
fn queue_custom(
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
custom_pipeline: Res<IsRedPipeline>,
mut pipelines: ResMut<SpecializedPipelines<IsRedPipeline>>,
mut pipeline_cache: ResMut<RenderPipelineCache>,
material_meshes: Query<(Entity, &MeshUniform, &IsRed), With<Handle<Mesh>>>,
mut views: Query<(&ExtractedView, &mut RenderPhase<Transparent3d>)>,
) {
let draw_custom = transparent_3d_draw_functions
.read()
.get_id::<DrawIsRed>()
.unwrap();
for (view, mut transparent_phase) in views.iter_mut() {
let view_matrix = view.transform.compute_matrix();
let view_row_2 = view_matrix.row(2);
for (entity, mesh_uniform, is_red) in material_meshes.iter() {
let pipeline = pipelines.specialize(&mut pipeline_cache, &custom_pipeline, *is_red);
transparent_phase.add(Transparent3d {
entity,
pipeline,
draw_function: draw_custom,
distance: view_row_2.dot(mesh_uniform.transform.col(3)),
});
}
}
}

View file

@ -9,12 +9,18 @@ pub use main_pass_driver::*;
use bevy_app::{App, Plugin};
use bevy_asset::Handle;
use bevy_core::FloatOrd;
use bevy_ecs::prelude::*;
use bevy_ecs::{
prelude::*,
system::{lifetimeless::SRes, SystemParamItem},
};
use bevy_render2::{
camera::{ActiveCameras, CameraPlugin},
color::Color,
render_graph::{EmptyNode, RenderGraph, SlotInfo, SlotType},
render_phase::{sort_phase_system, DrawFunctionId, DrawFunctions, PhaseItem, RenderPhase},
render_phase::{
sort_phase_system, DrawFunctionId, DrawFunctions, PhaseItem, RenderCommand, RenderPhase,
TrackedRenderPass,
},
render_resource::*,
renderer::RenderDevice,
texture::{Image, TextureCache},
@ -155,6 +161,7 @@ impl Plugin for CorePipelinePlugin {
pub struct Transparent2d {
pub sort_key: Handle<Image>,
pub entity: Entity,
pub pipeline: CachedPipelineId,
pub draw_function: DrawFunctionId,
}
@ -174,6 +181,7 @@ impl PhaseItem for Transparent2d {
pub struct Transparent3d {
pub distance: f32,
pub pipeline: CachedPipelineId,
pub entity: Entity,
pub draw_function: DrawFunctionId,
}
@ -192,6 +200,41 @@ impl PhaseItem for Transparent3d {
}
}
pub struct SetItemPipeline;
impl RenderCommand<Transparent3d> for SetItemPipeline {
type Param = SRes<RenderPipelineCache>;
#[inline]
fn render<'w>(
_view: Entity,
item: &Transparent3d,
pipeline_cache: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let pipeline = pipeline_cache
.into_inner()
.get_state(item.pipeline)
.unwrap();
pass.set_render_pipeline(pipeline);
}
}
impl RenderCommand<Transparent2d> for SetItemPipeline {
type Param = SRes<RenderPipelineCache>;
#[inline]
fn render<'w>(
_view: Entity,
item: &Transparent2d,
pipeline_cache: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let pipeline = pipeline_cache
.into_inner()
.get_state(item.pipeline)
.unwrap();
pass.set_render_pipeline(pipeline);
}
}
pub struct ViewDepthTexture {
pub texture: Texture,
pub view: TextureView,

View file

@ -73,7 +73,6 @@ impl Node for MainPass3dNode {
let draw_functions = world
.get_resource::<DrawFunctions<Transparent3d>>()
.unwrap();
let transparent_phase = self
.query
.get_manual(world, view_entity)

View file

@ -9,13 +9,15 @@ pub use material::*;
pub use render::*;
use bevy_app::prelude::*;
use bevy_asset::Handle;
use bevy_asset::{Assets, Handle, HandleUntyped};
use bevy_core_pipeline::Transparent3d;
use bevy_ecs::prelude::*;
use bevy_reflect::TypeUuid;
use bevy_render2::{
render_component::{ExtractComponentPlugin, UniformComponentPlugin},
render_graph::RenderGraph,
render_phase::{sort_phase_system, AddRenderCommand, DrawFunctions},
render_resource::{Shader, SpecializedPipelines},
RenderApp, RenderStage,
};
@ -25,11 +27,22 @@ pub mod draw_3d_graph {
}
}
pub const PBR_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 4805239651767701046);
pub const SHADOW_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1836745567947005696);
#[derive(Default)]
pub struct PbrPlugin;
impl Plugin for PbrPlugin {
fn build(&self, app: &mut App) {
let mut shaders = app.world.get_resource_mut::<Assets<Shader>>().unwrap();
let pbr_shader = Shader::from_wgsl(include_str!("render/pbr.wgsl"));
shaders.set_untracked(PBR_SHADER_HANDLE, pbr_shader);
let shadow_shader = Shader::from_wgsl(include_str!("render/depth.wgsl"));
shaders.set_untracked(SHADOW_SHADER_HANDLE, shadow_shader);
app.add_plugin(StandardMaterialPlugin)
.add_plugin(ExtractComponentPlugin::<Handle<StandardMaterial>>::default())
.add_plugin(UniformComponentPlugin::<MeshUniform>::default())
@ -53,10 +66,11 @@ impl Plugin for PbrPlugin {
.add_system_to_stage(RenderStage::Queue, render::queue_shadow_view_bind_group)
.add_system_to_stage(RenderStage::Queue, render::queue_transform_bind_group)
.add_system_to_stage(RenderStage::PhaseSort, sort_phase_system::<Shadow>)
.init_resource::<PbrShaders>()
.init_resource::<ShadowShaders>()
.init_resource::<PbrPipeline>()
.init_resource::<ShadowPipeline>()
.init_resource::<DrawFunctions<Shadow>>()
.init_resource::<LightMeta>();
.init_resource::<LightMeta>()
.init_resource::<SpecializedPipelines<PbrPipeline>>();
let draw_shadow_mesh = DrawShadowMesh::new(&mut render_app.world);
let shadow_pass_node = ShadowPassNode::new(&mut render_app.world);

View file

@ -1,3 +1,4 @@
use crate::PbrPipeline;
use bevy_app::{App, Plugin};
use bevy_asset::{AddAsset, Handle};
use bevy_ecs::system::{lifetimeless::SRes, SystemParamItem};
@ -15,8 +16,6 @@ use bevy_render2::{
use crevice::std140::{AsStd140, Std140};
use wgpu::{BindGroupDescriptor, BindGroupEntry, BindingResource};
use crate::PbrShaders;
// NOTE: These must match the bit flags in bevy_pbr2/src/render/pbr.frag!
bitflags::bitflags! {
#[repr(transparent)]
@ -148,7 +147,7 @@ impl RenderAsset for StandardMaterial {
type PreparedAsset = GpuStandardMaterial;
type Param = (
SRes<RenderDevice>,
SRes<PbrShaders>,
SRes<PbrPipeline>,
SRes<RenderAssets<Image>>,
);
@ -158,10 +157,10 @@ impl RenderAsset for StandardMaterial {
fn prepare_asset(
material: Self::ExtractedAsset,
(render_device, pbr_shaders, gpu_images): &mut SystemParamItem<Self::Param>,
(render_device, pbr_pipeline, gpu_images): &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let (base_color_texture_view, base_color_sampler) = if let Some(result) =
image_handle_to_view_sampler(pbr_shaders, gpu_images, &material.base_color_texture)
image_handle_to_view_sampler(pbr_pipeline, gpu_images, &material.base_color_texture)
{
result
} else {
@ -169,7 +168,7 @@ impl RenderAsset for StandardMaterial {
};
let (emissive_texture_view, emissive_sampler) = if let Some(result) =
image_handle_to_view_sampler(pbr_shaders, gpu_images, &material.emissive_texture)
image_handle_to_view_sampler(pbr_pipeline, gpu_images, &material.emissive_texture)
{
result
} else {
@ -178,7 +177,7 @@ impl RenderAsset for StandardMaterial {
let (metallic_roughness_texture_view, metallic_roughness_sampler) = if let Some(result) =
image_handle_to_view_sampler(
pbr_shaders,
pbr_pipeline,
gpu_images,
&material.metallic_roughness_texture,
) {
@ -187,7 +186,7 @@ impl RenderAsset for StandardMaterial {
return Err(PrepareAssetError::RetryNextUpdate(material));
};
let (occlusion_texture_view, occlusion_sampler) = if let Some(result) =
image_handle_to_view_sampler(pbr_shaders, gpu_images, &material.occlusion_texture)
image_handle_to_view_sampler(pbr_pipeline, gpu_images, &material.occlusion_texture)
{
result
} else {
@ -267,7 +266,7 @@ impl RenderAsset for StandardMaterial {
},
],
label: Some("pbr_standard_material_bind_group"),
layout: &pbr_shaders.material_layout,
layout: &pbr_pipeline.material_layout,
});
Ok(GpuStandardMaterial { buffer, bind_group })
@ -275,7 +274,7 @@ impl RenderAsset for StandardMaterial {
}
fn image_handle_to_view_sampler<'a>(
pbr_pipeline: &'a PbrShaders,
pbr_pipeline: &'a PbrPipeline,
gpu_images: &'a RenderAssets<Image>,
handle_option: &Option<Handle<Image>>,
) -> Option<(&'a TextureView, &'a Sampler)> {

View file

@ -1,6 +1,6 @@
use crate::{
AmbientLight, DirectionalLight, DirectionalLightShadowMap, MeshUniform, NotShadowCaster,
PbrShaders, PointLight, PointLightShadowMap, TransformBindGroup,
PbrPipeline, PointLight, PointLightShadowMap, TransformBindGroup, SHADOW_SHADER_HANDLE,
};
use bevy_asset::Handle;
use bevy_core::FloatOrd;
@ -22,7 +22,6 @@ use bevy_render2::{
},
render_resource::*,
renderer::{RenderContext, RenderDevice, RenderQueue},
shader::Shader,
texture::*,
view::{ExtractedView, ViewUniformOffset, ViewUniforms},
};
@ -101,21 +100,18 @@ pub const POINT_SHADOW_LAYERS: u32 = (6 * MAX_POINT_LIGHTS) as u32;
pub const DIRECTIONAL_SHADOW_LAYERS: u32 = MAX_DIRECTIONAL_LIGHTS as u32;
pub const SHADOW_FORMAT: TextureFormat = TextureFormat::Depth32Float;
pub struct ShadowShaders {
pub shader_module: ShaderModule,
pub pipeline: RenderPipeline,
pub struct ShadowPipeline {
pub pipeline: CachedPipelineId,
pub view_layout: BindGroupLayout,
pub point_light_sampler: Sampler,
pub directional_light_sampler: Sampler,
}
// TODO: this pattern for initializing the shaders / pipeline isn't ideal. this should be handled by the asset system
impl FromWorld for ShadowShaders {
impl FromWorld for ShadowPipeline {
fn from_world(world: &mut World) -> Self {
let world = world.cell();
let render_device = world.get_resource::<RenderDevice>().unwrap();
let pbr_shaders = world.get_resource::<PbrShaders>().unwrap();
let shader = Shader::from_wgsl(include_str!("depth.wgsl"));
let shader_module = render_device.create_shader_module(&shader);
let view_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[
@ -136,19 +132,16 @@ impl FromWorld for ShadowShaders {
label: Some("shadow_view_layout"),
});
let pipeline_layout = render_device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("shadow_pipeline_layout"),
push_constant_ranges: &[],
bind_group_layouts: &[&view_layout, &pbr_shaders.mesh_layout],
});
let pipeline = render_device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some("shadow_pipeline"),
let pbr_pipeline = world.get_resource::<PbrPipeline>().unwrap();
let descriptor = RenderPipelineDescriptor {
vertex: VertexState {
buffers: &[VertexBufferLayout {
shader: SHADOW_SHADER_HANDLE.typed::<Shader>(),
entry_point: "vertex".into(),
shader_defs: vec![],
buffers: vec![VertexBufferLayout {
array_stride: 32,
step_mode: VertexStepMode::Vertex,
attributes: &[
attributes: vec![
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
@ -169,10 +162,18 @@ impl FromWorld for ShadowShaders {
},
],
}],
module: &shader_module,
entry_point: "vertex",
},
fragment: None,
layout: Some(vec![view_layout.clone(), pbr_pipeline.mesh_layout.clone()]),
primitive: PrimitiveState {
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: None,
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
},
depth_stencil: Some(DepthStencilState {
format: SHADOW_FORMAT,
depth_write_enabled: true,
@ -189,22 +190,13 @@ impl FromWorld for ShadowShaders {
clamp: 0.0,
},
}),
layout: Some(&pipeline_layout),
multisample: MultisampleState::default(),
primitive: PrimitiveState {
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: None,
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
},
});
label: Some("shadow_pipeline".into()),
};
ShadowShaders {
shader_module,
pipeline,
let mut render_pipeline_cache = world.get_resource_mut::<RenderPipelineCache>().unwrap();
ShadowPipeline {
pipeline: render_pipeline_cache.queue(descriptor),
view_layout,
point_light_sampler: render_device.create_sampler(&SamplerDescriptor {
address_mode_u: AddressMode::ClampToEdge,
@ -618,7 +610,7 @@ pub fn prepare_lights(
pub fn queue_shadow_view_bind_group(
render_device: Res<RenderDevice>,
shadow_shaders: Res<ShadowShaders>,
shadow_pipeline: Res<ShadowPipeline>,
mut light_meta: ResMut<LightMeta>,
view_uniforms: Res<ViewUniforms>,
) {
@ -630,13 +622,14 @@ pub fn queue_shadow_view_bind_group(
resource: view_binding,
}],
label: Some("shadow_view_bind_group"),
layout: &shadow_shaders.view_layout,
layout: &shadow_pipeline.view_layout,
}));
}
}
pub fn queue_shadows(
shadow_draw_functions: Res<DrawFunctions<Shadow>>,
shadow_pipeline: Res<ShadowPipeline>,
casting_meshes: Query<Entity, (With<Handle<Mesh>>, Without<NotShadowCaster>)>,
mut view_lights: Query<&ViewLights>,
mut view_light_shadow_phases: Query<&mut RenderPhase<Shadow>>,
@ -653,6 +646,7 @@ pub fn queue_shadows(
for entity in casting_meshes.iter() {
shadow_phase.add(Shadow {
draw_function: draw_shadow_mesh,
pipeline: shadow_pipeline.pipeline,
entity,
distance: 0.0, // TODO: sort back-to-front
})
@ -664,6 +658,7 @@ pub fn queue_shadows(
pub struct Shadow {
pub distance: f32,
pub entity: Entity,
pub pipeline: CachedPipelineId,
pub draw_function: DrawFunctionId,
}
@ -734,7 +729,6 @@ impl Node for ShadowPassNode {
};
let draw_functions = world.get_resource::<DrawFunctions<Shadow>>().unwrap();
let render_pass = render_context
.command_encoder
.begin_render_pass(&pass_descriptor);
@ -753,7 +747,7 @@ impl Node for ShadowPassNode {
pub struct DrawShadowMesh {
params: SystemState<(
SRes<ShadowShaders>,
SRes<RenderPipelineCache>,
SRes<LightMeta>,
SRes<TransformBindGroup>,
SRes<RenderAssets<Mesh>>,
@ -778,11 +772,12 @@ impl Draw<Shadow> for DrawShadowMesh {
view: Entity,
item: &Shadow,
) {
let (shadow_shaders, light_meta, transform_bind_group, meshes, items, views) =
let (pipeline_cache, light_meta, transform_bind_group, meshes, items, views) =
self.params.get(world);
let (transform_index, mesh_handle) = items.get(item.entity).unwrap();
let view_uniform_offset = views.get(view).unwrap();
pass.set_render_pipeline(&shadow_shaders.into_inner().pipeline);
if let Some(pipeline) = pipeline_cache.into_inner().get(item.pipeline) {
pass.set_render_pipeline(pipeline);
pass.set_bind_group(
0,
light_meta
@ -809,3 +804,4 @@ impl Draw<Shadow> for DrawShadowMesh {
}
}
}
}

View file

@ -2,9 +2,12 @@ mod light;
pub use light::*;
use crate::{NotShadowCaster, NotShadowReceiver, StandardMaterial, StandardMaterialUniformData};
use crate::{
NotShadowCaster, NotShadowReceiver, StandardMaterial, StandardMaterialUniformData,
PBR_SHADER_HANDLE,
};
use bevy_asset::Handle;
use bevy_core_pipeline::Transparent3d;
use bevy_core_pipeline::{SetItemPipeline, Transparent3d};
use bevy_ecs::{
prelude::*,
system::{lifetimeless::*, SystemParamItem},
@ -17,7 +20,6 @@ use bevy_render2::{
render_phase::{DrawFunctions, RenderCommand, RenderPhase, TrackedRenderPass},
render_resource::*,
renderer::{RenderDevice, RenderQueue},
shader::Shader,
texture::{BevyDefault, GpuImage, Image, TextureFormatPixelInfo},
view::{ExtractedView, ViewUniformOffset, ViewUniforms},
};
@ -114,9 +116,8 @@ pub fn extract_meshes(
commands.insert_or_spawn_batch(not_caster_values);
}
pub struct PbrShaders {
pub pipeline: RenderPipeline,
pub shader_module: ShaderModule,
#[derive(Clone)]
pub struct PbrPipeline {
pub view_layout: BindGroupLayout,
pub material_layout: BindGroupLayout,
pub mesh_layout: BindGroupLayout,
@ -124,13 +125,9 @@ pub struct PbrShaders {
pub dummy_white_gpu_image: GpuImage,
}
// TODO: this pattern for initializing the shaders / pipeline isn't ideal. this should be handled by the asset system
impl FromWorld for PbrShaders {
impl FromWorld for PbrPipeline {
fn from_world(world: &mut World) -> Self {
let render_device = world.get_resource::<RenderDevice>().unwrap();
let shader = Shader::from_wgsl(include_str!("pbr.wgsl"));
let shader_module = render_device.create_shader_module(&shader);
let view_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[
// View
@ -322,92 +319,6 @@ impl FromWorld for PbrShaders {
}],
label: Some("pbr_mesh_layout"),
});
let pipeline_layout = render_device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("pbr_pipeline_layout"),
push_constant_ranges: &[],
bind_group_layouts: &[&view_layout, &material_layout, &mesh_layout],
});
let pipeline = render_device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some("pbr_pipeline"),
vertex: VertexState {
buffers: &[VertexBufferLayout {
array_stride: 32,
step_mode: VertexStepMode::Vertex,
attributes: &[
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 24,
shader_location: 2,
},
],
}],
module: &shader_module,
entry_point: "vertex",
},
fragment: Some(FragmentState {
module: &shader_module,
entry_point: "fragment",
targets: &[ColorTargetState {
format: TextureFormat::bevy_default(),
blend: Some(BlendState {
color: BlendComponent {
src_factor: BlendFactor::SrcAlpha,
dst_factor: BlendFactor::OneMinusSrcAlpha,
operation: BlendOperation::Add,
},
alpha: BlendComponent {
src_factor: BlendFactor::One,
dst_factor: BlendFactor::One,
operation: BlendOperation::Add,
},
}),
write_mask: ColorWrites::ALL,
}],
}),
depth_stencil: Some(DepthStencilState {
format: TextureFormat::Depth32Float,
depth_write_enabled: true,
depth_compare: CompareFunction::Greater,
stencil: StencilState {
front: StencilFaceState::IGNORE,
back: StencilFaceState::IGNORE,
read_mask: 0,
write_mask: 0,
},
bias: DepthBiasState {
constant: 0,
slope_scale: 0.0,
clamp: 0.0,
},
}),
layout: Some(&pipeline_layout),
multisample: MultisampleState::default(),
primitive: PrimitiveState {
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: Some(Face::Back),
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
},
});
// A 1x1x1 'all 1.0' texture to use as a dummy texture to use in place of optional StandardMaterial textures
let dummy_white_gpu_image = {
let image = Image::new_fill(
@ -449,9 +360,7 @@ impl FromWorld for PbrShaders {
sampler,
}
};
PbrShaders {
pipeline,
shader_module,
PbrPipeline {
view_layout,
material_layout,
mesh_layout,
@ -460,13 +369,114 @@ impl FromWorld for PbrShaders {
}
}
// TODO: add actual specialization key: MSAA, normal maps, shadeless, etc
bitflags::bitflags! {
#[repr(transparent)]
pub struct PbrPipelineKey: u32 { }
}
impl SpecializedPipeline for PbrPipeline {
type Key = PbrPipelineKey;
fn specialize(&self, _key: Self::Key) -> RenderPipelineDescriptor {
RenderPipelineDescriptor {
vertex: VertexState {
shader: PBR_SHADER_HANDLE.typed::<Shader>(),
entry_point: "vertex".into(),
shader_defs: vec![],
buffers: vec![VertexBufferLayout {
array_stride: 32,
step_mode: VertexStepMode::Vertex,
attributes: vec![
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 24,
shader_location: 2,
},
],
}],
},
fragment: Some(FragmentState {
shader: PBR_SHADER_HANDLE.typed::<Shader>(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![ColorTargetState {
format: TextureFormat::bevy_default(),
blend: Some(BlendState {
color: BlendComponent {
src_factor: BlendFactor::SrcAlpha,
dst_factor: BlendFactor::OneMinusSrcAlpha,
operation: BlendOperation::Add,
},
alpha: BlendComponent {
src_factor: BlendFactor::One,
dst_factor: BlendFactor::One,
operation: BlendOperation::Add,
},
}),
write_mask: ColorWrites::ALL,
}],
}),
layout: Some(vec![
self.view_layout.clone(),
self.material_layout.clone(),
self.mesh_layout.clone(),
]),
primitive: PrimitiveState {
front_face: FrontFace::Ccw,
cull_mode: Some(Face::Back),
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
},
depth_stencil: Some(DepthStencilState {
format: TextureFormat::Depth32Float,
depth_write_enabled: true,
depth_compare: CompareFunction::Greater,
stencil: StencilState {
front: StencilFaceState::IGNORE,
back: StencilFaceState::IGNORE,
read_mask: 0,
write_mask: 0,
},
bias: DepthBiasState {
constant: 0,
slope_scale: 0.0,
clamp: 0.0,
},
}),
multisample: MultisampleState {
count: 1,
mask: !0,
alpha_to_coverage_enabled: false,
},
label: Some("pbr_pipeline".into()),
}
}
}
pub struct TransformBindGroup {
pub value: BindGroup,
}
pub fn queue_transform_bind_group(
mut commands: Commands,
pbr_shaders: Res<PbrShaders>,
pbr_pipeline: Res<PbrPipeline>,
render_device: Res<RenderDevice>,
transform_uniforms: Res<ComponentUniforms<MeshUniform>>,
) {
@ -478,7 +488,7 @@ pub fn queue_transform_bind_group(
resource: binding,
}],
label: Some("transform_bind_group"),
layout: &pbr_shaders.mesh_layout,
layout: &pbr_pipeline.mesh_layout,
}),
});
}
@ -493,8 +503,10 @@ pub fn queue_meshes(
mut commands: Commands,
transparent_3d_draw_functions: Res<DrawFunctions<Transparent3d>>,
render_device: Res<RenderDevice>,
pbr_shaders: Res<PbrShaders>,
shadow_shaders: Res<ShadowShaders>,
pbr_pipeline: Res<PbrPipeline>,
shadow_pipeline: Res<ShadowPipeline>,
mut pipelines: ResMut<SpecializedPipelines<PbrPipeline>>,
mut pipeline_cache: ResMut<RenderPipelineCache>,
light_meta: Res<LightMeta>,
view_uniforms: Res<ViewUniforms>,
render_materials: Res<RenderAssets<StandardMaterial>>,
@ -532,7 +544,7 @@ pub fn queue_meshes(
},
BindGroupEntry {
binding: 3,
resource: BindingResource::Sampler(&shadow_shaders.point_light_sampler),
resource: BindingResource::Sampler(&shadow_pipeline.point_light_sampler),
},
BindGroupEntry {
binding: 4,
@ -543,12 +555,12 @@ pub fn queue_meshes(
BindGroupEntry {
binding: 5,
resource: BindingResource::Sampler(
&shadow_shaders.directional_light_sampler,
&shadow_pipeline.directional_light_sampler,
),
},
],
label: Some("pbr_view_bind_group"),
layout: &pbr_shaders.view_layout,
layout: &pbr_pipeline.view_layout,
});
commands.entity(entity).insert(PbrViewBindGroup {
@ -567,6 +579,9 @@ pub fn queue_meshes(
if !render_materials.contains_key(material_handle) {
continue;
}
let key = PbrPipelineKey::empty();
let pipeline_id = pipelines.specialize(&mut pipeline_cache, &pbr_pipeline, key);
// NOTE: row 2 of the view matrix dotted with column 3 of the model matrix
// gives the z component of translation of the mesh in view space
let mesh_z = view_row_2.dot(mesh_uniform.transform.col(3));
@ -574,6 +589,7 @@ pub fn queue_meshes(
transparent_phase.add(Transparent3d {
entity,
draw_function: draw_pbr,
pipeline: pipeline_id,
distance: mesh_z,
});
}
@ -582,27 +598,13 @@ pub fn queue_meshes(
}
pub type DrawPbr = (
SetPbrPipeline,
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetStandardMaterialBindGroup<1>,
SetTransformBindGroup<2>,
DrawMesh,
);
pub struct SetPbrPipeline;
impl RenderCommand<Transparent3d> for SetPbrPipeline {
type Param = SRes<PbrShaders>;
#[inline]
fn render<'w>(
_view: Entity,
_item: &Transparent3d,
pbr_shaders: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
pass.set_render_pipeline(&pbr_shaders.into_inner().pipeline);
}
}
pub struct SetMeshViewBindGroup<const I: usize>;
impl<const I: usize> RenderCommand<Transparent3d> for SetMeshViewBindGroup<I> {
type Param = SQuery<(

View file

@ -42,6 +42,7 @@ anyhow = "1.0"
hex = "0.4.2"
hexasphere = "4.0"
parking_lot = "0.11.0"
regex = "1.5"
crevice = { path = "../../crates/crevice", version = "0.6.0" }
[features]

View file

@ -7,7 +7,6 @@ pub mod render_graph;
pub mod render_phase;
pub mod render_resource;
pub mod renderer;
pub mod shader;
pub mod texture;
pub mod view;
@ -16,13 +15,15 @@ pub use once_cell;
use crate::{
camera::CameraPlugin,
mesh::MeshPlugin,
render_asset::RenderAssetPlugin,
render_graph::RenderGraph,
render_resource::{RenderPipelineCache, Shader, ShaderLoader},
renderer::render_system,
texture::ImagePlugin,
view::{ViewPlugin, WindowRenderPlugin},
};
use bevy_app::{App, AppLabel, Plugin};
use bevy_asset::AssetServer;
use bevy_asset::{AddAsset, AssetServer};
use bevy_ecs::prelude::*;
use std::ops::{Deref, DerefMut};
use wgpu::Backends;
@ -95,11 +96,15 @@ impl Plugin for RenderPlugin {
));
app.insert_resource(device.clone())
.insert_resource(queue.clone())
.add_asset::<Shader>()
.init_asset_loader::<ShaderLoader>()
.init_resource::<ScratchRenderWorld>();
let render_pipeline_cache = RenderPipelineCache::new(device.clone());
let asset_server = app.world.get_resource::<AssetServer>().unwrap().clone();
let mut render_app = App::empty();
let mut extract_stage = SystemStage::parallel();
let mut extract_stage =
SystemStage::parallel().with_system(RenderPipelineCache::extract_dirty_shaders);
// don't apply buffers when the stage finishes running
// extract stage runs on the app world, but the buffers are applied to the render world
extract_stage.set_apply_buffers(false);
@ -110,12 +115,15 @@ impl Plugin for RenderPlugin {
.add_stage(RenderStage::PhaseSort, SystemStage::parallel())
.add_stage(
RenderStage::Render,
SystemStage::parallel().with_system(render_system.exclusive_system()),
SystemStage::parallel()
.with_system(RenderPipelineCache::process_pipeline_queue_system)
.with_system(render_system.exclusive_system().at_end()),
)
.add_stage(RenderStage::Cleanup, SystemStage::parallel())
.insert_resource(instance)
.insert_resource(device)
.insert_resource(queue)
.insert_resource(render_pipeline_cache)
.insert_resource(asset_server)
.init_resource::<RenderGraph>();
@ -232,7 +240,8 @@ impl Plugin for RenderPlugin {
.add_plugin(CameraPlugin)
.add_plugin(ViewPlugin)
.add_plugin(MeshPlugin)
.add_plugin(ImagePlugin);
.add_plugin(ImagePlugin)
.add_plugin(RenderAssetPlugin::<Shader>::default());
}
}

View file

@ -0,0 +1,41 @@
use bevy_reflect::Uuid;
use std::{ops::Deref, sync::Arc};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct BindGroupLayoutId(Uuid);
#[derive(Clone, Debug)]
pub struct BindGroupLayout {
id: BindGroupLayoutId,
value: Arc<wgpu::BindGroupLayout>,
}
impl BindGroupLayout {
#[inline]
pub fn id(&self) -> BindGroupLayoutId {
self.id
}
#[inline]
pub fn value(&self) -> &wgpu::BindGroupLayout {
&self.value
}
}
impl From<wgpu::BindGroupLayout> for BindGroupLayout {
fn from(value: wgpu::BindGroupLayout) -> Self {
BindGroupLayout {
id: BindGroupLayoutId(Uuid::new_v4()),
value: Arc::new(value),
}
}
}
impl Deref for BindGroupLayout {
type Target = wgpu::BindGroupLayout;
#[inline]
fn deref(&self) -> &Self::Target {
&self.value
}
}

View file

@ -1,30 +1,40 @@
mod bind_group;
mod bind_group_layout;
mod buffer;
mod buffer_vec;
mod pipeline;
mod pipeline_cache;
mod pipeline_specializer;
mod shader;
mod texture;
mod uniform_vec;
pub use bind_group::*;
pub use bind_group_layout::*;
pub use buffer::*;
pub use buffer_vec::*;
pub use pipeline::*;
pub use pipeline_cache::*;
pub use pipeline_specializer::*;
pub use shader::*;
pub use texture::*;
pub use uniform_vec::*;
// TODO: decide where re-exports should go
pub use wgpu::{
util::BufferInitDescriptor, AddressMode, BindGroupDescriptor, BindGroupEntry, BindGroupLayout,
util::BufferInitDescriptor, AddressMode, BindGroupDescriptor, BindGroupEntry,
BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingResource, BindingType, BlendComponent,
BlendFactor, BlendOperation, BlendState, BufferAddress, BufferBindingType, BufferSize,
BufferUsages, ColorTargetState, ColorWrites, CompareFunction, ComputePassDescriptor,
ComputePipelineDescriptor, DepthBiasState, DepthStencilState, Extent3d, Face, FilterMode,
FragmentState, FrontFace, IndexFormat, LoadOp, MultisampleState, Operations, PipelineLayout,
PipelineLayoutDescriptor, PolygonMode, PrimitiveState, PrimitiveTopology,
RenderPassColorAttachment, RenderPassDepthStencilAttachment, RenderPassDescriptor,
RenderPipelineDescriptor, SamplerDescriptor, ShaderModule, ShaderModuleDescriptor,
ShaderSource, ShaderStages, StencilFaceState, StencilOperation, StencilState,
StorageTextureAccess, TextureAspect, TextureDescriptor, TextureDimension, TextureFormat,
TextureSampleType, TextureUsages, TextureViewDescriptor, TextureViewDimension, VertexAttribute,
VertexBufferLayout, VertexFormat, VertexState, VertexStepMode,
FragmentState as RawFragmentState, FrontFace, IndexFormat, LoadOp, MultisampleState,
Operations, PipelineLayout, PipelineLayoutDescriptor, PolygonMode, PrimitiveState,
PrimitiveTopology, RenderPassColorAttachment, RenderPassDepthStencilAttachment,
RenderPassDescriptor, RenderPipelineDescriptor as RawRenderPipelineDescriptor,
SamplerDescriptor, ShaderModule, ShaderModuleDescriptor, ShaderSource, ShaderStages,
StencilFaceState, StencilOperation, StencilState, StorageTextureAccess, TextureAspect,
TextureDescriptor, TextureDimension, TextureFormat, TextureSampleType, TextureUsages,
TextureViewDescriptor, TextureViewDimension, VertexAttribute,
VertexBufferLayout as RawVertexBufferLayout, VertexFormat, VertexState as RawVertexState,
VertexStepMode,
};

View file

@ -1,5 +1,11 @@
use crate::render_resource::{BindGroupLayout, Shader};
use bevy_asset::Handle;
use bevy_reflect::Uuid;
use std::{ops::Deref, sync::Arc};
use std::{borrow::Cow, ops::Deref, sync::Arc};
use wgpu::{
BufferAddress, ColorTargetState, DepthStencilState, MultisampleState, PrimitiveState,
VertexAttribute, VertexStepMode,
};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct RenderPipelineId(Uuid);
@ -68,3 +74,58 @@ impl Deref for ComputePipeline {
&self.value
}
}
/// Describes a render (graphics) pipeline.
#[derive(Clone, Debug)]
pub struct RenderPipelineDescriptor {
/// Debug label of the pipeline. This will show up in graphics debuggers for easy identification.
pub label: Option<Cow<'static, str>>,
/// The layout of bind groups for this pipeline.
pub layout: Option<Vec<BindGroupLayout>>,
/// The compiled vertex stage, its entry point, and the input buffers layout.
pub vertex: VertexState,
/// The properties of the pipeline at the primitive assembly and rasterization level.
pub primitive: PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.
pub depth_stencil: Option<DepthStencilState>,
/// The multi-sampling properties of the pipeline.
pub multisample: MultisampleState,
/// The compiled fragment stage, its entry point, and the color targets.
pub fragment: Option<FragmentState>,
}
#[derive(Clone, Debug)]
pub struct VertexState {
/// The compiled shader module for this stage.
pub shader: Handle<Shader>,
pub shader_defs: Vec<String>,
/// The name of the entry point in the compiled shader. There must be a function that returns
/// void with this name in the shader.
pub entry_point: Cow<'static, str>,
/// The format of any vertex buffers used with this pipeline.
pub buffers: Vec<VertexBufferLayout>,
}
/// Describes how the vertex buffer is interpreted.
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub struct VertexBufferLayout {
/// The stride, in bytes, between elements of this buffer.
pub array_stride: BufferAddress,
/// How often this vertex buffer is "stepped" forward.
pub step_mode: VertexStepMode,
/// The list of attributes which comprise a single vertex.
pub attributes: Vec<VertexAttribute>,
}
/// Describes the fragment process in a render pipeline.
#[derive(Clone, Debug)]
pub struct FragmentState {
/// The compiled shader module for this stage.
pub shader: Handle<Shader>,
pub shader_defs: Vec<String>,
/// The name of the entry point in the compiled shader. There must be a function that returns
/// void with this name in the shader.
pub entry_point: Cow<'static, str>,
/// The color state of the render targets.
pub targets: Vec<ColorTargetState>,
}

View file

@ -0,0 +1,324 @@
use crate::{
render_asset::RenderAssets,
render_resource::{
AsModuleDescriptorError, BindGroupLayout, BindGroupLayoutId, ProcessShaderError,
RawFragmentState, RawRenderPipelineDescriptor, RawVertexState, RenderPipeline,
RenderPipelineDescriptor, Shader, ShaderProcessor,
},
renderer::RenderDevice,
RenderWorld,
};
use bevy_app::EventReader;
use bevy_asset::{AssetEvent, Handle};
use bevy_ecs::system::{Res, ResMut};
use bevy_utils::{HashMap, HashSet};
use std::{collections::hash_map::Entry, hash::Hash, ops::Deref, sync::Arc};
use thiserror::Error;
use wgpu::{PipelineLayoutDescriptor, ShaderModule, VertexBufferLayout};
#[derive(Default)]
pub struct ShaderData {
pipelines: HashSet<CachedPipelineId>,
processed_shaders: HashMap<Vec<String>, Arc<ShaderModule>>,
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct CachedPipelineId(usize);
impl CachedPipelineId {
pub const INVALID: Self = CachedPipelineId(usize::MAX);
}
#[derive(Default)]
struct ShaderCache {
data: HashMap<Handle<Shader>, ShaderData>,
processor: ShaderProcessor,
}
impl ShaderCache {
fn get(
&mut self,
render_device: &RenderDevice,
shaders: &RenderAssets<Shader>,
pipeline: CachedPipelineId,
handle: &Handle<Shader>,
shader_defs: &[String],
) -> Result<Arc<ShaderModule>, RenderPipelineError> {
let shader = shaders
.get(handle)
.ok_or_else(|| RenderPipelineError::ShaderNotLoaded(handle.clone_weak()))?;
let data = self.data.entry(handle.clone_weak()).or_default();
data.pipelines.insert(pipeline);
// PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
let module = match data.processed_shaders.entry(shader_defs.to_vec()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
let processed = self.processor.process_shader(shader, shader_defs)?;
let module_descriptor = processed.get_module_descriptor()?;
entry.insert(Arc::new(
render_device.create_shader_module(&module_descriptor),
))
}
};
Ok(module.clone())
}
fn clear(
&mut self,
handle: &Handle<Shader>,
) -> Option<impl Iterator<Item = CachedPipelineId> + '_> {
let data = self.data.get_mut(handle)?;
data.processed_shaders.clear();
Some(data.pipelines.drain())
}
fn remove(&mut self, handle: &Handle<Shader>) {
self.data.remove(handle);
}
}
#[derive(Default)]
struct LayoutCache {
layouts: HashMap<Vec<BindGroupLayoutId>, wgpu::PipelineLayout>,
}
impl LayoutCache {
fn get(
&mut self,
render_device: &RenderDevice,
bind_group_layouts: &[BindGroupLayout],
) -> &wgpu::PipelineLayout {
let key = bind_group_layouts.iter().map(|l| l.id()).collect();
self.layouts.entry(key).or_insert_with(|| {
let bind_group_layouts = bind_group_layouts
.iter()
.map(|l| l.value())
.collect::<Vec<_>>();
render_device.create_pipeline_layout(&PipelineLayoutDescriptor {
bind_group_layouts: &bind_group_layouts,
..Default::default()
})
})
}
}
pub struct RenderPipelineCache {
layout_cache: LayoutCache,
shader_cache: ShaderCache,
device: RenderDevice,
pipelines: Vec<CachedPipeline>,
waiting_pipelines: Vec<CachedPipelineId>,
}
struct CachedPipeline {
descriptor: RenderPipelineDescriptor,
state: CachedPipelineState,
}
#[derive(Debug)]
pub enum CachedPipelineState {
Queued,
Ok(RenderPipeline),
Err(RenderPipelineError),
}
impl CachedPipelineState {
pub fn unwrap(&self) -> &RenderPipeline {
match self {
CachedPipelineState::Ok(pipeline) => pipeline,
CachedPipelineState::Queued => {
panic!("Pipeline has not been compiled yet. It is still in the 'Queued' state.")
}
CachedPipelineState::Err(err) => panic!("{}", err),
}
}
}
#[derive(Error, Debug)]
pub enum RenderPipelineError {
#[error(
"Pipeline cound not be compiled because the following shader is not loaded yet: {0:?}"
)]
ShaderNotLoaded(Handle<Shader>),
#[error(transparent)]
ProcessShaderError(#[from] ProcessShaderError),
#[error(transparent)]
AsModuleDescriptorError(#[from] AsModuleDescriptorError),
}
impl RenderPipelineCache {
pub fn new(device: RenderDevice) -> Self {
Self {
device,
layout_cache: Default::default(),
shader_cache: Default::default(),
waiting_pipelines: Default::default(),
pipelines: Default::default(),
}
}
#[inline]
pub fn get_state(&self, id: CachedPipelineId) -> &CachedPipelineState {
&self.pipelines[id.0].state
}
#[inline]
pub fn get(&self, id: CachedPipelineId) -> Option<&RenderPipeline> {
if let CachedPipelineState::Ok(pipeline) = &self.pipelines[id.0].state {
Some(pipeline)
} else {
None
}
}
pub fn queue(&mut self, descriptor: RenderPipelineDescriptor) -> CachedPipelineId {
let id = CachedPipelineId(self.pipelines.len());
self.pipelines.push(CachedPipeline {
descriptor,
state: CachedPipelineState::Queued,
});
self.waiting_pipelines.push(id);
id
}
pub fn mark_shader_dirty(&mut self, shader: &Handle<Shader>) {
if let Some(cached_pipelines) = self.shader_cache.clear(shader) {
for cached_pipeline in cached_pipelines {
self.pipelines[cached_pipeline.0].state = CachedPipelineState::Queued;
self.waiting_pipelines.push(cached_pipeline);
}
}
}
pub fn remove_shader(&mut self, shader: &Handle<Shader>) {
if let Some(cached_pipelines) = self.shader_cache.clear(shader) {
for cached_pipeline in cached_pipelines {
self.pipelines[cached_pipeline.0].state = CachedPipelineState::Queued;
self.waiting_pipelines.push(cached_pipeline);
}
}
}
pub fn process_queue(&mut self, shaders: &RenderAssets<Shader>) {
let pipelines = std::mem::take(&mut self.waiting_pipelines);
for id in pipelines {
let state = &mut self.pipelines[id.0];
match &state.state {
CachedPipelineState::Ok(_) => continue,
CachedPipelineState::Queued => {}
CachedPipelineState::Err(err) => {
match err {
RenderPipelineError::ShaderNotLoaded(_) => { /* retry */ }
RenderPipelineError::ProcessShaderError(_)
| RenderPipelineError::AsModuleDescriptorError(_) => {
// shader could not be processed ... retrying won't help
continue;
}
}
}
}
let descriptor = &state.descriptor;
let vertex_module = match self.shader_cache.get(
&self.device,
shaders,
id,
&descriptor.vertex.shader,
&descriptor.vertex.shader_defs,
) {
Ok(module) => module,
Err(err) => {
state.state = CachedPipelineState::Err(err);
self.waiting_pipelines.push(id);
continue;
}
};
let fragment_data = if let Some(fragment) = &descriptor.fragment {
let fragment_module = match self.shader_cache.get(
&self.device,
shaders,
id,
&fragment.shader,
&fragment.shader_defs,
) {
Ok(module) => module,
Err(err) => {
state.state = CachedPipelineState::Err(err);
self.waiting_pipelines.push(id);
continue;
}
};
Some((
fragment_module,
fragment.entry_point.deref(),
&fragment.targets,
))
} else {
None
};
let vertex_buffer_layouts = descriptor
.vertex
.buffers
.iter()
.map(|layout| VertexBufferLayout {
array_stride: layout.array_stride,
attributes: &layout.attributes,
step_mode: layout.step_mode,
})
.collect::<Vec<_>>();
let layout = if let Some(layout) = &descriptor.layout {
Some(self.layout_cache.get(&self.device, layout))
} else {
None
};
let descriptor = RawRenderPipelineDescriptor {
depth_stencil: descriptor.depth_stencil.clone(),
label: descriptor.label.as_deref(),
layout,
multisample: descriptor.multisample,
primitive: descriptor.primitive,
vertex: RawVertexState {
buffers: &vertex_buffer_layouts,
entry_point: descriptor.vertex.entry_point.deref(),
module: &vertex_module,
},
fragment: fragment_data
.as_ref()
.map(|(module, entry_point, targets)| RawFragmentState {
entry_point,
module,
targets,
}),
};
let pipeline = self.device.create_render_pipeline(&descriptor);
state.state = CachedPipelineState::Ok(pipeline);
}
}
pub(crate) fn process_pipeline_queue_system(
mut cache: ResMut<Self>,
shaders: Res<RenderAssets<Shader>>,
) {
cache.process_queue(&shaders);
}
pub(crate) fn extract_dirty_shaders(
mut world: ResMut<RenderWorld>,
mut events: EventReader<AssetEvent<Shader>>,
) {
let mut cache = world.get_resource_mut::<Self>().unwrap();
for event in events.iter() {
match event {
AssetEvent::Created { handle } | AssetEvent::Modified { handle } => {
cache.mark_shader_dirty(handle)
}
AssetEvent::Removed { handle } => cache.shader_cache.remove(handle),
}
}
}
}

View file

@ -0,0 +1,34 @@
use crate::render_resource::{CachedPipelineId, RenderPipelineCache, RenderPipelineDescriptor};
use bevy_utils::HashMap;
use std::hash::Hash;
pub struct SpecializedPipelines<S: SpecializedPipeline> {
cache: HashMap<S::Key, CachedPipelineId>,
}
impl<S: SpecializedPipeline> Default for SpecializedPipelines<S> {
fn default() -> Self {
Self {
cache: Default::default(),
}
}
}
impl<S: SpecializedPipeline> SpecializedPipelines<S> {
pub fn specialize(
&mut self,
cache: &mut RenderPipelineCache,
specialize_pipeline: &S,
key: S::Key,
) -> CachedPipelineId {
*self.cache.entry(key.clone()).or_insert_with(|| {
let descriptor = specialize_pipeline.specialize(key);
cache.queue(descriptor)
})
}
}
pub trait SpecializedPipeline {
type Key: Clone + Hash + PartialEq + Eq;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor;
}

View file

@ -0,0 +1,435 @@
use bevy_asset::{AssetLoader, LoadContext, LoadedAsset};
use bevy_reflect::{TypeUuid, Uuid};
use bevy_utils::{tracing::error, BoxedFuture};
use naga::{valid::ModuleInfo, Module};
use regex::Regex;
use std::{borrow::Cow, collections::HashSet, marker::Copy};
use thiserror::Error;
use wgpu::{ShaderModuleDescriptor, ShaderSource};
use crate::render_asset::{PrepareAssetError, RenderAsset};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ShaderId(Uuid);
impl ShaderId {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
ShaderId(Uuid::new_v4())
}
}
#[derive(Error, Debug)]
pub enum ShaderReflectError {
#[error(transparent)]
WgslParse(#[from] naga::front::wgsl::ParseError),
#[error("GLSL Parse Error: {0:?}")]
GlslParse(Vec<naga::front::glsl::Error>),
#[error(transparent)]
SpirVParse(#[from] naga::front::spv::Error),
#[error(transparent)]
Validation(#[from] naga::valid::ValidationError),
}
/// A shader, as defined by its [ShaderSource] and [ShaderStage]
/// This is an "unprocessed" shader. It can contain preprocessor directives.
#[derive(Debug, Clone, TypeUuid)]
#[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"]
pub enum Shader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
// TODO: consider the following
// PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
// NagaModule(Module) ... Module impls Serialize/Deserialize
}
/// A processed [Shader]. This cannot contain preprocessor directions. It must be "ready to compile"
pub enum ProcessedShader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
}
impl ProcessedShader {
pub fn reflect(&self) -> Result<ShaderReflection, ShaderReflectError> {
let module = match &self {
// TODO: process macros here
ProcessedShader::Wgsl(source) => naga::front::wgsl::parse_str(source)?,
ProcessedShader::Glsl(source, shader_stage) => {
let mut parser = naga::front::glsl::Parser::default();
parser
.parse(&naga::front::glsl::Options::from(*shader_stage), source)
.map_err(ShaderReflectError::GlslParse)?
}
ProcessedShader::SpirV(source) => naga::front::spv::parse_u8_slice(
source,
&naga::front::spv::Options {
adjust_coordinate_space: false,
..naga::front::spv::Options::default()
},
)?,
};
let module_info = naga::valid::Validator::new(
naga::valid::ValidationFlags::default(),
naga::valid::Capabilities::default(),
)
.validate(&module)?;
Ok(ShaderReflection {
module,
module_info,
})
}
pub fn get_module_descriptor(&self) -> Result<ShaderModuleDescriptor, AsModuleDescriptorError> {
Ok(ShaderModuleDescriptor {
label: None,
source: match self {
ProcessedShader::Wgsl(source) => ShaderSource::Wgsl(source.clone()),
ProcessedShader::Glsl(_source, _stage) => {
let reflection = self.reflect()?;
// TODO: it probably makes more sense to convert this to spirv, but as of writing
// this comment, naga's spirv conversion is broken
let wgsl = reflection.get_wgsl()?;
ShaderSource::Wgsl(wgsl.into())
}
ProcessedShader::SpirV(_) => {
// TODO: we can probably just transmute the u8 array to u32?
let reflection = self.reflect()?;
let spirv = reflection.get_spirv()?;
ShaderSource::SpirV(Cow::Owned(spirv))
}
},
})
}
}
#[derive(Error, Debug)]
pub enum AsModuleDescriptorError {
#[error(transparent)]
ShaderReflectError(#[from] ShaderReflectError),
#[error(transparent)]
WgslConversion(#[from] naga::back::wgsl::Error),
#[error(transparent)]
SpirVConversion(#[from] naga::back::spv::Error),
}
pub struct ShaderReflection {
pub module: Module,
pub module_info: ModuleInfo,
}
impl ShaderReflection {
pub fn get_spirv(&self) -> Result<Vec<u32>, naga::back::spv::Error> {
naga::back::spv::write_vec(
&self.module,
&self.module_info,
&naga::back::spv::Options {
flags: naga::back::spv::WriterFlags::empty(),
..naga::back::spv::Options::default()
},
None,
)
}
pub fn get_wgsl(&self) -> Result<String, naga::back::wgsl::Error> {
naga::back::wgsl::write_string(&self.module, &self.module_info)
}
}
impl Shader {
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
Shader::Wgsl(source.into())
}
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
Shader::Glsl(source.into(), stage)
}
pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>) -> Shader {
Shader::SpirV(source.into())
}
}
#[derive(Default)]
pub struct ShaderLoader;
impl AssetLoader for ShaderLoader {
fn load<'a>(
&'a self,
bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let ext = load_context.path().extension().unwrap().to_str().unwrap();
let shader = match ext {
"spv" => Shader::from_spirv(Vec::from(bytes)),
"wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?),
"vert" => Shader::from_glsl(
String::from_utf8(Vec::from(bytes))?,
naga::ShaderStage::Vertex,
),
"frag" => Shader::from_glsl(
String::from_utf8(Vec::from(bytes))?,
naga::ShaderStage::Fragment,
),
_ => panic!("unhandled extension: {}", ext),
};
load_context.set_default_asset(LoadedAsset::new(shader));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["spv", "wgsl", "vert", "frag"]
}
}
impl RenderAsset for Shader {
type ExtractedAsset = Shader;
type PreparedAsset = Shader;
type Param = ();
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
extracted_asset: Self::ExtractedAsset,
_param: &mut bevy_ecs::system::SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
Ok(extracted_asset)
}
}
#[derive(Error, Debug, PartialEq, Eq)]
pub enum ProcessShaderError {
#[error("Too many '# endif' lines. Each endif should be preceded by an if statement.")]
TooManyEndIfs,
#[error(
"Not enough '# endif' lines. Each if statement should be followed by an endif statement."
)]
NotEnoughEndIfs,
#[error("This Shader's format does not support processing shader defs.")]
ShaderFormatDoesNotSupportShaderDefs,
}
pub struct ShaderProcessor {
ifdef_regex: Regex,
ifndef_regex: Regex,
endif_regex: Regex,
}
impl Default for ShaderProcessor {
fn default() -> Self {
Self {
ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(),
ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(),
endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(),
}
}
}
impl ShaderProcessor {
pub fn process_shader(
&self,
shader: &Shader,
shader_defs: &[String],
) -> Result<ProcessedShader, ProcessShaderError> {
match shader {
Shader::Wgsl(source) => Ok(ProcessedShader::Wgsl(Cow::from(
self.process_str(source, shader_defs)?,
))),
Shader::Glsl(source, stage) => Ok(ProcessedShader::Glsl(
Cow::from(self.process_str(source, shader_defs)?),
*stage,
)),
Shader::SpirV(source) => {
if shader_defs.is_empty() {
Ok(ProcessedShader::SpirV(source.clone()))
} else {
Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs)
}
}
}
}
pub fn process_str(
&self,
shader: &str,
shader_defs: &[String],
) -> Result<String, ProcessShaderError> {
let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let mut scopes = vec![true];
let mut final_string = String::new();
for line in shader.split('\n') {
if let Some(cap) = self.ifdef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(shader_defs.contains(def.as_str()));
} else if let Some(cap) = self.ifndef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(!shader_defs.contains(def.as_str()));
} else if self.endif_regex.is_match(line) {
scopes.pop();
if scopes.is_empty() {
return Err(ProcessShaderError::TooManyEndIfs);
}
} else if *scopes.last().unwrap() {
final_string.push_str(line);
final_string.push('\n');
}
}
final_string.pop();
if scopes.len() != 1 {
return Err(ProcessShaderError::NotEnoughEndIfs);
}
Ok(final_string)
}
}
#[cfg(test)]
mod tests {
use crate::render_resource::{ProcessShaderError, ShaderProcessor};
#[rustfmt::skip]
const WGSL: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
# ifdef TEXTURE
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
# endif
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
#[test]
fn process_shader_def_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL, &["TEXTURE".to_string()])
.unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_shader_def_not_defined() {
#[rustfmt::skip]
const EXPECTED: &str = r"
[[block]]
struct View {
view_proj: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
[[builtin(position)]] position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(
[[location(0)]] vertex_position: vec3<f32>,
[[location(1)]] vertex_uv: vec2<f32>
) -> VertexOutput {
var out: VertexOutput;
out.uv = vertex_uv;
out.position = view.view_proj * vec4<f32>(vertex_position, 1.0);
return out;
}
";
let processor = ShaderProcessor::default();
let result = processor.process_str(WGSL, &[]).unwrap();
assert_eq!(result, EXPECTED);
}
#[test]
fn process_shader_def_unclosed() {
#[rustfmt::skip]
const INPUT: &str = r"
# ifdef FOO
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
assert_eq!(result, Err(ProcessShaderError::NotEnoughEndIfs));
}
#[test]
fn process_shader_def_too_closed() {
#[rustfmt::skip]
const INPUT: &str = r"
# endif
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
assert_eq!(result, Err(ProcessShaderError::TooManyEndIfs));
}
#[test]
fn process_shader_def_commented() {
#[rustfmt::skip]
const INPUT: &str = r"
// # ifdef FOO
fn foo() { }
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]).unwrap();
assert_eq!(result, INPUT);
}
}

View file

@ -2,7 +2,8 @@ use futures_lite::future;
use wgpu::util::DeviceExt;
use crate::render_resource::{
BindGroup, Buffer, ComputePipeline, RenderPipeline, Sampler, Texture,
BindGroup, BindGroupLayout, Buffer, ComputePipeline, RawRenderPipelineDescriptor,
RenderPipeline, Sampler, Texture,
};
use std::sync::Arc;
@ -36,11 +37,8 @@ impl RenderDevice {
/// Creates a shader module from either SPIR-V or WGSL source code.
#[inline]
pub fn create_shader_module<'a>(
&self,
desc: impl Into<wgpu::ShaderModuleDescriptor<'a>>,
) -> wgpu::ShaderModule {
self.device.create_shader_module(&desc.into())
pub fn create_shader_module(&self, desc: &wgpu::ShaderModuleDescriptor) -> wgpu::ShaderModule {
self.device.create_shader_module(desc)
}
/// Check for resource cleanups and mapping callbacks.
@ -81,8 +79,8 @@ impl RenderDevice {
pub fn create_bind_group_layout(
&self,
desc: &wgpu::BindGroupLayoutDescriptor,
) -> wgpu::BindGroupLayout {
self.device.create_bind_group_layout(desc)
) -> BindGroupLayout {
BindGroupLayout::from(self.device.create_bind_group_layout(desc))
}
/// Creates a [`PipelineLayout`].
@ -96,7 +94,7 @@ impl RenderDevice {
/// Creates a [`RenderPipeline`].
#[inline]
pub fn create_render_pipeline(&self, desc: &wgpu::RenderPipelineDescriptor) -> RenderPipeline {
pub fn create_render_pipeline(&self, desc: &RawRenderPipelineDescriptor) -> RenderPipeline {
let wgpu_render_pipeline = self.device.create_render_pipeline(desc);
RenderPipeline::from(wgpu_render_pipeline)
}

View file

@ -1,4 +0,0 @@
#[allow(clippy::module_inception)]
mod shader;
pub use shader::*;

View file

@ -1,158 +0,0 @@
use bevy_asset::{AssetLoader, LoadContext, LoadedAsset};
use bevy_reflect::{TypeUuid, Uuid};
use bevy_utils::{tracing::error, BoxedFuture};
use naga::{valid::ModuleInfo, Module};
use std::{borrow::Cow, marker::Copy};
use thiserror::Error;
use wgpu::{ShaderModuleDescriptor, ShaderSource};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ShaderId(Uuid);
impl ShaderId {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
ShaderId(Uuid::new_v4())
}
}
#[derive(Error, Debug)]
pub enum ShaderReflectError {
#[error(transparent)]
WgslParse(#[from] naga::front::wgsl::ParseError),
#[error("GLSL Parse Error: {0:?}")]
GlslParse(Vec<naga::front::glsl::Error>),
#[error(transparent)]
SpirVParse(#[from] naga::front::spv::Error),
#[error(transparent)]
Validation(#[from] naga::valid::ValidationError),
}
/// A shader, as defined by its [ShaderSource] and [ShaderStage]
#[derive(Debug, TypeUuid)]
#[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"]
pub enum Shader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Vec<u8>),
// TODO: consider the following
// PrecompiledSpirVMacros(HashMap<HashSet<String>, Vec<u32>>)
// NagaModule(Module) ... Module impls Serialize/Deserialize
}
pub struct ShaderReflection {
pub module: Module,
pub module_info: ModuleInfo,
}
impl ShaderReflection {
pub fn get_spirv(&self) -> Result<Vec<u32>, naga::back::spv::Error> {
naga::back::spv::write_vec(
&self.module,
&self.module_info,
&naga::back::spv::Options {
flags: naga::back::spv::WriterFlags::empty(),
..naga::back::spv::Options::default()
},
None,
)
}
pub fn get_wgsl(&self) -> Result<String, naga::back::wgsl::Error> {
naga::back::wgsl::write_string(&self.module, &self.module_info)
}
}
impl Shader {
pub fn reflect(&self) -> Result<ShaderReflection, ShaderReflectError> {
let module = match &self {
// TODO: process macros here
Shader::Wgsl(source) => naga::front::wgsl::parse_str(source)?,
Shader::Glsl(source, shader_stage) => {
let mut parser = naga::front::glsl::Parser::default();
parser
.parse(&naga::front::glsl::Options::from(*shader_stage), source)
.map_err(ShaderReflectError::GlslParse)?
}
Shader::SpirV(source) => naga::front::spv::parse_u8_slice(
source,
&naga::front::spv::Options {
adjust_coordinate_space: false,
..naga::front::spv::Options::default()
},
)?,
};
let module_info = naga::valid::Validator::new(
naga::valid::ValidationFlags::default(),
naga::valid::Capabilities::default(),
)
.validate(&module)?;
Ok(ShaderReflection {
module,
module_info,
})
}
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
Shader::Wgsl(source.into())
}
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
Shader::Glsl(source.into(), stage)
}
pub fn from_spirv(source: Vec<u8>) -> Shader {
Shader::SpirV(source)
}
}
#[derive(Default)]
pub struct ShaderLoader;
impl AssetLoader for ShaderLoader {
fn load<'a>(
&'a self,
bytes: &'a [u8],
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<(), anyhow::Error>> {
Box::pin(async move {
let ext = load_context.path().extension().unwrap().to_str().unwrap();
let shader = match ext {
"spv" => Shader::from_spirv(Vec::from(bytes)),
"wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?),
_ => panic!("unhandled extension: {}", ext),
};
load_context.set_default_asset(LoadedAsset::new(shader));
Ok(())
})
}
fn extensions(&self) -> &[&str] {
&["spv", "wgsl"]
}
}
impl<'a> From<&'a Shader> for ShaderModuleDescriptor<'a> {
fn from(shader: &'a Shader) -> Self {
ShaderModuleDescriptor {
label: None,
source: match shader {
Shader::Wgsl(source) => ShaderSource::Wgsl(source.clone()),
Shader::Glsl(_source, _stage) => {
let reflection = shader.reflect().unwrap();
let wgsl = reflection.get_wgsl().unwrap();
ShaderSource::Wgsl(wgsl.into())
}
Shader::SpirV(_) => {
// TODO: we can probably just transmute the u8 array to u32?
let reflection = shader.reflect().unwrap();
let spirv = reflection.get_spirv().unwrap();
ShaderSource::SpirV(Cow::Owned(spirv))
}
},
}
}
}

View file

@ -15,20 +15,27 @@ pub use texture_atlas::*;
pub use texture_atlas_builder::*;
use bevy_app::prelude::*;
use bevy_asset::AddAsset;
use bevy_asset::{AddAsset, Assets, HandleUntyped};
use bevy_core_pipeline::Transparent2d;
use bevy_render2::{render_phase::DrawFunctions, RenderApp, RenderStage};
use bevy_reflect::TypeUuid;
use bevy_render2::{render_phase::DrawFunctions, render_resource::Shader, RenderApp, RenderStage};
#[derive(Default)]
pub struct SpritePlugin;
pub const SPRITE_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 2763343953151597127);
impl Plugin for SpritePlugin {
fn build(&self, app: &mut App) {
let mut shaders = app.world.get_resource_mut::<Assets<Shader>>().unwrap();
let sprite_shader = Shader::from_wgsl(include_str!("render/sprite.wgsl"));
shaders.set_untracked(SPRITE_SHADER_HANDLE, sprite_shader);
app.add_asset::<TextureAtlas>().register_type::<Sprite>();
let render_app = app.sub_app(RenderApp);
render_app
.init_resource::<ImageBindGroups>()
.init_resource::<SpriteShaders>()
.init_resource::<SpritePipeline>()
.init_resource::<SpriteMeta>()
.add_system_to_stage(RenderStage::Extract, render::extract_atlases)
.add_system_to_stage(RenderStage::Extract, render::extract_sprites)

View file

@ -1,6 +1,6 @@
use crate::{
texture_atlas::{TextureAtlas, TextureAtlasSprite},
Rect, Sprite,
Rect, Sprite, SPRITE_SHADER_HANDLE,
};
use bevy_asset::{Assets, Handle};
use bevy_core_pipeline::Transparent2d;
@ -15,7 +15,6 @@ use bevy_render2::{
render_phase::{Draw, DrawFunctions, RenderPhase, TrackedRenderPass},
render_resource::*,
renderer::{RenderDevice, RenderQueue},
shader::Shader,
texture::{BevyDefault, Image},
view::{ViewUniformOffset, ViewUniforms},
};
@ -23,18 +22,17 @@ use bevy_transform::components::GlobalTransform;
use bevy_utils::HashMap;
use bytemuck::{Pod, Zeroable};
pub struct SpriteShaders {
pipeline: RenderPipeline,
pub struct SpritePipeline {
view_layout: BindGroupLayout,
material_layout: BindGroupLayout,
pipeline: CachedPipelineId,
}
// TODO: this pattern for initializing the shaders / pipeline isn't ideal. this should be handled by the asset system
impl FromWorld for SpriteShaders {
impl FromWorld for SpritePipeline {
fn from_world(world: &mut World) -> Self {
let world = world.cell();
let render_device = world.get_resource::<RenderDevice>().unwrap();
let shader = Shader::from_wgsl(include_str!("sprite.wgsl"));
let shader_module = render_device.create_shader_module(&shader);
let mut pipeline_cache = world.get_resource_mut::<RenderPipelineCache>().unwrap();
let view_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[BindGroupLayoutEntry {
@ -77,20 +75,15 @@ impl FromWorld for SpriteShaders {
label: Some("sprite_material_layout"),
});
let pipeline_layout = render_device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("sprite_pipeline_layout"),
push_constant_ranges: &[],
bind_group_layouts: &[&view_layout, &material_layout],
});
let pipeline = render_device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some("sprite_pipeline"),
depth_stencil: None,
let descriptor = RenderPipelineDescriptor {
vertex: VertexState {
buffers: &[VertexBufferLayout {
shader: SPRITE_SHADER_HANDLE.typed::<Shader>(),
entry_point: "vertex".into(),
shader_defs: vec![],
buffers: vec![VertexBufferLayout {
array_stride: 20,
step_mode: VertexStepMode::Vertex,
attributes: &[
attributes: vec![
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
@ -103,13 +96,12 @@ impl FromWorld for SpriteShaders {
},
],
}],
module: &shader_module,
entry_point: "vertex",
},
fragment: Some(FragmentState {
module: &shader_module,
entry_point: "fragment",
targets: &[ColorTargetState {
shader: SPRITE_SHADER_HANDLE.typed::<Shader>(),
shader_defs: vec![],
entry_point: "fragment".into(),
targets: vec![ColorTargetState {
format: TextureFormat::bevy_default(),
blend: Some(BlendState {
color: BlendComponent {
@ -126,21 +118,27 @@ impl FromWorld for SpriteShaders {
write_mask: ColorWrites::ALL,
}],
}),
layout: Some(&pipeline_layout),
multisample: MultisampleState::default(),
layout: Some(vec![view_layout.clone(), material_layout.clone()]),
primitive: PrimitiveState {
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
front_face: FrontFace::Ccw,
cull_mode: None,
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
},
});
depth_stencil: None,
multisample: MultisampleState {
count: 1,
mask: !0,
alpha_to_coverage_enabled: false,
},
label: Some("sprite_pipeline".into()),
};
SpriteShaders {
pipeline,
SpritePipeline {
pipeline: pipeline_cache.queue(descriptor),
view_layout,
material_layout,
}
@ -327,7 +325,7 @@ pub fn queue_sprites(
render_device: Res<RenderDevice>,
mut sprite_meta: ResMut<SpriteMeta>,
view_uniforms: Res<ViewUniforms>,
sprite_shaders: Res<SpriteShaders>,
sprite_pipeline: Res<SpritePipeline>,
mut image_bind_groups: ResMut<ImageBindGroups>,
gpu_images: Res<RenderAssets<Image>>,
mut extracted_sprites: Query<(Entity, &ExtractedSprite)>,
@ -340,7 +338,7 @@ pub fn queue_sprites(
resource: view_binding,
}],
label: Some("sprite_view_bind_group"),
layout: &sprite_shaders.view_layout,
layout: &sprite_pipeline.view_layout,
}));
let draw_sprite_function = draw_functions.read().get_id::<DrawSprite>().unwrap();
for mut transparent_phase in views.iter_mut() {
@ -362,11 +360,12 @@ pub fn queue_sprites(
},
],
label: Some("sprite_material_bind_group"),
layout: &sprite_shaders.material_layout,
layout: &sprite_pipeline.material_layout,
})
});
transparent_phase.add(Transparent2d {
draw_function: draw_sprite_function,
pipeline: sprite_pipeline.pipeline,
entity,
sort_key: sprite.handle.clone_weak(),
});
@ -377,9 +376,9 @@ pub fn queue_sprites(
pub struct DrawSprite {
params: SystemState<(
SRes<SpriteShaders>,
SRes<SpriteMeta>,
SRes<ImageBindGroups>,
SRes<RenderPipelineCache>,
SQuery<Read<ViewUniformOffset>>,
SQuery<Read<ExtractedSprite>>,
)>,
@ -402,13 +401,13 @@ impl Draw<Transparent2d> for DrawSprite {
item: &Transparent2d,
) {
const INDICES: usize = 6;
let (sprite_shaders, sprite_meta, image_bind_groups, views, sprites) =
self.params.get(world);
let (sprite_meta, image_bind_groups, pipelines, views, sprites) = self.params.get(world);
let view_uniform = views.get(view).unwrap();
let sprite_meta = sprite_meta.into_inner();
let image_bind_groups = image_bind_groups.into_inner();
let extracted_sprite = sprites.get(item.entity).unwrap();
pass.set_render_pipeline(&sprite_shaders.into_inner().pipeline);
if let Some(pipeline) = pipelines.into_inner().get(item.pipeline) {
pass.set_render_pipeline(pipeline);
pass.set_vertex_buffer(0, sprite_meta.vertices.buffer().unwrap().slice(..));
pass.set_index_buffer(
sprite_meta.indices.buffer().unwrap().slice(..),
@ -437,3 +436,4 @@ impl Draw<Transparent2d> for DrawSprite {
);
}
}
}