Shader Imports. Decouple Mesh logic from PBR (#3137)

## Shader Imports

This adds "whole file" shader imports. These come in two flavors:

### Asset Path Imports

```rust
// /assets/shaders/custom.wgsl

#import "shaders/custom_material.wgsl"

[[stage(fragment)]]
fn fragment() -> [[location(0)]] vec4<f32> {
    return get_color();
}
```

```rust
// /assets/shaders/custom_material.wgsl

[[block]]
struct CustomMaterial {
    color: vec4<f32>;
};
[[group(1), binding(0)]]
var<uniform> material: CustomMaterial;
```

### Custom Path Imports

Enables defining custom import paths. These are intended to be used by crates to export shader functionality:

```rust
// bevy_pbr2/src/render/pbr.wgsl

#import bevy_pbr::mesh_view_bind_group
#import bevy_pbr::mesh_bind_group

[[block]]
struct StandardMaterial {
    base_color: vec4<f32>;
    emissive: vec4<f32>;
    perceptual_roughness: f32;
    metallic: f32;
    reflectance: f32;
    flags: u32;
};

/* rest of PBR fragment shader here */
```

```rust
impl Plugin for MeshRenderPlugin {
    fn build(&self, app: &mut bevy_app::App) {
        let mut shaders = app.world.get_resource_mut::<Assets<Shader>>().unwrap();
        shaders.set_untracked(
            MESH_BIND_GROUP_HANDLE,
            Shader::from_wgsl(include_str!("mesh_bind_group.wgsl"))
                .with_import_path("bevy_pbr::mesh_bind_group"),
        );
        shaders.set_untracked(
            MESH_VIEW_BIND_GROUP_HANDLE,
            Shader::from_wgsl(include_str!("mesh_view_bind_group.wgsl"))
                .with_import_path("bevy_pbr::mesh_view_bind_group"),
        );
```

By convention these should use rust-style module paths that start with the crate name. Ultimately we might enforce this convention.

Note that this feature implements _run time_ import resolution. Ultimately we should move the import logic into an asset preprocessor once Bevy gets support for that.

## Decouple Mesh Logic from PBR Logic via MeshRenderPlugin

This breaks out mesh rendering code from PBR material code, which improves the legibility of the code, decouples mesh logic from PBR logic, and opens the door for a future `MaterialPlugin<T: Material>` that handles all of the pipeline setup for arbitrary shader materials.

## Removed `RenderAsset<Shader>` in favor of extracting shaders into RenderPipelineCache

This simplifies the shader import implementation and removes the need to pass around `RenderAssets<Shader>`.

##  RenderCommands are now fallible

This allows us to cleanly handle pipelines+shaders not being ready yet. We can abort a render command early in these cases, preventing bevy from trying to bind group / do draw calls for pipelines that couldn't be bound. This could also be used in the future for things like "components not existing on entities yet". 

# Next Steps

* Investigate using Naga for "partial typed imports" (ex: `#import bevy_pbr::material::StandardMaterial`, which would import only the StandardMaterial struct)
* Implement `MaterialPlugin<T: Material>` for low-boilerplate custom material shaders
* Move shader import logic into the asset preprocessor once bevy gets support for that.

Fixes #3132
This commit is contained in:
Carter Anderson 2021-11-18 03:45:02 +00:00
parent 1076a8f2b5
commit 2e79951659
21 changed files with 1421 additions and 1001 deletions

View file

@ -1,46 +0,0 @@
[[block]]
struct View {
view_proj: mat4x4<f32>;
projection: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[block]]
struct Mesh {
transform: mat4x4<f32>;
};
[[group(1), binding(0)]]
var<uniform> mesh: Mesh;
struct Vertex {
[[location(0)]] position: vec3<f32>;
[[location(1)]] normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
};
struct VertexOutput {
[[builtin(position)]] clip_position: vec4<f32>;
};
[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.transform * vec4<f32>(vertex.position, 1.0);
var out: VertexOutput;
out.clip_position = view.view_proj * world_position;
return out;
}
[[block]]
struct CustomMaterial {
color: vec4<f32>;
};
[[group(2), binding(0)]]
var<uniform> material: CustomMaterial;
[[stage(fragment)]]
fn fragment() -> [[location(0)]] vec4<f32> {
return material.color;
}

View file

@ -0,0 +1,11 @@
[[block]]
struct CustomMaterial {
color: vec4<f32>;
};
[[group(1), binding(0)]]
var<uniform> material: CustomMaterial;
[[stage(fragment)]]
fn fragment() -> [[location(0)]] vec4<f32> {
return material.color;
}

View file

@ -1,16 +1,6 @@
[[block]]
struct View {
view_proj: mat4x4<f32>;
projection: mat4x4<f32>;
world_position: vec3<f32>;
};
[[group(0), binding(0)]]
var<uniform> view: View;
#import bevy_pbr::mesh_view_bind_group
#import bevy_pbr::mesh_struct
[[block]]
struct Mesh {
transform: mat4x4<f32>;
};
[[group(1), binding(0)]]
var<uniform> mesh: Mesh;
@ -26,7 +16,7 @@ struct VertexOutput {
[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.transform * vec4<f32>(vertex.position, 1.0);
let world_position = mesh.model * vec4<f32>(vertex.position, 1.0);
var out: VertexOutput;
out.clip_position = view.view_proj * world_position;

View file

@ -46,13 +46,19 @@ impl<T: Asset> LoadedAsset<T> {
}
}
pub fn with_dependency(mut self, asset_path: AssetPath) -> Self {
pub fn add_dependency(&mut self, asset_path: AssetPath) {
self.dependencies.push(asset_path.to_owned());
}
pub fn with_dependency(mut self, asset_path: AssetPath) -> Self {
self.add_dependency(asset_path);
self
}
pub fn with_dependencies(mut self, asset_paths: Vec<AssetPath<'static>>) -> Self {
self.dependencies.extend(asset_paths);
pub fn with_dependencies(mut self, mut asset_paths: Vec<AssetPath<'static>>) -> Self {
for asset_path in asset_paths.drain(..) {
self.add_dependency(asset_path);
}
self
}
}

View file

@ -7,8 +7,8 @@ use bevy::{
},
math::{Vec3, Vec4},
pbr2::{
DrawMesh, MeshUniform, PbrPipeline, PbrPipelineKey, SetMeshViewBindGroup,
SetTransformBindGroup,
DrawMesh, MeshPipeline, MeshPipelineKey, MeshUniform, SetMeshBindGroup,
SetMeshViewBindGroup,
},
prelude::{AddAsset, App, AssetServer, Assets, GlobalTransform, Handle, Plugin, Transform},
reflect::TypeUuid,
@ -19,8 +19,8 @@ use bevy::{
render_asset::{PrepareAssetError, RenderAsset, RenderAssetPlugin, RenderAssets},
render_component::ExtractComponentPlugin,
render_phase::{
AddRenderCommand, DrawFunctions, EntityRenderCommand, RenderPhase, SetItemPipeline,
TrackedRenderPass,
AddRenderCommand, DrawFunctions, EntityRenderCommand, RenderCommandResult, RenderPhase,
SetItemPipeline, TrackedRenderPass,
},
render_resource::*,
renderer::RenderDevice,
@ -127,22 +127,21 @@ impl Plugin for CustomMaterialPlugin {
}
pub struct CustomPipeline {
mesh_pipeline: MeshPipeline,
material_layout: BindGroupLayout,
shader: Handle<Shader>,
pbr_pipeline: PbrPipeline,
}
impl SpecializedPipeline for CustomPipeline {
type Key = PbrPipelineKey;
type Key = MeshPipelineKey;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
let mut descriptor = self.pbr_pipeline.specialize(key);
descriptor.vertex.shader = self.shader.clone();
let mut descriptor = self.mesh_pipeline.specialize(key);
descriptor.fragment.as_mut().unwrap().shader = self.shader.clone();
descriptor.layout = Some(vec![
self.pbr_pipeline.view_layout.clone(),
self.pbr_pipeline.mesh_layout.clone(),
self.mesh_pipeline.view_layout.clone(),
self.material_layout.clone(),
self.mesh_pipeline.mesh_layout.clone(),
]);
descriptor
}
@ -167,8 +166,8 @@ impl FromWorld for CustomPipeline {
});
CustomPipeline {
pbr_pipeline: world.get_resource::<PbrPipeline>().unwrap().clone(),
shader: asset_server.load("shaders/custom.wgsl"),
mesh_pipeline: world.get_resource::<MeshPipeline>().unwrap().clone(),
shader: asset_server.load("shaders/custom_material.wgsl"),
material_layout,
}
}
@ -189,7 +188,7 @@ pub fn queue_custom(
.read()
.get_id::<DrawCustom>()
.unwrap();
let key = PbrPipelineKey::from_msaa_samples(msaa.samples);
let key = MeshPipelineKey::from_msaa_samples(msaa.samples);
for (view, mut transparent_phase) in views.iter_mut() {
let view_matrix = view.transform.compute_matrix();
let view_row_2 = view_matrix.row(2);
@ -213,8 +212,8 @@ pub fn queue_custom(
type DrawCustom = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetTransformBindGroup<1>,
SetCustomMaterialBindGroup,
SetMeshBindGroup<2>,
DrawMesh,
);
@ -229,9 +228,10 @@ impl EntityRenderCommand for SetCustomMaterialBindGroup {
item: Entity,
(materials, query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
) -> RenderCommandResult {
let material_handle = query.get(item).unwrap();
let material = materials.into_inner().get(material_handle).unwrap();
pass.set_bind_group(2, &material.bind_group, &[]);
pass.set_bind_group(1, &material.bind_group, &[]);
RenderCommandResult::Success
}
}

View file

@ -4,8 +4,8 @@ use bevy::{
ecs::prelude::*,
math::Vec3,
pbr2::{
DrawMesh, MeshUniform, PbrPipeline, PbrPipelineKey, SetMeshViewBindGroup,
SetTransformBindGroup,
DrawMesh, MeshPipeline, MeshPipelineKey, MeshUniform, SetMeshBindGroup,
SetMeshViewBindGroup,
},
prelude::{App, AssetServer, Assets, GlobalTransform, Handle, Plugin, Transform},
render2::{
@ -86,39 +86,39 @@ fn setup(mut commands: Commands, mut meshes: ResMut<Assets<Mesh>>) {
}
struct IsRedPipeline {
mesh_pipline: MeshPipeline,
shader: Handle<Shader>,
pbr_pipeline: PbrPipeline,
}
impl FromWorld for IsRedPipeline {
fn from_world(world: &mut World) -> Self {
let asset_server = world.get_resource::<AssetServer>().unwrap();
let pbr_pipeline = world.get_resource::<PbrPipeline>().unwrap();
let mesh_pipeline = world.get_resource::<MeshPipeline>().unwrap();
let shader = asset_server.load("shaders/shader_defs.wgsl");
IsRedPipeline {
mesh_pipline: mesh_pipeline.clone(),
shader,
pbr_pipeline: pbr_pipeline.clone(),
}
}
}
impl SpecializedPipeline for IsRedPipeline {
type Key = (IsRed, PbrPipelineKey);
type Key = (IsRed, MeshPipelineKey);
fn specialize(&self, (is_red, pbr_pipeline_key): Self::Key) -> RenderPipelineDescriptor {
let mut shader_defs = Vec::new();
if is_red.0 {
shader_defs.push("IS_RED".to_string());
}
let mut descriptor = self.pbr_pipeline.specialize(pbr_pipeline_key);
let mut descriptor = self.mesh_pipline.specialize(pbr_pipeline_key);
descriptor.vertex.shader = self.shader.clone();
descriptor.vertex.shader_defs = shader_defs.clone();
let fragment = descriptor.fragment.as_mut().unwrap();
fragment.shader = self.shader.clone();
fragment.shader_defs = shader_defs;
descriptor.layout = Some(vec![
self.pbr_pipeline.view_layout.clone(),
self.pbr_pipeline.mesh_layout.clone(),
self.mesh_pipline.view_layout.clone(),
self.mesh_pipline.mesh_layout.clone(),
]);
descriptor
}
@ -127,7 +127,7 @@ impl SpecializedPipeline for IsRedPipeline {
type DrawIsRed = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetTransformBindGroup<1>,
SetMeshBindGroup<1>,
DrawMesh,
);
@ -144,7 +144,7 @@ fn queue_custom(
.read()
.get_id::<DrawIsRed>()
.unwrap();
let key = PbrPipelineKey::from_msaa_samples(msaa.samples);
let key = MeshPipelineKey::from_msaa_samples(msaa.samples);
for (view, mut transparent_phase) in views.iter_mut() {
let view_matrix = view.transform.compute_matrix();
let view_row_2 = view_matrix.row(2);

View file

@ -3,7 +3,7 @@ use bevy_reflect::Reflect;
// FIXME: This should probably be part of bevy_render2!
/// Alpha mode
#[derive(Debug, Reflect, Clone, PartialEq)]
#[derive(Debug, Reflect, Copy, Clone, PartialEq)]
#[reflect(Component)]
pub enum AlphaMode {
Opaque,

View file

@ -16,7 +16,7 @@ use bevy_core_pipeline::{AlphaMask3d, Opaque3d, Transparent3d};
use bevy_ecs::prelude::*;
use bevy_reflect::TypeUuid;
use bevy_render2::{
render_component::{ExtractComponentPlugin, UniformComponentPlugin},
render_component::ExtractComponentPlugin,
render_graph::RenderGraph,
render_phase::{sort_phase_system, AddRenderCommand, DrawFunctions},
render_resource::{Shader, SpecializedPipelines},
@ -44,14 +44,18 @@ pub struct PbrPlugin;
impl Plugin for PbrPlugin {
fn build(&self, app: &mut App) {
let mut shaders = app.world.get_resource_mut::<Assets<Shader>>().unwrap();
let pbr_shader = Shader::from_wgsl(include_str!("render/pbr.wgsl"));
shaders.set_untracked(PBR_SHADER_HANDLE, pbr_shader);
let shadow_shader = Shader::from_wgsl(include_str!("render/depth.wgsl"));
shaders.set_untracked(SHADOW_SHADER_HANDLE, shadow_shader);
shaders.set_untracked(
PBR_SHADER_HANDLE,
Shader::from_wgsl(include_str!("render/pbr.wgsl")),
);
shaders.set_untracked(
SHADOW_SHADER_HANDLE,
Shader::from_wgsl(include_str!("render/depth.wgsl")),
);
app.add_plugin(StandardMaterialPlugin)
.add_plugin(MeshRenderPlugin)
.add_plugin(ExtractComponentPlugin::<Handle<StandardMaterial>>::default())
.add_plugin(UniformComponentPlugin::<MeshUniform>::default())
.init_resource::<AmbientLight>()
.init_resource::<DirectionalLightShadowMap>()
.init_resource::<PointLightShadowMap>()
@ -84,7 +88,6 @@ impl Plugin for PbrPlugin {
let render_app = app.sub_app(RenderApp);
render_app
.add_system_to_stage(RenderStage::Extract, render::extract_meshes)
.add_system_to_stage(
RenderStage::Extract,
render::extract_lights.label(RenderLightSystems::ExtractLights),
@ -97,13 +100,12 @@ impl Plugin for PbrPlugin {
.exclusive_system()
.label(RenderLightSystems::PrepareLights),
)
.add_system_to_stage(RenderStage::Queue, render::queue_meshes)
.add_system_to_stage(
RenderStage::Queue,
render::queue_shadows.label(RenderLightSystems::QueueShadows),
)
.add_system_to_stage(RenderStage::Queue, queue_meshes)
.add_system_to_stage(RenderStage::Queue, render::queue_shadow_view_bind_group)
.add_system_to_stage(RenderStage::Queue, render::queue_transform_bind_group)
.add_system_to_stage(RenderStage::PhaseSort, sort_phase_system::<Shadow>)
.init_resource::<PbrPipeline>()
.init_resource::<ShadowPipeline>()

View file

@ -163,16 +163,18 @@ impl RenderAsset for StandardMaterial {
material: Self::ExtractedAsset,
(render_device, pbr_pipeline, gpu_images): &mut SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
let (base_color_texture_view, base_color_sampler) = if let Some(result) =
pbr_pipeline.image_handle_to_texture(gpu_images, &material.base_color_texture)
let (base_color_texture_view, base_color_sampler) = if let Some(result) = pbr_pipeline
.mesh_pipeline
.get_image_texture(gpu_images, &material.base_color_texture)
{
result
} else {
return Err(PrepareAssetError::RetryNextUpdate(material));
};
let (emissive_texture_view, emissive_sampler) = if let Some(result) =
pbr_pipeline.image_handle_to_texture(gpu_images, &material.emissive_texture)
let (emissive_texture_view, emissive_sampler) = if let Some(result) = pbr_pipeline
.mesh_pipeline
.get_image_texture(gpu_images, &material.emissive_texture)
{
result
} else {
@ -180,21 +182,25 @@ impl RenderAsset for StandardMaterial {
};
let (metallic_roughness_texture_view, metallic_roughness_sampler) = if let Some(result) =
pbr_pipeline.image_handle_to_texture(gpu_images, &material.metallic_roughness_texture)
pbr_pipeline
.mesh_pipeline
.get_image_texture(gpu_images, &material.metallic_roughness_texture)
{
result
} else {
return Err(PrepareAssetError::RetryNextUpdate(material));
};
let (normal_map_texture_view, normal_map_sampler) = if let Some(result) =
pbr_pipeline.image_handle_to_texture(gpu_images, &material.normal_map_texture)
let (normal_map_texture_view, normal_map_sampler) = if let Some(result) = pbr_pipeline
.mesh_pipeline
.get_image_texture(gpu_images, &material.normal_map_texture)
{
result
} else {
return Err(PrepareAssetError::RetryNextUpdate(material));
};
let (occlusion_texture_view, occlusion_sampler) = if let Some(result) =
pbr_pipeline.image_handle_to_texture(gpu_images, &material.occlusion_texture)
let (occlusion_texture_view, occlusion_sampler) = if let Some(result) = pbr_pipeline
.mesh_pipeline
.get_image_texture(gpu_images, &material.occlusion_texture)
{
result
} else {

View file

@ -1,3 +1,5 @@
#import bevy_pbr::mesh_struct
// NOTE: Keep in sync with pbr.wgsl
[[block]]
struct View {
@ -8,15 +10,6 @@ struct View {
[[group(0), binding(0)]]
var<uniform> view: View;
[[block]]
struct Mesh {
model: mat4x4<f32>;
inverse_transpose_model: mat4x4<f32>;
// 'flags' is a bit field indicating various options. u32 is 32 bits so we have up to 32 options.
flags: u32;
};
[[group(1), binding(0)]]
var<uniform> mesh: Mesh;

View file

@ -1,6 +1,6 @@
use crate::{
AmbientLight, CubemapVisibleEntities, DirectionalLight, DirectionalLightShadowMap, DrawMesh,
NotShadowCaster, PbrPipeline, PointLight, PointLightShadowMap, SetTransformBindGroup,
MeshPipeline, NotShadowCaster, PointLight, PointLightShadowMap, SetMeshBindGroup,
SHADOW_SHADER_HANDLE,
};
use bevy_asset::Handle;
@ -19,7 +19,8 @@ use bevy_render2::{
render_graph::{Node, NodeRunError, RenderGraphContext, SlotInfo, SlotType},
render_phase::{
CachedPipelinePhaseItem, DrawFunctionId, DrawFunctions, EntityPhaseItem,
EntityRenderCommand, PhaseItem, RenderPhase, SetItemPipeline, TrackedRenderPass,
EntityRenderCommand, PhaseItem, RenderCommandResult, RenderPhase, SetItemPipeline,
TrackedRenderPass,
},
render_resource::*,
renderer::{RenderContext, RenderDevice, RenderQueue},
@ -140,11 +141,11 @@ impl FromWorld for ShadowPipeline {
label: Some("shadow_view_layout"),
});
let pbr_pipeline = world.get_resource::<PbrPipeline>().unwrap();
let mesh_pipeline = world.get_resource::<MeshPipeline>().unwrap();
ShadowPipeline {
view_layout,
mesh_layout: pbr_pipeline.mesh_layout.clone(),
mesh_layout: mesh_pipeline.mesh_layout.clone(),
point_light_sampler: render_device.create_sampler(&SamplerDescriptor {
address_mode_u: AddressMode::ClampToEdge,
address_mode_v: AddressMode::ClampToEdge,
@ -881,7 +882,7 @@ impl Node for ShadowPassNode {
pub type DrawShadowMesh = (
SetItemPipeline,
SetShadowViewBindGroup<0>,
SetTransformBindGroup<1>,
SetMeshBindGroup<1>,
DrawMesh,
);
@ -894,7 +895,7 @@ impl<const I: usize> EntityRenderCommand for SetShadowViewBindGroup<I> {
_item: Entity,
(light_meta, view_query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
) -> RenderCommandResult {
let view_uniform_offset = view_query.get(view).unwrap();
pass.set_bind_group(
I,
@ -905,5 +906,7 @@ impl<const I: usize> EntityRenderCommand for SetShadowViewBindGroup<I> {
.unwrap(),
&[view_uniform_offset.offset],
);
RenderCommandResult::Success
}
}

View file

@ -0,0 +1,658 @@
use crate::{LightMeta, NotShadowCaster, NotShadowReceiver, ShadowPipeline, ViewLights};
use bevy_app::Plugin;
use bevy_asset::{Assets, Handle, HandleUntyped};
use bevy_ecs::{
prelude::*,
system::{lifetimeless::*, SystemParamItem},
};
use bevy_math::Mat4;
use bevy_reflect::TypeUuid;
use bevy_render2::{
mesh::Mesh,
render_asset::RenderAssets,
render_component::{ComponentUniforms, DynamicUniformIndex, UniformComponentPlugin},
render_phase::{EntityRenderCommand, RenderCommandResult, TrackedRenderPass},
render_resource::*,
renderer::{RenderDevice, RenderQueue},
texture::{BevyDefault, GpuImage, Image, TextureFormatPixelInfo},
view::{ComputedVisibility, ViewUniformOffset, ViewUniforms},
RenderApp, RenderStage,
};
use bevy_transform::components::GlobalTransform;
use crevice::std140::AsStd140;
use wgpu::{
Extent3d, ImageCopyTexture, ImageDataLayout, Origin3d, TextureDimension, TextureFormat,
TextureViewDescriptor,
};
#[derive(Default)]
pub struct MeshRenderPlugin;
pub const MESH_VIEW_BIND_GROUP_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 9076678235888822571);
pub const MESH_STRUCT_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 2506024101911992377);
pub const MESH_SHADER_HANDLE: HandleUntyped =
HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 3252377289100772450);
impl Plugin for MeshRenderPlugin {
fn build(&self, app: &mut bevy_app::App) {
let mut shaders = app.world.get_resource_mut::<Assets<Shader>>().unwrap();
shaders.set_untracked(
MESH_SHADER_HANDLE,
Shader::from_wgsl(include_str!("mesh.wgsl")),
);
shaders.set_untracked(
MESH_STRUCT_HANDLE,
Shader::from_wgsl(include_str!("mesh_struct.wgsl"))
.with_import_path("bevy_pbr::mesh_struct"),
);
shaders.set_untracked(
MESH_VIEW_BIND_GROUP_HANDLE,
Shader::from_wgsl(include_str!("mesh_view_bind_group.wgsl"))
.with_import_path("bevy_pbr::mesh_view_bind_group"),
);
app.add_plugin(UniformComponentPlugin::<MeshUniform>::default());
app.sub_app(RenderApp)
.init_resource::<MeshPipeline>()
.add_system_to_stage(RenderStage::Extract, extract_meshes)
.add_system_to_stage(RenderStage::Queue, queue_mesh_bind_group)
.add_system_to_stage(RenderStage::Queue, queue_mesh_view_bind_groups);
}
}
#[derive(AsStd140, Clone)]
pub struct MeshUniform {
pub transform: Mat4,
pub inverse_transpose_model: Mat4,
pub flags: u32,
}
// NOTE: These must match the bit flags in bevy_pbr2/src/render/mesh.wgsl!
bitflags::bitflags! {
#[repr(transparent)]
struct MeshFlags: u32 {
const SHADOW_RECEIVER = (1 << 0);
const NONE = 0;
const UNINITIALIZED = 0xFFFF;
}
}
pub fn extract_meshes(
mut commands: Commands,
mut previous_caster_len: Local<usize>,
mut previous_not_caster_len: Local<usize>,
caster_query: Query<
(
Entity,
&ComputedVisibility,
&GlobalTransform,
&Handle<Mesh>,
Option<&NotShadowReceiver>,
),
Without<NotShadowCaster>,
>,
not_caster_query: Query<
(
Entity,
&ComputedVisibility,
&GlobalTransform,
&Handle<Mesh>,
Option<&NotShadowReceiver>,
),
With<NotShadowCaster>,
>,
) {
let mut caster_values = Vec::with_capacity(*previous_caster_len);
for (entity, computed_visibility, transform, handle, not_receiver) in caster_query.iter() {
if !computed_visibility.is_visible {
continue;
}
let transform = transform.compute_matrix();
caster_values.push((
entity,
(
handle.clone_weak(),
MeshUniform {
flags: if not_receiver.is_some() {
MeshFlags::empty().bits
} else {
MeshFlags::SHADOW_RECEIVER.bits
},
transform,
inverse_transpose_model: transform.inverse().transpose(),
},
),
));
}
*previous_caster_len = caster_values.len();
commands.insert_or_spawn_batch(caster_values);
let mut not_caster_values = Vec::with_capacity(*previous_not_caster_len);
for (entity, computed_visibility, transform, handle, not_receiver) in not_caster_query.iter() {
if !computed_visibility.is_visible {
continue;
}
let transform = transform.compute_matrix();
not_caster_values.push((
entity,
(
handle.clone_weak(),
MeshUniform {
flags: if not_receiver.is_some() {
MeshFlags::empty().bits
} else {
MeshFlags::SHADOW_RECEIVER.bits
},
transform,
inverse_transpose_model: transform.inverse().transpose(),
},
NotShadowCaster,
),
));
}
*previous_not_caster_len = not_caster_values.len();
commands.insert_or_spawn_batch(not_caster_values);
}
#[derive(Clone)]
pub struct MeshPipeline {
pub view_layout: BindGroupLayout,
pub mesh_layout: BindGroupLayout,
// This dummy white texture is to be used in place of optional StandardMaterial textures
pub dummy_white_gpu_image: GpuImage,
}
impl FromWorld for MeshPipeline {
fn from_world(world: &mut World) -> Self {
let render_device = world.get_resource::<RenderDevice>().unwrap();
let view_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[
// View
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
// TODO: change this to ViewUniform::std140_size_static once crevice fixes this!
// Context: https://github.com/LPGhatguy/crevice/issues/29
min_binding_size: BufferSize::new(144),
},
count: None,
},
// Lights
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
// TODO: change this to GpuLights::std140_size_static once crevice fixes this!
// Context: https://github.com/LPGhatguy/crevice/issues/29
min_binding_size: BufferSize::new(1424),
},
count: None,
},
// Point Shadow Texture Cube Array
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
multisampled: false,
sample_type: TextureSampleType::Depth,
view_dimension: TextureViewDimension::CubeArray,
},
count: None,
},
// Point Shadow Texture Array Sampler
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Sampler {
comparison: true,
filtering: true,
},
count: None,
},
// Directional Shadow Texture Array
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
multisampled: false,
sample_type: TextureSampleType::Depth,
view_dimension: TextureViewDimension::D2Array,
},
count: None,
},
// Directional Shadow Texture Array Sampler
BindGroupLayoutEntry {
binding: 5,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Sampler {
comparison: true,
filtering: true,
},
count: None,
},
],
label: Some("mesh_view_layout"),
});
let mesh_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
// TODO: change this to MeshUniform::std140_size_static once crevice fixes this!
// Context: https://github.com/LPGhatguy/crevice/issues/29
min_binding_size: BufferSize::new(144),
},
count: None,
}],
label: Some("mesh_layout"),
});
// A 1x1x1 'all 1.0' texture to use as a dummy texture to use in place of optional StandardMaterial textures
let dummy_white_gpu_image = {
let image = Image::new_fill(
Extent3d::default(),
TextureDimension::D2,
&[255u8; 4],
TextureFormat::bevy_default(),
);
let texture = render_device.create_texture(&image.texture_descriptor);
let sampler = render_device.create_sampler(&image.sampler_descriptor);
let format_size = image.texture_descriptor.format.pixel_size();
let render_queue = world.get_resource_mut::<RenderQueue>().unwrap();
render_queue.write_texture(
ImageCopyTexture {
texture: &texture,
mip_level: 0,
origin: Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
&image.data,
ImageDataLayout {
offset: 0,
bytes_per_row: Some(
std::num::NonZeroU32::new(
image.texture_descriptor.size.width * format_size as u32,
)
.unwrap(),
),
rows_per_image: None,
},
image.texture_descriptor.size,
);
let texture_view = texture.create_view(&TextureViewDescriptor::default());
GpuImage {
texture,
texture_view,
sampler,
}
};
MeshPipeline {
view_layout,
mesh_layout,
dummy_white_gpu_image,
}
}
}
impl MeshPipeline {
pub fn get_image_texture<'a>(
&'a self,
gpu_images: &'a RenderAssets<Image>,
handle_option: &Option<Handle<Image>>,
) -> Option<(&'a TextureView, &'a Sampler)> {
if let Some(handle) = handle_option {
let gpu_image = gpu_images.get(handle)?;
Some((&gpu_image.texture_view, &gpu_image.sampler))
} else {
Some((
&self.dummy_white_gpu_image.texture_view,
&self.dummy_white_gpu_image.sampler,
))
}
}
}
bitflags::bitflags! {
#[repr(transparent)]
// NOTE: Apparently quadro drivers support up to 64x MSAA.
/// MSAA uses the highest 6 bits for the MSAA sample count - 1 to support up to 64x MSAA.
pub struct MeshPipelineKey: u32 {
const NONE = 0;
const VERTEX_TANGENTS = (1 << 0);
const TRANSPARENT_MAIN_PASS = (1 << 1);
const MSAA_RESERVED_BITS = MeshPipelineKey::MSAA_MASK_BITS << MeshPipelineKey::MSAA_SHIFT_BITS;
}
}
impl MeshPipelineKey {
const MSAA_MASK_BITS: u32 = 0b111111;
const MSAA_SHIFT_BITS: u32 = 32 - 6;
pub fn from_msaa_samples(msaa_samples: u32) -> Self {
let msaa_bits = ((msaa_samples - 1) & Self::MSAA_MASK_BITS) << Self::MSAA_SHIFT_BITS;
MeshPipelineKey::from_bits(msaa_bits).unwrap()
}
pub fn msaa_samples(&self) -> u32 {
((self.bits >> Self::MSAA_SHIFT_BITS) & Self::MSAA_MASK_BITS) + 1
}
}
impl SpecializedPipeline for MeshPipeline {
type Key = MeshPipelineKey;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
let (vertex_array_stride, vertex_attributes) =
if key.contains(MeshPipelineKey::VERTEX_TANGENTS) {
(
48,
vec![
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv (GOTCHA! uv is no longer third in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 40,
shader_location: 2,
},
// Tangent
VertexAttribute {
format: VertexFormat::Float32x4,
offset: 24,
shader_location: 3,
},
],
)
} else {
(
32,
vec![
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 24,
shader_location: 2,
},
],
)
};
let mut shader_defs = Vec::new();
if key.contains(MeshPipelineKey::VERTEX_TANGENTS) {
shader_defs.push(String::from("VERTEX_TANGENTS"));
}
let (label, blend, depth_write_enabled);
if key.contains(MeshPipelineKey::TRANSPARENT_MAIN_PASS) {
label = "transparent_mesh_pipeline".into();
blend = Some(BlendState::ALPHA_BLENDING);
// For the transparent pass, fragments that are closer will be alpha blended
// but their depth is not written to the depth buffer
depth_write_enabled = false;
} else {
label = "opaque_mesh_pipeline".into();
blend = Some(BlendState::REPLACE);
// For the opaque and alpha mask passes, fragments that are closer will replace
// the current fragment value in the output and the depth is written to the
// depth buffer
depth_write_enabled = true;
}
RenderPipelineDescriptor {
vertex: VertexState {
shader: MESH_SHADER_HANDLE.typed::<Shader>(),
entry_point: "vertex".into(),
shader_defs: shader_defs.clone(),
buffers: vec![VertexBufferLayout {
array_stride: vertex_array_stride,
step_mode: VertexStepMode::Vertex,
attributes: vertex_attributes,
}],
},
fragment: Some(FragmentState {
shader: MESH_SHADER_HANDLE.typed::<Shader>(),
shader_defs,
entry_point: "fragment".into(),
targets: vec![ColorTargetState {
format: TextureFormat::bevy_default(),
blend,
write_mask: ColorWrites::ALL,
}],
}),
layout: Some(vec![self.view_layout.clone(), self.mesh_layout.clone()]),
primitive: PrimitiveState {
front_face: FrontFace::Ccw,
cull_mode: Some(Face::Back),
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
},
depth_stencil: Some(DepthStencilState {
format: TextureFormat::Depth32Float,
depth_write_enabled,
depth_compare: CompareFunction::Greater,
stencil: StencilState {
front: StencilFaceState::IGNORE,
back: StencilFaceState::IGNORE,
read_mask: 0,
write_mask: 0,
},
bias: DepthBiasState {
constant: 0,
slope_scale: 0.0,
clamp: 0.0,
},
}),
multisample: MultisampleState {
count: key.msaa_samples(),
mask: !0,
alpha_to_coverage_enabled: false,
},
label: Some(label),
}
}
}
pub struct MeshBindGroup {
pub value: BindGroup,
}
pub fn queue_mesh_bind_group(
mut commands: Commands,
mesh_pipeline: Res<MeshPipeline>,
render_device: Res<RenderDevice>,
mesh_uniforms: Res<ComponentUniforms<MeshUniform>>,
) {
if let Some(binding) = mesh_uniforms.uniforms().binding() {
commands.insert_resource(MeshBindGroup {
value: render_device.create_bind_group(&BindGroupDescriptor {
entries: &[BindGroupEntry {
binding: 0,
resource: binding,
}],
label: Some("mesh_bind_group"),
layout: &mesh_pipeline.mesh_layout,
}),
});
}
}
pub struct MeshViewBindGroup {
pub value: BindGroup,
}
pub fn queue_mesh_view_bind_groups(
mut commands: Commands,
render_device: Res<RenderDevice>,
mesh_pipeline: Res<MeshPipeline>,
shadow_pipeline: Res<ShadowPipeline>,
light_meta: Res<LightMeta>,
view_uniforms: Res<ViewUniforms>,
mut views: Query<(Entity, &ViewLights)>,
) {
if let (Some(view_binding), Some(light_binding)) = (
view_uniforms.uniforms.binding(),
light_meta.view_gpu_lights.binding(),
) {
for (entity, view_lights) in views.iter_mut() {
let view_bind_group = render_device.create_bind_group(&BindGroupDescriptor {
entries: &[
BindGroupEntry {
binding: 0,
resource: view_binding.clone(),
},
BindGroupEntry {
binding: 1,
resource: light_binding.clone(),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&view_lights.point_light_depth_texture_view,
),
},
BindGroupEntry {
binding: 3,
resource: BindingResource::Sampler(&shadow_pipeline.point_light_sampler),
},
BindGroupEntry {
binding: 4,
resource: BindingResource::TextureView(
&view_lights.directional_light_depth_texture_view,
),
},
BindGroupEntry {
binding: 5,
resource: BindingResource::Sampler(
&shadow_pipeline.directional_light_sampler,
),
},
],
label: Some("mesh_view_bind_group"),
layout: &mesh_pipeline.view_layout,
});
commands.entity(entity).insert(MeshViewBindGroup {
value: view_bind_group,
});
}
}
}
pub struct SetMeshViewBindGroup<const I: usize>;
impl<const I: usize> EntityRenderCommand for SetMeshViewBindGroup<I> {
type Param = SQuery<(
Read<ViewUniformOffset>,
Read<ViewLights>,
Read<MeshViewBindGroup>,
)>;
#[inline]
fn render<'w>(
view: Entity,
_item: Entity,
view_query: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let (view_uniform, view_lights, mesh_view_bind_group) = view_query.get(view).unwrap();
pass.set_bind_group(
I,
&mesh_view_bind_group.value,
&[view_uniform.offset, view_lights.gpu_light_binding_index],
);
RenderCommandResult::Success
}
}
pub struct SetMeshBindGroup<const I: usize>;
impl<const I: usize> EntityRenderCommand for SetMeshBindGroup<I> {
type Param = (
SRes<MeshBindGroup>,
SQuery<Read<DynamicUniformIndex<MeshUniform>>>,
);
#[inline]
fn render<'w>(
_view: Entity,
item: Entity,
(mesh_bind_group, mesh_query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let mesh_index = mesh_query.get(item).unwrap();
pass.set_bind_group(
I,
&mesh_bind_group.into_inner().value,
&[mesh_index.index()],
);
RenderCommandResult::Success
}
}
pub struct DrawMesh;
impl EntityRenderCommand for DrawMesh {
type Param = (SRes<RenderAssets<Mesh>>, SQuery<Read<Handle<Mesh>>>);
#[inline]
fn render<'w>(
_view: Entity,
item: Entity,
(meshes, mesh_query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) -> RenderCommandResult {
let mesh_handle = mesh_query.get(item).unwrap();
let gpu_mesh = meshes.into_inner().get(mesh_handle).unwrap();
pass.set_vertex_buffer(0, gpu_mesh.vertex_buffer.slice(..));
if let Some(index_info) = &gpu_mesh.index_info {
pass.set_index_buffer(index_info.buffer.slice(..), 0, index_info.index_format);
pass.draw_indexed(0..index_info.count, 0, 0..1);
} else {
panic!("non-indexed drawing not supported yet")
}
RenderCommandResult::Success
}
}
#[cfg(test)]
mod tests {
use super::MeshPipelineKey;
#[test]
fn mesh_key_msaa_samples() {
for i in 1..=64 {
assert_eq!(MeshPipelineKey::from_msaa_samples(i).msaa_samples(), i);
}
}
}

View file

@ -0,0 +1,65 @@
#import bevy_pbr::mesh_view_bind_group
#import bevy_pbr::mesh_struct
struct Vertex {
[[location(0)]] position: vec3<f32>;
[[location(1)]] normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
#ifdef VERTEX_TANGENTS
[[location(3)]] tangent: vec4<f32>;
#endif
};
struct VertexOutput {
[[builtin(position)]] clip_position: vec4<f32>;
[[location(0)]] world_position: vec4<f32>;
[[location(1)]] world_normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
#ifdef VERTEX_TANGENTS
[[location(3)]] world_tangent: vec4<f32>;
#endif
};
[[group(2), binding(0)]]
var<uniform> mesh: Mesh;
[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.model * vec4<f32>(vertex.position, 1.0);
var out: VertexOutput;
out.uv = vertex.uv;
out.world_position = world_position;
out.clip_position = view.view_proj * world_position;
out.world_normal = mat3x3<f32>(
mesh.inverse_transpose_model[0].xyz,
mesh.inverse_transpose_model[1].xyz,
mesh.inverse_transpose_model[2].xyz
) * vertex.normal;
#ifdef VERTEX_TANGENTS
out.world_tangent = vec4<f32>(
mat3x3<f32>(
mesh.model[0].xyz,
mesh.model[1].xyz,
mesh.model[2].xyz
) * vertex.tangent.xyz,
vertex.tangent.w
);
#endif
return out;
}
struct FragmentInput {
[[builtin(front_facing)]] is_front: bool;
[[location(0)]] world_position: vec4<f32>;
[[location(1)]] world_normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
#ifdef VERTEX_TANGENTS
[[location(3)]] world_tangent: vec4<f32>;
#endif
};
[[stage(fragment)]]
fn fragment(in: FragmentInput) -> [[location(0)]] vec4<f32> {
return vec4<f32>(1.0, 0.0, 1.0, 1.0);
}

View file

@ -0,0 +1,9 @@
[[block]]
struct Mesh {
model: mat4x4<f32>;
inverse_transpose_model: mat4x4<f32>;
// 'flags' is a bit field indicating various options. u32 is 32 bits so we have up to 32 options.
flags: u32;
};
let MESH_FLAGS_SHADOW_RECEIVER_BIT: u32 = 1u;

View file

@ -0,0 +1,50 @@
[[block]]
struct View {
view_proj: mat4x4<f32>;
projection: mat4x4<f32>;
world_position: vec3<f32>;
};
struct PointLight {
projection: mat4x4<f32>;
color: vec4<f32>;
position: vec3<f32>;
inverse_square_range: f32;
radius: f32;
near: f32;
far: f32;
shadow_depth_bias: f32;
shadow_normal_bias: f32;
};
struct DirectionalLight {
view_projection: mat4x4<f32>;
color: vec4<f32>;
direction_to_light: vec3<f32>;
shadow_depth_bias: f32;
shadow_normal_bias: f32;
};
[[block]]
struct Lights {
// NOTE: this array size must be kept in sync with the constants defined bevy_pbr2/src/render/light.rs
// TODO: this can be removed if we move to storage buffers for light arrays
point_lights: array<PointLight, 10>;
directional_lights: array<DirectionalLight, 1>;
ambient_color: vec4<f32>;
n_point_lights: u32;
n_directional_lights: u32;
};
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(0), binding(1)]]
var<uniform> lights: Lights;
[[group(0), binding(2)]]
var point_shadow_textures: texture_depth_cube_array;
[[group(0), binding(3)]]
var point_shadow_textures_sampler: sampler_comparison;
[[group(0), binding(4)]]
var directional_shadow_textures: texture_depth_2d_array;
[[group(0), binding(5)]]
var directional_shadow_textures_sampler: sampler_comparison;

View file

@ -1,55 +1,28 @@
mod light;
mod mesh;
pub use light::*;
pub use mesh::*;
use crate::{
AlphaMode, NotShadowCaster, NotShadowReceiver, StandardMaterial, StandardMaterialUniformData,
PBR_SHADER_HANDLE,
};
use crate::{AlphaMode, StandardMaterial, StandardMaterialUniformData, PBR_SHADER_HANDLE};
use bevy_asset::Handle;
use bevy_core_pipeline::{AlphaMask3d, Opaque3d, Transparent3d};
use bevy_ecs::{
prelude::*,
system::{lifetimeless::*, SystemParamItem},
};
use bevy_math::Mat4;
use bevy_render2::{
mesh::Mesh,
render_asset::RenderAssets,
render_component::{ComponentUniforms, DynamicUniformIndex},
render_phase::{
DrawFunctions, EntityRenderCommand, RenderPhase, SetItemPipeline, TrackedRenderPass,
DrawFunctions, EntityRenderCommand, RenderCommandResult, RenderPhase, SetItemPipeline,
TrackedRenderPass,
},
render_resource::*,
renderer::{RenderDevice, RenderQueue},
texture::{BevyDefault, GpuImage, Image, TextureFormatPixelInfo},
view::{
ComputedVisibility, ExtractedView, Msaa, ViewUniformOffset, ViewUniforms, VisibleEntities,
},
renderer::RenderDevice,
view::{ExtractedView, Msaa, VisibleEntities},
};
use bevy_transform::components::GlobalTransform;
use crevice::std140::AsStd140;
use wgpu::{
Extent3d, ImageCopyTexture, ImageDataLayout, Origin3d, TextureDimension, TextureFormat,
TextureViewDescriptor,
};
#[derive(AsStd140, Clone)]
pub struct MeshUniform {
pub transform: Mat4,
pub inverse_transpose_model: Mat4,
pub flags: u32,
}
// NOTE: These must match the bit flags in bevy_pbr2/src/render/pbr.wgsl!
bitflags::bitflags! {
#[repr(transparent)]
struct MeshFlags: u32 {
const SHADOW_RECEIVER = (1 << 0);
const NONE = 0;
const UNINITIALIZED = 0xFFFF;
}
}
// NOTE: These must match the bit flags in bevy_pbr2/src/render/pbr.wgsl!
bitflags::bitflags! {
@ -69,186 +42,15 @@ bitflags::bitflags! {
}
}
pub fn extract_meshes(
mut commands: Commands,
mut previous_caster_len: Local<usize>,
mut previous_not_caster_len: Local<usize>,
caster_query: Query<
(
Entity,
&ComputedVisibility,
&GlobalTransform,
&Handle<Mesh>,
Option<&NotShadowReceiver>,
),
Without<NotShadowCaster>,
>,
not_caster_query: Query<
(
Entity,
&ComputedVisibility,
&GlobalTransform,
&Handle<Mesh>,
Option<&NotShadowReceiver>,
),
With<NotShadowCaster>,
>,
) {
let mut caster_values = Vec::with_capacity(*previous_caster_len);
for (entity, computed_visibility, transform, handle, not_receiver) in caster_query.iter() {
if !computed_visibility.is_visible {
continue;
}
let transform = transform.compute_matrix();
caster_values.push((
entity,
(
handle.clone_weak(),
MeshUniform {
flags: if not_receiver.is_some() {
MeshFlags::empty().bits
} else {
MeshFlags::SHADOW_RECEIVER.bits
},
transform,
inverse_transpose_model: transform.inverse().transpose(),
},
),
));
}
*previous_caster_len = caster_values.len();
commands.insert_or_spawn_batch(caster_values);
let mut not_caster_values = Vec::with_capacity(*previous_not_caster_len);
for (entity, computed_visibility, transform, handle, not_receiver) in not_caster_query.iter() {
if !computed_visibility.is_visible {
continue;
}
let transform = transform.compute_matrix();
not_caster_values.push((
entity,
(
handle.clone_weak(),
MeshUniform {
flags: if not_receiver.is_some() {
MeshFlags::empty().bits
} else {
MeshFlags::SHADOW_RECEIVER.bits
},
transform,
inverse_transpose_model: transform.inverse().transpose(),
},
NotShadowCaster,
),
));
}
*previous_not_caster_len = not_caster_values.len();
commands.insert_or_spawn_batch(not_caster_values);
}
#[derive(Clone)]
pub struct PbrPipeline {
pub view_layout: BindGroupLayout,
pub mesh_pipeline: MeshPipeline,
pub material_layout: BindGroupLayout,
pub mesh_layout: BindGroupLayout,
// This dummy white texture is to be used in place of optional StandardMaterial textures
pub dummy_white_gpu_image: GpuImage,
}
impl PbrPipeline {
pub fn image_handle_to_texture<'a>(
&'a self,
gpu_images: &'a RenderAssets<Image>,
handle_option: &Option<Handle<Image>>,
) -> Option<(&'a TextureView, &'a Sampler)> {
if let Some(handle) = handle_option {
let gpu_image = gpu_images.get(handle)?;
Some((&gpu_image.texture_view, &gpu_image.sampler))
} else {
Some((
&self.dummy_white_gpu_image.texture_view,
&self.dummy_white_gpu_image.sampler,
))
}
}
}
impl FromWorld for PbrPipeline {
fn from_world(world: &mut World) -> Self {
let render_device = world.get_resource::<RenderDevice>().unwrap();
let view_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[
// View
BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
// TODO: change this to ViewUniform::std140_size_static once crevice fixes this!
// Context: https://github.com/LPGhatguy/crevice/issues/29
min_binding_size: BufferSize::new(144),
},
count: None,
},
// Lights
BindGroupLayoutEntry {
binding: 1,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
// TODO: change this to GpuLights::std140_size_static once crevice fixes this!
// Context: https://github.com/LPGhatguy/crevice/issues/29
min_binding_size: BufferSize::new(1424),
},
count: None,
},
// Point Shadow Texture Cube Array
BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
multisampled: false,
sample_type: TextureSampleType::Depth,
view_dimension: TextureViewDimension::CubeArray,
},
count: None,
},
// Point Shadow Texture Array Sampler
BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Sampler {
comparison: true,
filtering: true,
},
count: None,
},
// Directional Shadow Texture Array
BindGroupLayoutEntry {
binding: 4,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Texture {
multisampled: false,
sample_type: TextureSampleType::Depth,
view_dimension: TextureViewDimension::D2Array,
},
count: None,
},
// Directional Shadow Texture Array Sampler
BindGroupLayoutEntry {
binding: 5,
visibility: ShaderStages::FRAGMENT,
ty: BindingType::Sampler {
comparison: true,
filtering: true,
},
count: None,
},
],
label: Some("pbr_view_layout"),
});
let material_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[
@ -373,264 +175,34 @@ impl FromWorld for PbrPipeline {
label: Some("pbr_material_layout"),
});
let mesh_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
entries: &[BindGroupLayoutEntry {
binding: 0,
visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Uniform,
has_dynamic_offset: true,
// TODO: change this to MeshUniform::std140_size_static once crevice fixes this!
// Context: https://github.com/LPGhatguy/crevice/issues/29
min_binding_size: BufferSize::new(144),
},
count: None,
}],
label: Some("pbr_mesh_layout"),
});
// A 1x1x1 'all 1.0' texture to use as a dummy texture to use in place of optional StandardMaterial textures
let dummy_white_gpu_image = {
let image = Image::new_fill(
Extent3d::default(),
TextureDimension::D2,
&[255u8; 4],
TextureFormat::bevy_default(),
);
let texture = render_device.create_texture(&image.texture_descriptor);
let sampler = render_device.create_sampler(&image.sampler_descriptor);
let format_size = image.texture_descriptor.format.pixel_size();
let render_queue = world.get_resource_mut::<RenderQueue>().unwrap();
render_queue.write_texture(
ImageCopyTexture {
texture: &texture,
mip_level: 0,
origin: Origin3d::ZERO,
aspect: wgpu::TextureAspect::All,
},
&image.data,
ImageDataLayout {
offset: 0,
bytes_per_row: Some(
std::num::NonZeroU32::new(
image.texture_descriptor.size.width * format_size as u32,
)
.unwrap(),
),
rows_per_image: None,
},
image.texture_descriptor.size,
);
let texture_view = texture.create_view(&TextureViewDescriptor::default());
GpuImage {
texture,
texture_view,
sampler,
}
};
PbrPipeline {
view_layout,
material_layout,
mesh_layout,
dummy_white_gpu_image,
mesh_pipeline: world.get_resource::<MeshPipeline>().unwrap().clone(),
}
}
}
bitflags::bitflags! {
#[repr(transparent)]
// NOTE: Apparently quadro drivers support up to 64x MSAA.
/// MSAA uses the highest 6 bits for the MSAA sample count - 1 to support up to 64x MSAA.
pub struct PbrPipelineKey: u32 {
const NONE = 0;
const VERTEX_TANGENTS = (1 << 0);
const STANDARDMATERIAL_NORMAL_MAP = (1 << 1);
const OPAQUE_MAIN_PASS = (1 << 2);
const ALPHA_MASK_MAIN_PASS = (1 << 3);
const TRANSPARENT_MAIN_PASS = (1 << 4);
const MSAA_RESERVED_BITS = PbrPipelineKey::MSAA_MASK_BITS << PbrPipelineKey::MSAA_SHIFT_BITS;
}
}
impl PbrPipelineKey {
const MSAA_MASK_BITS: u32 = 0b111111;
const MSAA_SHIFT_BITS: u32 = 32 - 6;
pub fn from_msaa_samples(msaa_samples: u32) -> Self {
let msaa_bits = ((msaa_samples - 1) & Self::MSAA_MASK_BITS) << Self::MSAA_SHIFT_BITS;
PbrPipelineKey::from_bits(msaa_bits).unwrap()
}
pub fn msaa_samples(&self) -> u32 {
((self.bits >> Self::MSAA_SHIFT_BITS) & Self::MSAA_MASK_BITS) + 1
}
#[derive(Clone, Copy, Hash, PartialEq, Eq)]
pub struct PbrPipelineKey {
pub mesh_key: MeshPipelineKey,
pub normal_map: bool,
}
impl SpecializedPipeline for PbrPipeline {
type Key = PbrPipelineKey;
fn specialize(&self, key: Self::Key) -> RenderPipelineDescriptor {
let (vertex_array_stride, vertex_attributes) =
if key.contains(PbrPipelineKey::VERTEX_TANGENTS) {
(
48,
vec![
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv (GOTCHA! uv is no longer third in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 40,
shader_location: 2,
},
// Tangent
VertexAttribute {
format: VertexFormat::Float32x4,
offset: 24,
shader_location: 3,
},
],
)
} else {
(
32,
vec![
// Position (GOTCHA! Vertex_Position isn't first in the buffer due to how Mesh sorts attributes (alphabetically))
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 12,
shader_location: 0,
},
// Normal
VertexAttribute {
format: VertexFormat::Float32x3,
offset: 0,
shader_location: 1,
},
// Uv
VertexAttribute {
format: VertexFormat::Float32x2,
offset: 24,
shader_location: 2,
},
],
)
};
let mut shader_defs = Vec::new();
if key.contains(PbrPipelineKey::VERTEX_TANGENTS) {
shader_defs.push(String::from("VERTEX_TANGENTS"));
let mut descriptor = self.mesh_pipeline.specialize(key.mesh_key);
descriptor.fragment.as_mut().unwrap().shader = PBR_SHADER_HANDLE.typed::<Shader>();
descriptor.layout = Some(vec![
self.mesh_pipeline.view_layout.clone(),
self.material_layout.clone(),
self.mesh_pipeline.mesh_layout.clone(),
]);
if let Some(label) = &mut descriptor.label {
*label = format!("pbr_{}", *label).into();
}
if key.contains(PbrPipelineKey::STANDARDMATERIAL_NORMAL_MAP) {
shader_defs.push(String::from("STANDARDMATERIAL_NORMAL_MAP"));
}
let (label, blend, depth_write_enabled);
if key.contains(PbrPipelineKey::TRANSPARENT_MAIN_PASS) {
label = Some("transparent_pbr_pipeline".into());
blend = Some(BlendState::ALPHA_BLENDING);
// For the transparent pass, fragments that are closer will be alpha blended
// but their depth is not written to the depth buffer
depth_write_enabled = false;
} else {
label = Some("opaque_pbr_pipeline".into());
blend = Some(BlendState::REPLACE);
// For the opaque and alpha mask passes, fragments that are closer will replace
// the current fragment value in the output and the depth is written to the
// depth buffer
depth_write_enabled = true;
}
RenderPipelineDescriptor {
vertex: VertexState {
shader: PBR_SHADER_HANDLE.typed::<Shader>(),
entry_point: "vertex".into(),
shader_defs: shader_defs.clone(),
buffers: vec![VertexBufferLayout {
array_stride: vertex_array_stride,
step_mode: VertexStepMode::Vertex,
attributes: vertex_attributes,
}],
},
fragment: Some(FragmentState {
shader: PBR_SHADER_HANDLE.typed::<Shader>(),
shader_defs,
entry_point: "fragment".into(),
targets: vec![ColorTargetState {
format: TextureFormat::bevy_default(),
blend,
write_mask: ColorWrites::ALL,
}],
}),
layout: Some(vec![
self.view_layout.clone(),
self.material_layout.clone(),
self.mesh_layout.clone(),
]),
primitive: PrimitiveState {
front_face: FrontFace::Ccw,
cull_mode: Some(Face::Back),
polygon_mode: PolygonMode::Fill,
clamp_depth: false,
conservative: false,
topology: PrimitiveTopology::TriangleList,
strip_index_format: None,
},
depth_stencil: Some(DepthStencilState {
format: TextureFormat::Depth32Float,
depth_write_enabled,
depth_compare: CompareFunction::Greater,
stencil: StencilState {
front: StencilFaceState::IGNORE,
back: StencilFaceState::IGNORE,
read_mask: 0,
write_mask: 0,
},
bias: DepthBiasState {
constant: 0,
slope_scale: 0.0,
clamp: 0.0,
},
}),
multisample: MultisampleState {
count: key.msaa_samples(),
mask: !0,
alpha_to_coverage_enabled: false,
},
label,
}
}
}
pub struct TransformBindGroup {
pub value: BindGroup,
}
pub fn queue_transform_bind_group(
mut commands: Commands,
pbr_pipeline: Res<PbrPipeline>,
render_device: Res<RenderDevice>,
transform_uniforms: Res<ComponentUniforms<MeshUniform>>,
) {
if let Some(binding) = transform_uniforms.uniforms().binding() {
commands.insert_resource(TransformBindGroup {
value: render_device.create_bind_group(&BindGroupDescriptor {
entries: &[BindGroupEntry {
binding: 0,
resource: binding,
}],
label: Some("transform_bind_group"),
layout: &pbr_pipeline.mesh_layout,
}),
});
descriptor
}
}
@ -640,128 +212,67 @@ pub struct PbrViewBindGroup {
#[allow(clippy::too_many_arguments)]
pub fn queue_meshes(
mut commands: Commands,
opaque_draw_functions: Res<DrawFunctions<Opaque3d>>,
alpha_mask_draw_functions: Res<DrawFunctions<AlphaMask3d>>,
transparent_draw_functions: Res<DrawFunctions<Transparent3d>>,
render_device: Res<RenderDevice>,
pbr_pipeline: Res<PbrPipeline>,
shadow_pipeline: Res<ShadowPipeline>,
mut pipelines: ResMut<SpecializedPipelines<PbrPipeline>>,
mut pipeline_cache: ResMut<RenderPipelineCache>,
light_meta: Res<LightMeta>,
msaa: Res<Msaa>,
view_uniforms: Res<ViewUniforms>,
render_meshes: Res<RenderAssets<Mesh>>,
render_materials: Res<RenderAssets<StandardMaterial>>,
standard_material_meshes: Query<(&Handle<StandardMaterial>, &Handle<Mesh>, &MeshUniform)>,
mut views: Query<(
Entity,
&ExtractedView,
&ViewLights,
&VisibleEntities,
&mut RenderPhase<Opaque3d>,
&mut RenderPhase<AlphaMask3d>,
&mut RenderPhase<Transparent3d>,
)>,
) {
if let (Some(view_binding), Some(light_binding)) = (
view_uniforms.uniforms.binding(),
light_meta.view_gpu_lights.binding(),
) {
for (
entity,
view,
view_lights,
visible_entities,
mut opaque_phase,
mut alpha_mask_phase,
mut transparent_phase,
) in views.iter_mut()
{
let view_bind_group = render_device.create_bind_group(&BindGroupDescriptor {
entries: &[
BindGroupEntry {
binding: 0,
resource: view_binding.clone(),
},
BindGroupEntry {
binding: 1,
resource: light_binding.clone(),
},
BindGroupEntry {
binding: 2,
resource: BindingResource::TextureView(
&view_lights.point_light_depth_texture_view,
),
},
BindGroupEntry {
binding: 3,
resource: BindingResource::Sampler(&shadow_pipeline.point_light_sampler),
},
BindGroupEntry {
binding: 4,
resource: BindingResource::TextureView(
&view_lights.directional_light_depth_texture_view,
),
},
BindGroupEntry {
binding: 5,
resource: BindingResource::Sampler(
&shadow_pipeline.directional_light_sampler,
),
},
],
label: Some("pbr_view_bind_group"),
layout: &pbr_pipeline.view_layout,
});
for (view, visible_entities, mut opaque_phase, mut alpha_mask_phase, mut transparent_phase) in
views.iter_mut()
{
let draw_opaque_pbr = opaque_draw_functions.read().get_id::<DrawPbr>().unwrap();
let draw_alpha_mask_pbr = alpha_mask_draw_functions
.read()
.get_id::<DrawPbr>()
.unwrap();
let draw_transparent_pbr = transparent_draw_functions
.read()
.get_id::<DrawPbr>()
.unwrap();
commands.entity(entity).insert(PbrViewBindGroup {
value: view_bind_group,
});
let inverse_view_matrix = view.transform.compute_matrix().inverse();
let inverse_view_row_2 = inverse_view_matrix.row(2);
let mesh_key = MeshPipelineKey::from_msaa_samples(msaa.samples);
let draw_opaque_pbr = opaque_draw_functions.read().get_id::<DrawPbr>().unwrap();
let draw_alpha_mask_pbr = alpha_mask_draw_functions
.read()
.get_id::<DrawPbr>()
.unwrap();
let draw_transparent_pbr = transparent_draw_functions
.read()
.get_id::<DrawPbr>()
.unwrap();
let inverse_view_matrix = view.transform.compute_matrix().inverse();
let inverse_view_row_2 = inverse_view_matrix.row(2);
for visible_entity in &visible_entities.entities {
if let Ok((material_handle, mesh_handle, mesh_uniform)) =
standard_material_meshes.get(visible_entity.entity)
{
let mut key = PbrPipelineKey::from_msaa_samples(msaa.samples);
let alpha_mode = if let Some(material) = render_materials.get(material_handle) {
if material.has_normal_map {
key |= PbrPipelineKey::STANDARDMATERIAL_NORMAL_MAP;
}
material.alpha_mode.clone()
} else {
continue;
for visible_entity in &visible_entities.entities {
if let Ok((material_handle, mesh_handle, mesh_uniform)) =
standard_material_meshes.get(visible_entity.entity)
{
if let Some(material) = render_materials.get(material_handle) {
let mut pbr_key = PbrPipelineKey {
mesh_key,
normal_map: material.has_normal_map,
};
if let Some(mesh) = render_meshes.get(mesh_handle) {
if mesh.has_tangents {
key |= PbrPipelineKey::VERTEX_TANGENTS;
pbr_key.mesh_key |= MeshPipelineKey::VERTEX_TANGENTS;
}
}
key |= match alpha_mode {
AlphaMode::Opaque => PbrPipelineKey::OPAQUE_MAIN_PASS,
AlphaMode::Mask(_) => PbrPipelineKey::ALPHA_MASK_MAIN_PASS,
AlphaMode::Blend => PbrPipelineKey::TRANSPARENT_MAIN_PASS,
};
let pipeline_id = pipelines.specialize(&mut pipeline_cache, &pbr_pipeline, key);
if let AlphaMode::Blend = material.alpha_mode {
pbr_key.mesh_key |= MeshPipelineKey::TRANSPARENT_MAIN_PASS
}
let pipeline_id =
pipelines.specialize(&mut pipeline_cache, &pbr_pipeline, pbr_key);
// NOTE: row 2 of the inverse view matrix dotted with column 3 of the model matrix
// gives the z component of translation of the mesh in view space
let mesh_z = inverse_view_row_2.dot(mesh_uniform.transform.col(3));
match alpha_mode {
match material.alpha_mode {
AlphaMode::Opaque => {
opaque_phase.add(Opaque3d {
entity: visible_entity.entity,
@ -809,55 +320,10 @@ pub type DrawPbr = (
SetItemPipeline,
SetMeshViewBindGroup<0>,
SetStandardMaterialBindGroup<1>,
SetTransformBindGroup<2>,
SetMeshBindGroup<2>,
DrawMesh,
);
pub struct SetMeshViewBindGroup<const I: usize>;
impl<const I: usize> EntityRenderCommand for SetMeshViewBindGroup<I> {
type Param = SQuery<(
Read<ViewUniformOffset>,
Read<ViewLights>,
Read<PbrViewBindGroup>,
)>;
#[inline]
fn render<'w>(
view: Entity,
_item: Entity,
view_query: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let (view_uniform, view_lights, pbr_view_bind_group) = view_query.get(view).unwrap();
pass.set_bind_group(
I,
&pbr_view_bind_group.value,
&[view_uniform.offset, view_lights.gpu_light_binding_index],
);
}
}
pub struct SetTransformBindGroup<const I: usize>;
impl<const I: usize> EntityRenderCommand for SetTransformBindGroup<I> {
type Param = (
SRes<TransformBindGroup>,
SQuery<Read<DynamicUniformIndex<MeshUniform>>>,
);
#[inline]
fn render<'w>(
_view: Entity,
item: Entity,
(transform_bind_group, mesh_query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let transform_index = mesh_query.get(item).unwrap();
pass.set_bind_group(
I,
&transform_bind_group.into_inner().value,
&[transform_index.index()],
);
}
}
pub struct SetStandardMaterialBindGroup<const I: usize>;
impl<const I: usize> EntityRenderCommand for SetStandardMaterialBindGroup<I> {
type Param = (
@ -870,44 +336,12 @@ impl<const I: usize> EntityRenderCommand for SetStandardMaterialBindGroup<I> {
item: Entity,
(materials, handle_query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
) -> RenderCommandResult {
let handle = handle_query.get(item).unwrap();
let materials = materials.into_inner();
let material = materials.get(handle).unwrap();
pass.set_bind_group(I, &material.bind_group, &[]);
}
}
pub struct DrawMesh;
impl EntityRenderCommand for DrawMesh {
type Param = (SRes<RenderAssets<Mesh>>, SQuery<Read<Handle<Mesh>>>);
#[inline]
fn render<'w>(
_view: Entity,
item: Entity,
(meshes, mesh_query): SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let mesh_handle = mesh_query.get(item).unwrap();
let gpu_mesh = meshes.into_inner().get(mesh_handle).unwrap();
pass.set_vertex_buffer(0, gpu_mesh.vertex_buffer.slice(..));
if let Some(index_info) = &gpu_mesh.index_info {
pass.set_index_buffer(index_info.buffer.slice(..), 0, index_info.index_format);
pass.draw_indexed(0..index_info.count, 0, 0..1);
} else {
panic!("non-indexed drawing not supported yet")
}
}
}
#[cfg(test)]
mod tests {
use super::PbrPipelineKey;
#[test]
fn pbr_key_msaa_samples() {
for i in 1..=64 {
assert_eq!(PbrPipelineKey::from_msaa_samples(i).msaa_samples(), i);
}
RenderCommandResult::Success
}
}

View file

@ -1,72 +1,3 @@
// NOTE: Keep in sync with depth.wgsl
[[block]]
struct View {
view_proj: mat4x4<f32>;
projection: mat4x4<f32>;
world_position: vec3<f32>;
};
[[block]]
struct Mesh {
model: mat4x4<f32>;
inverse_transpose_model: mat4x4<f32>;
// 'flags' is a bit field indicating various options. u32 is 32 bits so we have up to 32 options.
flags: u32;
};
let MESH_FLAGS_SHADOW_RECEIVER_BIT: u32 = 1u;
[[group(0), binding(0)]]
var<uniform> view: View;
[[group(2), binding(0)]]
var<uniform> mesh: Mesh;
struct Vertex {
[[location(0)]] position: vec3<f32>;
[[location(1)]] normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
#ifdef VERTEX_TANGENTS
[[location(3)]] tangent: vec4<f32>;
#endif
};
struct VertexOutput {
[[builtin(position)]] clip_position: vec4<f32>;
[[location(0)]] world_position: vec4<f32>;
[[location(1)]] world_normal: vec3<f32>;
[[location(2)]] uv: vec2<f32>;
#ifdef VERTEX_TANGENTS
[[location(3)]] world_tangent: vec4<f32>;
#endif
};
[[stage(vertex)]]
fn vertex(vertex: Vertex) -> VertexOutput {
let world_position = mesh.model * vec4<f32>(vertex.position, 1.0);
var out: VertexOutput;
out.uv = vertex.uv;
out.world_position = world_position;
out.clip_position = view.view_proj * world_position;
out.world_normal = mat3x3<f32>(
mesh.inverse_transpose_model[0].xyz,
mesh.inverse_transpose_model[1].xyz,
mesh.inverse_transpose_model[2].xyz
) * vertex.normal;
#ifdef VERTEX_TANGENTS
out.world_tangent = vec4<f32>(
mat3x3<f32>(
mesh.model[0].xyz,
mesh.model[1].xyz,
mesh.model[2].xyz
) * vertex.tangent.xyz,
vertex.tangent.w
);
#endif
return out;
}
// From the Filament design doc
// https://google.github.io/filament/Filament.html#table_symbols
// Symbol Definition
@ -101,6 +32,12 @@ fn vertex(vertex: Vertex) -> VertexOutput {
//
// The above integration needs to be approximated.
#import bevy_pbr::mesh_view_bind_group
#import bevy_pbr::mesh_struct
[[group(2), binding(0)]]
var<uniform> mesh: Mesh;
[[block]]
struct StandardMaterial {
base_color: vec4<f32>;
@ -123,49 +60,6 @@ let STANDARD_MATERIAL_FLAGS_ALPHA_MODE_OPAQUE: u32 = 64u;
let STANDARD_MATERIAL_FLAGS_ALPHA_MODE_MASK: u32 = 128u;
let STANDARD_MATERIAL_FLAGS_ALPHA_MODE_BLEND: u32 = 256u;
struct PointLight {
projection: mat4x4<f32>;
color: vec4<f32>;
position: vec3<f32>;
inverse_square_range: f32;
radius: f32;
near: f32;
far: f32;
shadow_depth_bias: f32;
shadow_normal_bias: f32;
};
struct DirectionalLight {
view_projection: mat4x4<f32>;
color: vec4<f32>;
direction_to_light: vec3<f32>;
shadow_depth_bias: f32;
shadow_normal_bias: f32;
};
[[block]]
struct Lights {
// NOTE: this array size must be kept in sync with the constants defined bevy_pbr2/src/render/light.rs
// TODO: this can be removed if we move to storage buffers for light arrays
point_lights: array<PointLight, 10>;
directional_lights: array<DirectionalLight, 1>;
ambient_color: vec4<f32>;
n_point_lights: u32;
n_directional_lights: u32;
};
[[group(0), binding(1)]]
var<uniform> lights: Lights;
[[group(0), binding(2)]]
var point_shadow_textures: texture_depth_cube_array;
[[group(0), binding(3)]]
var point_shadow_textures_sampler: sampler_comparison;
[[group(0), binding(4)]]
var directional_shadow_textures: texture_depth_2d_array;
[[group(0), binding(5)]]
var directional_shadow_textures_sampler: sampler_comparison;
[[group(1), binding(0)]]
var<uniform> material: StandardMaterial;
[[group(1), binding(1)]]

View file

@ -16,7 +16,6 @@ pub use once_cell;
use crate::{
camera::CameraPlugin,
mesh::MeshPlugin,
render_asset::RenderAssetPlugin,
render_graph::RenderGraph,
render_resource::{RenderPipelineCache, Shader, ShaderLoader},
renderer::render_system,
@ -135,7 +134,7 @@ impl Plugin for RenderPlugin {
let mut render_app = App::empty();
let mut extract_stage =
SystemStage::parallel().with_system(RenderPipelineCache::extract_dirty_shaders);
SystemStage::parallel().with_system(RenderPipelineCache::extract_shaders);
// don't apply buffers when the stage finishes running
// extract stage runs on the app world, but the buffers are applied to the render world
extract_stage.set_apply_buffers(false);
@ -271,8 +270,7 @@ impl Plugin for RenderPlugin {
.add_plugin(CameraPlugin)
.add_plugin(ViewPlugin)
.add_plugin(MeshPlugin)
.add_plugin(ImagePlugin)
.add_plugin(RenderAssetPlugin::<Shader>::default());
.add_plugin(ImagePlugin);
}
}

View file

@ -140,7 +140,12 @@ pub trait RenderCommand<P: PhaseItem> {
item: &P,
param: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
);
) -> RenderCommandResult;
}
pub enum RenderCommandResult {
Success,
Failure,
}
pub trait EntityRenderCommand {
@ -150,7 +155,7 @@ pub trait EntityRenderCommand {
item: Entity,
param: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
);
) -> RenderCommandResult;
}
pub trait EntityPhaseItem: PhaseItem {
@ -170,8 +175,8 @@ impl<P: EntityPhaseItem, E: EntityRenderCommand> RenderCommand<P> for E {
item: &P,
param: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
<E as EntityRenderCommand>::render(view, item.entity(), param, pass);
) -> RenderCommandResult {
<E as EntityRenderCommand>::render(view, item.entity(), param, pass)
}
}
@ -184,12 +189,13 @@ impl<P: CachedPipelinePhaseItem> RenderCommand<P> for SetItemPipeline {
item: &P,
pipeline_cache: SystemParamItem<'w, '_, Self::Param>,
pass: &mut TrackedRenderPass<'w>,
) {
let pipeline = pipeline_cache
.into_inner()
.get_state(item.cached_pipeline())
.unwrap();
pass.set_render_pipeline(pipeline);
) -> RenderCommandResult {
if let Some(pipeline) = pipeline_cache.into_inner().get(item.cached_pipeline()) {
pass.set_render_pipeline(pipeline);
RenderCommandResult::Success
} else {
RenderCommandResult::Failure
}
}
}
@ -204,8 +210,11 @@ macro_rules! render_command_tuple_impl {
_item: &P,
($($name,)*): SystemParamItem<'w, '_, Self::Param>,
_pass: &mut TrackedRenderPass<'w>,
) {
$($name::render(_view, _item, $name, _pass);)*
) -> RenderCommandResult{
$(if let RenderCommandResult::Failure = $name::render(_view, _item, $name, _pass) {
return RenderCommandResult::Failure;
})*
RenderCommandResult::Success
}
}
};

View file

@ -1,15 +1,14 @@
use crate::{
render_asset::RenderAssets,
render_resource::{
AsModuleDescriptorError, BindGroupLayout, BindGroupLayoutId, ProcessShaderError,
RawFragmentState, RawRenderPipelineDescriptor, RawVertexState, RenderPipeline,
RenderPipelineDescriptor, Shader, ShaderProcessor,
RenderPipelineDescriptor, Shader, ShaderImport, ShaderProcessor,
},
renderer::RenderDevice,
RenderWorld,
};
use bevy_app::EventReader;
use bevy_asset::{AssetEvent, Handle};
use bevy_asset::{AssetEvent, Assets, Handle};
use bevy_ecs::system::{Res, ResMut};
use bevy_utils::{HashMap, HashSet};
use std::{collections::hash_map::Entry, hash::Hash, ops::Deref, sync::Arc};
@ -20,6 +19,8 @@ use wgpu::{PipelineLayoutDescriptor, ShaderModule, VertexBufferLayout};
pub struct ShaderData {
pipelines: HashSet<CachedPipelineId>,
processed_shaders: HashMap<Vec<String>, Arc<ShaderModule>>,
resolved_imports: HashMap<ShaderImport, Handle<Shader>>,
dependents: HashSet<Handle<Shader>>,
}
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
@ -32,6 +33,9 @@ impl CachedPipelineId {
#[derive(Default)]
struct ShaderCache {
data: HashMap<Handle<Shader>, ShaderData>,
shaders: HashMap<Handle<Shader>, Shader>,
import_path_shaders: HashMap<ShaderImport, Handle<Shader>>,
waiting_on_import: HashMap<ShaderImport, Vec<Handle<Shader>>>,
processor: ShaderProcessor,
}
@ -39,41 +43,101 @@ impl ShaderCache {
fn get(
&mut self,
render_device: &RenderDevice,
shaders: &RenderAssets<Shader>,
pipeline: CachedPipelineId,
handle: &Handle<Shader>,
shader_defs: &[String],
) -> Result<Arc<ShaderModule>, RenderPipelineError> {
let shader = shaders
let shader = self
.shaders
.get(handle)
.ok_or_else(|| RenderPipelineError::ShaderNotLoaded(handle.clone_weak()))?;
let data = self.data.entry(handle.clone_weak()).or_default();
if shader.imports().len() != data.resolved_imports.len() {
return Err(RenderPipelineError::ShaderImportNotYetAvailable);
}
data.pipelines.insert(pipeline);
// PERF: this shader_defs clone isn't great. use raw_entry_mut when it stabilizes
let module = match data.processed_shaders.entry(shader_defs.to_vec()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => {
let processed = self.processor.process_shader(shader, shader_defs)?;
let processed = self.processor.process(
shader,
shader_defs,
&self.shaders,
&self.import_path_shaders,
)?;
let module_descriptor = processed.get_module_descriptor()?;
entry.insert(Arc::new(
render_device.create_shader_module(&module_descriptor),
))
}
};
Ok(module.clone())
}
fn clear(
&mut self,
handle: &Handle<Shader>,
) -> Option<impl Iterator<Item = CachedPipelineId> + '_> {
let data = self.data.get_mut(handle)?;
data.processed_shaders.clear();
Some(data.pipelines.drain())
fn clear(&mut self, handle: &Handle<Shader>) -> Vec<CachedPipelineId> {
let mut shaders_to_clear = vec![handle.clone_weak()];
let mut pipelines_to_queue = Vec::new();
while let Some(handle) = shaders_to_clear.pop() {
if let Some(data) = self.data.get_mut(&handle) {
data.processed_shaders.clear();
pipelines_to_queue.extend(data.pipelines.iter().cloned());
shaders_to_clear.extend(data.dependents.iter().map(|h| h.clone_weak()));
}
}
pipelines_to_queue
}
fn remove(&mut self, handle: &Handle<Shader>) {
self.data.remove(handle);
fn set_shader(&mut self, handle: &Handle<Shader>, shader: Shader) -> Vec<CachedPipelineId> {
let pipelines_to_queue = self.clear(handle);
if let Some(path) = shader.import_path() {
self.import_path_shaders
.insert(path.clone(), handle.clone_weak());
if let Some(waiting_shaders) = self.waiting_on_import.get_mut(path) {
for waiting_shader in waiting_shaders.drain(..) {
// resolve waiting shader import
let data = self.data.entry(waiting_shader.clone_weak()).or_default();
data.resolved_imports
.insert(path.clone(), handle.clone_weak());
// add waiting shader as dependent of this shader
let data = self.data.entry(handle.clone_weak()).or_default();
data.dependents.insert(waiting_shader.clone_weak());
}
}
}
for import in shader.imports() {
if let Some(import_handle) = self.import_path_shaders.get(import) {
// resolve import because it is currently available
let data = self.data.entry(handle.clone_weak()).or_default();
data.resolved_imports
.insert(import.clone(), import_handle.clone_weak());
// add this shader as a dependent of the import
let data = self.data.entry(import_handle.clone_weak()).or_default();
data.dependents.insert(handle.clone_weak());
} else {
let waiting = self.waiting_on_import.entry(import.clone()).or_default();
waiting.push(handle.clone_weak());
}
}
self.shaders.insert(handle.clone_weak(), shader);
pipelines_to_queue
}
fn remove(&mut self, handle: &Handle<Shader>) -> Vec<CachedPipelineId> {
let pipelines_to_queue = self.clear(handle);
if let Some(shader) = self.shaders.remove(handle) {
if let Some(import_path) = shader.import_path() {
self.import_path_shaders.remove(import_path);
}
}
pipelines_to_queue
}
}
@ -107,7 +171,7 @@ pub struct RenderPipelineCache {
shader_cache: ShaderCache,
device: RenderDevice,
pipelines: Vec<CachedPipeline>,
waiting_pipelines: Vec<CachedPipelineId>,
waiting_pipelines: HashSet<CachedPipelineId>,
}
struct CachedPipeline {
@ -144,6 +208,8 @@ pub enum RenderPipelineError {
ProcessShaderError(#[from] ProcessShaderError),
#[error(transparent)]
AsModuleDescriptorError(#[from] AsModuleDescriptorError),
#[error("Shader import not yet available.")]
ShaderImportNotYetAvailable,
}
impl RenderPipelineCache {
@ -177,29 +243,27 @@ impl RenderPipelineCache {
descriptor,
state: CachedPipelineState::Queued,
});
self.waiting_pipelines.push(id);
self.waiting_pipelines.insert(id);
id
}
pub fn mark_shader_dirty(&mut self, shader: &Handle<Shader>) {
if let Some(cached_pipelines) = self.shader_cache.clear(shader) {
for cached_pipeline in cached_pipelines {
self.pipelines[cached_pipeline.0].state = CachedPipelineState::Queued;
self.waiting_pipelines.push(cached_pipeline);
}
fn set_shader(&mut self, handle: &Handle<Shader>, shader: &Shader) {
let pipelines_to_queue = self.shader_cache.set_shader(handle, shader.clone());
for cached_pipeline in pipelines_to_queue {
self.pipelines[cached_pipeline.0].state = CachedPipelineState::Queued;
self.waiting_pipelines.insert(cached_pipeline);
}
}
pub fn remove_shader(&mut self, shader: &Handle<Shader>) {
if let Some(cached_pipelines) = self.shader_cache.clear(shader) {
for cached_pipeline in cached_pipelines {
self.pipelines[cached_pipeline.0].state = CachedPipelineState::Queued;
self.waiting_pipelines.push(cached_pipeline);
}
fn remove_shader(&mut self, shader: &Handle<Shader>) {
let pipelines_to_queue = self.shader_cache.remove(shader);
for cached_pipeline in pipelines_to_queue {
self.pipelines[cached_pipeline.0].state = CachedPipelineState::Queued;
self.waiting_pipelines.insert(cached_pipeline);
}
}
pub fn process_queue(&mut self, shaders: &RenderAssets<Shader>) {
pub fn process_queue(&mut self) {
let pipelines = std::mem::take(&mut self.waiting_pipelines);
for id in pipelines {
let state = &mut self.pipelines[id.0];
@ -208,7 +272,8 @@ impl RenderPipelineCache {
CachedPipelineState::Queued => {}
CachedPipelineState::Err(err) => {
match err {
RenderPipelineError::ShaderNotLoaded(_) => { /* retry */ }
RenderPipelineError::ShaderNotLoaded(_)
| RenderPipelineError::ShaderImportNotYetAvailable => { /* retry */ }
RenderPipelineError::ProcessShaderError(_)
| RenderPipelineError::AsModuleDescriptorError(_) => {
// shader could not be processed ... retrying won't help
@ -221,7 +286,6 @@ impl RenderPipelineCache {
let descriptor = &state.descriptor;
let vertex_module = match self.shader_cache.get(
&self.device,
shaders,
id,
&descriptor.vertex.shader,
&descriptor.vertex.shader_defs,
@ -229,7 +293,7 @@ impl RenderPipelineCache {
Ok(module) => module,
Err(err) => {
state.state = CachedPipelineState::Err(err);
self.waiting_pipelines.push(id);
self.waiting_pipelines.insert(id);
continue;
}
};
@ -237,7 +301,6 @@ impl RenderPipelineCache {
let fragment_data = if let Some(fragment) = &descriptor.fragment {
let fragment_module = match self.shader_cache.get(
&self.device,
shaders,
id,
&fragment.shader,
&fragment.shader_defs,
@ -245,7 +308,7 @@ impl RenderPipelineCache {
Ok(module) => module,
Err(err) => {
state.state = CachedPipelineState::Err(err);
self.waiting_pipelines.push(id);
self.waiting_pipelines.insert(id);
continue;
}
};
@ -300,24 +363,24 @@ impl RenderPipelineCache {
}
}
pub(crate) fn process_pipeline_queue_system(
mut cache: ResMut<Self>,
shaders: Res<RenderAssets<Shader>>,
) {
cache.process_queue(&shaders);
pub(crate) fn process_pipeline_queue_system(mut cache: ResMut<Self>) {
cache.process_queue();
}
pub(crate) fn extract_dirty_shaders(
pub(crate) fn extract_shaders(
mut world: ResMut<RenderWorld>,
shaders: Res<Assets<Shader>>,
mut events: EventReader<AssetEvent<Shader>>,
) {
let mut cache = world.get_resource_mut::<Self>().unwrap();
for event in events.iter() {
match event {
AssetEvent::Created { handle } | AssetEvent::Modified { handle } => {
cache.mark_shader_dirty(handle)
if let Some(shader) = shaders.get(handle) {
cache.set_shader(handle, shader);
}
}
AssetEvent::Removed { handle } => cache.shader_cache.remove(handle),
AssetEvent::Removed { handle } => cache.remove_shader(handle),
}
}
}

View file

@ -1,14 +1,15 @@
use bevy_asset::{AssetLoader, LoadContext, LoadedAsset};
use bevy_asset::{AssetLoader, Handle, LoadContext, LoadedAsset};
use bevy_reflect::{TypeUuid, Uuid};
use bevy_utils::{tracing::error, BoxedFuture};
use bevy_utils::{tracing::error, BoxedFuture, HashMap};
use naga::{valid::ModuleInfo, Module};
use once_cell::sync::Lazy;
use regex::Regex;
use std::{borrow::Cow, collections::HashSet, marker::Copy};
use std::{
borrow::Cow, collections::HashSet, marker::Copy, ops::Deref, path::PathBuf, str::FromStr,
};
use thiserror::Error;
use wgpu::{ShaderModuleDescriptor, ShaderSource};
use crate::render_asset::{PrepareAssetError, RenderAsset};
#[derive(Copy, Clone, Hash, Eq, PartialEq, Debug)]
pub struct ShaderId(Uuid);
@ -35,7 +36,60 @@ pub enum ShaderReflectError {
/// This is an "unprocessed" shader. It can contain preprocessor directives.
#[derive(Debug, Clone, TypeUuid)]
#[uuid = "d95bc916-6c55-4de3-9622-37e7b6969fda"]
pub enum Shader {
pub struct Shader {
source: Source,
import_path: Option<ShaderImport>,
imports: Vec<ShaderImport>,
}
impl Shader {
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
let source = source.into();
Shader {
imports: SHADER_IMPORT_PROCESSOR.get_imports_from_str(&source),
source: Source::Wgsl(source),
import_path: None,
}
}
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
let source = source.into();
Shader {
imports: SHADER_IMPORT_PROCESSOR.get_imports_from_str(&source),
source: Source::Glsl(source, stage),
import_path: None,
}
}
pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>) -> Shader {
Shader {
imports: Vec::new(),
source: Source::SpirV(source.into()),
import_path: None,
}
}
pub fn set_import_path<P: Into<String>>(&mut self, import_path: P) {
self.import_path = Some(ShaderImport::Custom(import_path.into()));
}
pub fn with_import_path<P: Into<String>>(mut self, import_path: P) -> Self {
self.set_import_path(import_path);
self
}
#[inline]
pub fn import_path(&self) -> Option<&ShaderImport> {
self.import_path.as_ref()
}
pub fn imports(&self) -> impl ExactSizeIterator<Item = &ShaderImport> {
self.imports.iter()
}
}
#[derive(Debug, Clone)]
pub enum Source {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
SpirV(Cow<'static, [u8]>),
@ -45,6 +99,7 @@ pub enum Shader {
}
/// A processed [Shader]. This cannot contain preprocessor directions. It must be "ready to compile"
#[derive(PartialEq, Eq, Debug)]
pub enum ProcessedShader {
Wgsl(Cow<'static, str>),
Glsl(Cow<'static, str>, naga::ShaderStage),
@ -52,6 +107,21 @@ pub enum ProcessedShader {
}
impl ProcessedShader {
pub fn get_wgsl_source(&self) -> Option<&str> {
if let ProcessedShader::Wgsl(source) = self {
Some(source)
} else {
None
}
}
pub fn get_glsl_source(&self) -> Option<&str> {
if let ProcessedShader::Glsl(source, _stage) = self {
Some(source)
} else {
None
}
}
pub fn reflect(&self) -> Result<ShaderReflection, ShaderReflectError> {
let module = match &self {
// TODO: process macros here
@ -138,20 +208,6 @@ impl ShaderReflection {
}
}
impl Shader {
pub fn from_wgsl(source: impl Into<Cow<'static, str>>) -> Shader {
Shader::Wgsl(source.into())
}
pub fn from_glsl(source: impl Into<Cow<'static, str>>, stage: naga::ShaderStage) -> Shader {
Shader::Glsl(source.into(), stage)
}
pub fn from_spirv(source: impl Into<Cow<'static, [u8]>>) -> Shader {
Shader::SpirV(source.into())
}
}
#[derive(Default)]
pub struct ShaderLoader;
@ -164,7 +220,7 @@ impl AssetLoader for ShaderLoader {
Box::pin(async move {
let ext = load_context.path().extension().unwrap().to_str().unwrap();
let shader = match ext {
let mut shader = match ext {
"spv" => Shader::from_spirv(Vec::from(bytes)),
"wgsl" => Shader::from_wgsl(String::from_utf8(Vec::from(bytes))?),
"vert" => Shader::from_glsl(
@ -178,7 +234,19 @@ impl AssetLoader for ShaderLoader {
_ => panic!("unhandled extension: {}", ext),
};
load_context.set_default_asset(LoadedAsset::new(shader));
shader.import_path = Some(ShaderImport::AssetPath(
load_context.path().to_string_lossy().to_string(),
));
let imports = SHADER_IMPORT_PROCESSOR.get_imports(&shader);
let mut asset = LoadedAsset::new(shader);
for import in imports {
if let ShaderImport::AssetPath(asset_path) = import {
let path = PathBuf::from_str(&asset_path)?;
asset.add_dependency(path.into());
}
}
load_context.set_default_asset(asset);
Ok(())
})
}
@ -188,23 +256,6 @@ impl AssetLoader for ShaderLoader {
}
}
impl RenderAsset for Shader {
type ExtractedAsset = Shader;
type PreparedAsset = Shader;
type Param = ();
fn extract_asset(&self) -> Self::ExtractedAsset {
self.clone()
}
fn prepare_asset(
extracted_asset: Self::ExtractedAsset,
_param: &mut bevy_ecs::system::SystemParamItem<Self::Param>,
) -> Result<Self::PreparedAsset, PrepareAssetError<Self::ExtractedAsset>> {
Ok(extracted_asset)
}
}
#[derive(Error, Debug, PartialEq, Eq)]
pub enum ProcessShaderError {
#[error("Too many '# endif' lines. Each endif should be preceded by an if statement.")]
@ -215,8 +266,62 @@ pub enum ProcessShaderError {
NotEnoughEndIfs,
#[error("This Shader's format does not support processing shader defs.")]
ShaderFormatDoesNotSupportShaderDefs,
#[error("This Shader's formatdoes not support imports.")]
ShaderFormatDoesNotSupportImports,
#[error("Unresolved import: {0:?}.")]
UnresolvedImport(ShaderImport),
#[error("The shader import {0:?} does not match the source file type. Support for this might be added in the future.")]
MismatchedImportFormat(ShaderImport),
}
pub struct ShaderImportProcessor {
import_asset_path_regex: Regex,
import_custom_path_regex: Regex,
}
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub enum ShaderImport {
AssetPath(String),
Custom(String),
}
impl Default for ShaderImportProcessor {
fn default() -> Self {
Self {
import_asset_path_regex: Regex::new(r#"^\s*#\s*import\s*"(.+)""#).unwrap(),
import_custom_path_regex: Regex::new(r"^\s*#\s*import\s*(.+)").unwrap(),
}
}
}
impl ShaderImportProcessor {
pub fn get_imports(&self, shader: &Shader) -> Vec<ShaderImport> {
match &shader.source {
Source::Wgsl(source) => self.get_imports_from_str(source),
Source::Glsl(source, _stage) => self.get_imports_from_str(source),
Source::SpirV(_source) => Vec::new(),
}
}
pub fn get_imports_from_str(&self, shader: &str) -> Vec<ShaderImport> {
let mut imports = Vec::new();
for line in shader.split('\n') {
if let Some(cap) = self.import_asset_path_regex.captures(line) {
let import = cap.get(1).unwrap();
imports.push(ShaderImport::AssetPath(import.as_str().to_string()));
} else if let Some(cap) = self.import_custom_path_regex.captures(line) {
let import = cap.get(1).unwrap();
imports.push(ShaderImport::Custom(import.as_str().to_string()));
}
}
imports
}
}
pub static SHADER_IMPORT_PROCESSOR: Lazy<ShaderImportProcessor> =
Lazy::new(ShaderImportProcessor::default);
pub struct ShaderProcessor {
ifdef_regex: Regex,
ifndef_regex: Regex,
@ -234,38 +339,29 @@ impl Default for ShaderProcessor {
}
impl ShaderProcessor {
pub fn process_shader(
pub fn process(
&self,
shader: &Shader,
shader_defs: &[String],
shaders: &HashMap<Handle<Shader>, Shader>,
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
) -> Result<ProcessedShader, ProcessShaderError> {
match shader {
Shader::Wgsl(source) => Ok(ProcessedShader::Wgsl(Cow::from(
self.process_str(source, shader_defs)?,
))),
Shader::Glsl(source, stage) => Ok(ProcessedShader::Glsl(
Cow::from(self.process_str(source, shader_defs)?),
*stage,
)),
Shader::SpirV(source) => {
let shader_str = match &shader.source {
Source::Wgsl(source) => source.deref(),
Source::Glsl(source, _stage) => source.deref(),
Source::SpirV(source) => {
if shader_defs.is_empty() {
Ok(ProcessedShader::SpirV(source.clone()))
return Ok(ProcessedShader::SpirV(source.clone()));
} else {
Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs)
return Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs);
}
}
}
}
};
pub fn process_str(
&self,
shader: &str,
shader_defs: &[String],
) -> Result<String, ProcessShaderError> {
let shader_defs = HashSet::<String>::from_iter(shader_defs.iter().cloned());
let mut scopes = vec![true];
let mut final_string = String::new();
for line in shader.split('\n') {
for line in shader_str.split('\n') {
if let Some(cap) = self.ifdef_regex.captures(line) {
let def = cap.get(1).unwrap();
scopes.push(*scopes.last().unwrap() && shader_defs.contains(def.as_str()));
@ -277,6 +373,18 @@ impl ShaderProcessor {
if scopes.is_empty() {
return Err(ProcessShaderError::TooManyEndIfs);
}
} else if let Some(cap) = SHADER_IMPORT_PROCESSOR
.import_asset_path_regex
.captures(line)
{
let import = ShaderImport::AssetPath(cap.get(1).unwrap().as_str().to_string());
apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
} else if let Some(cap) = SHADER_IMPORT_PROCESSOR
.import_custom_path_regex
.captures(line)
{
let import = ShaderImport::Custom(cap.get(1).unwrap().as_str().to_string());
apply_import(import_handles, shaders, &import, shader, &mut final_string)?;
} else if *scopes.last().unwrap() {
final_string.push_str(line);
final_string.push('\n');
@ -288,13 +396,60 @@ impl ShaderProcessor {
if scopes.len() != 1 {
return Err(ProcessShaderError::NotEnoughEndIfs);
}
Ok(final_string)
let processed_source = Cow::from(final_string);
match &shader.source {
Source::Wgsl(_source) => Ok(ProcessedShader::Wgsl(processed_source)),
Source::Glsl(_source, stage) => Ok(ProcessedShader::Glsl(processed_source, *stage)),
Source::SpirV(_source) => {
unreachable!("SpirV has early return");
}
}
}
}
fn apply_import(
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
shaders: &HashMap<Handle<Shader>, Shader>,
import: &ShaderImport,
shader: &Shader,
final_string: &mut String,
) -> Result<(), ProcessShaderError> {
let imported_shader = import_handles
.get(import)
.and_then(|handle| shaders.get(handle))
.ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
match &shader.source {
Source::Wgsl(_) => {
if let Source::Wgsl(import_source) = &imported_shader.source {
final_string.push_str(import_source);
} else {
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
}
}
Source::Glsl(_, _) => {
if let Source::Glsl(import_source, _) = &imported_shader.source {
final_string.push_str(import_source);
} else {
return Err(ProcessShaderError::MismatchedImportFormat(import.clone()));
}
}
Source::SpirV(_) => {
return Err(ProcessShaderError::ShaderFormatDoesNotSupportImports);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::render_resource::{ProcessShaderError, ShaderProcessor};
use bevy_asset::Handle;
use bevy_utils::HashMap;
use naga::ShaderStage;
use crate::render_resource::{ProcessShaderError, Shader, ShaderImport, ShaderProcessor};
#[rustfmt::skip]
const WGSL: &str = r"
[[block]]
@ -305,10 +460,10 @@ struct View {
[[group(0), binding(0)]]
var<uniform> view: View;
# ifdef TEXTURE
#ifdef TEXTURE
[[group(1), binding(0)]]
var sprite_texture: texture_2d<f32>;
# endif
#endif
struct VertexOutput {
[[location(0)]] uv: vec2<f32>;
@ -392,9 +547,14 @@ fn vertex(
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL, &["TEXTURE".to_string()])
.process(
&Shader::from_wgsl(WGSL),
&["TEXTURE".to_string()],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result, EXPECTED);
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
@ -427,18 +587,30 @@ fn vertex(
}
";
let processor = ShaderProcessor::default();
let result = processor.process_str(WGSL, &[]).unwrap();
assert_eq!(result, EXPECTED);
let result = processor
.process(
&Shader::from_wgsl(WGSL),
&[],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
fn process_shader_def_unclosed() {
#[rustfmt::skip]
const INPUT: &str = r"
# ifdef FOO
#ifdef FOO
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
let result = processor.process(
&Shader::from_wgsl(INPUT),
&[],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(result, Err(ProcessShaderError::NotEnoughEndIfs));
}
@ -446,10 +618,15 @@ fn vertex(
fn process_shader_def_too_closed() {
#[rustfmt::skip]
const INPUT: &str = r"
# endif
#endif
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]);
let result = processor.process(
&Shader::from_wgsl(INPUT),
&[],
&HashMap::default(),
&HashMap::default(),
);
assert_eq!(result, Err(ProcessShaderError::TooManyEndIfs));
}
@ -457,12 +634,91 @@ fn vertex(
fn process_shader_def_commented() {
#[rustfmt::skip]
const INPUT: &str = r"
// # ifdef FOO
// #ifdef FOO
fn foo() { }
";
let processor = ShaderProcessor::default();
let result = processor.process_str(INPUT, &[]).unwrap();
assert_eq!(result, INPUT);
let result = processor
.process(
&Shader::from_wgsl(INPUT),
&[],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), INPUT);
}
#[test]
fn process_import_wgsl() {
#[rustfmt::skip]
const FOO: &str = r"
fn foo() { }
";
#[rustfmt::skip]
const INPUT: &str = r"
#import FOO
fn bar() { }
";
#[rustfmt::skip]
const EXPECTED: &str = r"
fn foo() { }
fn bar() { }
";
let processor = ShaderProcessor::default();
let mut shaders = HashMap::default();
let mut import_handles = HashMap::default();
let foo_handle = Handle::<Shader>::default();
shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
import_handles.insert(
ShaderImport::Custom("FOO".to_string()),
foo_handle.clone_weak(),
);
let result = processor
.process(&Shader::from_wgsl(INPUT), &[], &shaders, &import_handles)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
fn process_import_glsl() {
#[rustfmt::skip]
const FOO: &str = r"
void foo() { }
";
#[rustfmt::skip]
const INPUT: &str = r"
#import FOO
void bar() { }
";
#[rustfmt::skip]
const EXPECTED: &str = r"
void foo() { }
void bar() { }
";
let processor = ShaderProcessor::default();
let mut shaders = HashMap::default();
let mut import_handles = HashMap::default();
let foo_handle = Handle::<Shader>::default();
shaders.insert(
foo_handle.clone_weak(),
Shader::from_glsl(FOO, ShaderStage::Vertex),
);
import_handles.insert(
ShaderImport::Custom("FOO".to_string()),
foo_handle.clone_weak(),
);
let result = processor
.process(
&Shader::from_glsl(INPUT, ShaderStage::Vertex),
&[],
&shaders,
&import_handles,
)
.unwrap();
assert_eq!(result.get_glsl_source().unwrap(), EXPECTED);
}
#[test]
@ -496,9 +752,14 @@ fn vertex(
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL_NESTED_IFDEF, &["TEXTURE".to_string()])
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&["TEXTURE".to_string()],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result, EXPECTED);
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
@ -531,8 +792,15 @@ fn vertex(
}
";
let processor = ShaderProcessor::default();
let result = processor.process_str(WGSL_NESTED_IFDEF, &[]).unwrap();
assert_eq!(result, EXPECTED);
let result = processor
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&[],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
@ -566,9 +834,14 @@ fn vertex(
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(WGSL_NESTED_IFDEF, &["ATTRIBUTE".to_string()])
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&["ATTRIBUTE".to_string()],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result, EXPECTED);
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
#[test]
@ -604,11 +877,13 @@ fn vertex(
";
let processor = ShaderProcessor::default();
let result = processor
.process_str(
WGSL_NESTED_IFDEF,
.process(
&Shader::from_wgsl(WGSL_NESTED_IFDEF),
&["TEXTURE".to_string(), "ATTRIBUTE".to_string()],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result, EXPECTED);
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
}