Async pipeline compilation (#10812)

# Objective

- Pipeline compilation is slow and blocks the frame
- Closes https://github.com/bevyengine/bevy/issues/8224

## Solution

- Compile pipelines in a Task on the AsyncComputeTaskPool

---

## Changelog

- Render/compute pipeline compilation is now done asynchronously over
multiple frames when the multi-threaded feature is enabled and on
non-wasm and non-macOS platforms
- Added `CachedPipelineState::Creating` 
- Added `PipelineCache::block_on_render_pipeline()`
- Added `bevy_utils::futures::check_ready`
- Added `bevy_render/multi-threaded` cargo feature

## Migration Guide

- Match on the new `Creating` variant for exhaustive matches of
`CachedPipelineState`
This commit is contained in:
JMS55 2024-02-05 05:50:50 -08:00 committed by GitHub
parent 7705c1dd6c
commit 9f7e61b819
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 259 additions and 165 deletions

View file

@ -63,7 +63,9 @@ impl Node for CASNode {
return Ok(());
};
let pipeline = pipeline_cache.get_render_pipeline(pipeline.0).unwrap();
let Some(pipeline) = pipeline_cache.get_render_pipeline(pipeline.0) else {
return Ok(());
};
let view_target = target.post_process_write();
let source = view_target.source;

View file

@ -39,9 +39,9 @@ impl ViewNode for FxaaNode {
return Ok(());
};
let pipeline = pipeline_cache
.get_render_pipeline(pipeline.pipeline_id)
.unwrap();
let Some(pipeline) = pipeline_cache.get_render_pipeline(pipeline.pipeline_id) else {
return Ok(());
};
let post_process = target.post_process_write();
let source = post_process.source;

View file

@ -27,7 +27,7 @@ pub struct ViewUpscalingPipeline(CachedRenderPipelineId);
fn prepare_view_upscaling_pipelines(
mut commands: Commands,
pipeline_cache: Res<PipelineCache>,
mut pipeline_cache: ResMut<PipelineCache>,
mut pipelines: ResMut<SpecializedRenderPipelines<BlitPipeline>>,
blit_pipeline: Res<BlitPipeline>,
view_targets: Query<(Entity, &ViewTarget, Option<&ExtractedCamera>)>,
@ -49,6 +49,9 @@ fn prepare_view_upscaling_pipelines(
};
let pipeline = pipelines.specialize(&pipeline_cache, &blit_pipeline, key);
// Ensure the pipeline is loaded before continuing the frame to prevent frames without any GPU work submitted
pipeline_cache.block_on_render_pipeline(pipeline);
commands
.entity(entity)
.insert(ViewUpscalingPipeline(pipeline));

View file

@ -77,6 +77,7 @@ serialize = [
multi-threaded = [
"bevy_asset/multi-threaded",
"bevy_ecs/multi-threaded",
"bevy_render/multi-threaded",
"bevy_tasks/multi-threaded",
]
async-io = ["bevy_tasks/async-io"]

View file

@ -40,11 +40,11 @@ smallvec = "1.6"
thread_local = "1.0"
[target.'cfg(target_arch = "wasm32")'.dependencies]
naga_oil = "0.12"
naga_oil = "0.13"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
# Omit the `glsl` feature in non-WebAssembly by default.
naga_oil = { version = "0.12", default-features = false, features = [
naga_oil = { version = "0.13", default-features = false, features = [
"test_shader",
] }

View file

@ -18,6 +18,7 @@ bmp = ["image/bmp"]
webp = ["image/webp"]
dds = ["ddsfile"]
pnm = ["image/pnm"]
multi-threaded = ["bevy_tasks/multi-threaded"]
bevy_ci_testing = ["bevy_app/bevy_ci_testing"]
shader_format_glsl = ["naga/glsl-in", "naga/wgsl-out", "naga_oil/glsl"]
@ -72,7 +73,7 @@ wgpu = { version = "0.19.1", default-features = false, features = [
"fragile-send-sync-non-atomic-wasm",
] }
naga = { version = "0.19", features = ["wgsl-in"] }
naga_oil = { version = "0.12", default-features = false, features = [
naga_oil = { version = "0.13", default-features = false, features = [
"test_shader",
] }
serde = { version = "1", features = ["derive"] }

View file

@ -1,15 +1,8 @@
use crate::{
render_resource::{
BindGroupLayout, BindGroupLayoutId, ComputePipeline, ComputePipelineDescriptor,
RawComputePipelineDescriptor, RawFragmentState, RawRenderPipelineDescriptor,
RawVertexState, RenderPipeline, RenderPipelineDescriptor, Shader, ShaderImport, Source,
},
renderer::RenderDevice,
Extract,
};
use crate::{render_resource::*, renderer::RenderDevice, Extract};
use bevy_asset::{AssetEvent, AssetId, Assets};
use bevy_ecs::system::{Res, ResMut};
use bevy_ecs::{event::EventReader, system::Resource};
use bevy_tasks::Task;
use bevy_utils::{
default,
tracing::{debug, error},
@ -18,10 +11,11 @@ use bevy_utils::{
use naga::valid::Capabilities;
use std::{
borrow::Cow,
future::Future,
hash::Hash,
mem,
ops::Deref,
sync::{Mutex, PoisonError},
sync::{Arc, Mutex, PoisonError},
};
use thiserror::Error;
#[cfg(feature = "shader_format_spirv")]
@ -94,6 +88,8 @@ pub struct CachedPipeline {
pub enum CachedPipelineState {
/// The pipeline GPU object is queued for creation.
Queued,
/// The pipeline GPU object is being created.
Creating(Task<Result<Pipeline, PipelineCacheError>>),
/// The pipeline GPU object was created successfully and is available (allocated on the GPU).
Ok(Pipeline),
/// An error occurred while trying to create the pipeline GPU object.
@ -117,6 +113,9 @@ impl CachedPipelineState {
CachedPipelineState::Queued => {
panic!("Pipeline has not been compiled yet. It is still in the 'Queued' state.")
}
CachedPipelineState::Creating(..) => {
panic!("Pipeline has not been compiled yet. It is still in the 'Creating' state.")
}
CachedPipelineState::Err(err) => panic!("{}", err),
}
}
@ -440,7 +439,7 @@ impl LayoutCache {
render_device: &RenderDevice,
bind_group_layouts: &[BindGroupLayout],
push_constant_ranges: Vec<PushConstantRange>,
) -> &wgpu::PipelineLayout {
) -> ErasedPipelineLayout {
let bind_group_ids = bind_group_layouts.iter().map(|l| l.id()).collect();
self.layouts
.entry((bind_group_ids, push_constant_ranges))
@ -457,6 +456,7 @@ impl LayoutCache {
},
))
})
.clone()
}
}
@ -468,14 +468,14 @@ impl LayoutCache {
/// pipeline object is deferred to the [`RenderSet::Render`] step, just before the render
/// graph starts being processed, as this requires access to the GPU.
///
/// Note that the cache do not perform automatic deduplication of identical pipelines. It is
/// Note that the cache does not perform automatic deduplication of identical pipelines. It is
/// up to the user not to insert the same pipeline twice to avoid wasting GPU resources.
///
/// [`RenderSet::Render`]: crate::RenderSet::Render
#[derive(Resource)]
pub struct PipelineCache {
layout_cache: LayoutCache,
shader_cache: ShaderCache,
layout_cache: Arc<Mutex<LayoutCache>>,
shader_cache: Arc<Mutex<ShaderCache>>,
device: RenderDevice,
pipelines: Vec<CachedPipeline>,
waiting_pipelines: HashSet<CachedPipelineId>,
@ -490,7 +490,7 @@ impl PipelineCache {
/// Create a new pipeline cache associated with the given render device.
pub fn new(device: RenderDevice) -> Self {
Self {
shader_cache: ShaderCache::new(&device),
shader_cache: Arc::new(Mutex::new(ShaderCache::new(&device))),
device,
layout_cache: default(),
waiting_pipelines: default(),
@ -561,6 +561,22 @@ impl PipelineCache {
}
}
/// Wait for a render pipeline to finish compiling.
#[inline]
pub fn block_on_render_pipeline(&mut self, id: CachedRenderPipelineId) {
if self.pipelines.len() <= id.0 {
self.process_queue();
}
let state = &mut self.pipelines[id.0].state;
if let CachedPipelineState::Creating(task) = state {
*state = match bevy_tasks::block_on(task) {
Ok(p) => CachedPipelineState::Ok(p),
Err(e) => CachedPipelineState::Err(e),
};
}
}
/// Try to retrieve a compute pipeline GPU object from a cached ID.
///
/// # Returns
@ -638,7 +654,8 @@ impl PipelineCache {
}
fn set_shader(&mut self, id: AssetId<Shader>, shader: &Shader) {
let pipelines_to_queue = self.shader_cache.set_shader(id, shader.clone());
let mut shader_cache = self.shader_cache.lock().unwrap();
let pipelines_to_queue = shader_cache.set_shader(id, shader.clone());
for cached_pipeline in pipelines_to_queue {
self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
self.waiting_pipelines.insert(cached_pipeline);
@ -646,135 +663,152 @@ impl PipelineCache {
}
fn remove_shader(&mut self, shader: AssetId<Shader>) {
let pipelines_to_queue = self.shader_cache.remove(shader);
let mut shader_cache = self.shader_cache.lock().unwrap();
let pipelines_to_queue = shader_cache.remove(shader);
for cached_pipeline in pipelines_to_queue {
self.pipelines[cached_pipeline].state = CachedPipelineState::Queued;
self.waiting_pipelines.insert(cached_pipeline);
}
}
fn process_render_pipeline(
fn start_create_render_pipeline(
&mut self,
id: CachedPipelineId,
descriptor: &RenderPipelineDescriptor,
descriptor: RenderPipelineDescriptor,
) -> CachedPipelineState {
let vertex_module = match self.shader_cache.get(
&self.device,
id,
descriptor.vertex.shader.id(),
&descriptor.vertex.shader_defs,
) {
Ok(module) => module,
Err(err) => {
return CachedPipelineState::Err(err);
}
};
let device = self.device.clone();
let shader_cache = self.shader_cache.clone();
let layout_cache = self.layout_cache.clone();
create_pipeline_task(async move {
let mut shader_cache = shader_cache.lock().unwrap();
let mut layout_cache = layout_cache.lock().unwrap();
let fragment_data = if let Some(fragment) = &descriptor.fragment {
let fragment_module = match self.shader_cache.get(
&self.device,
let vertex_module = match shader_cache.get(
&device,
id,
fragment.shader.id(),
&fragment.shader_defs,
descriptor.vertex.shader.id(),
&descriptor.vertex.shader_defs,
) {
Ok(module) => module,
Err(err) => {
return CachedPipelineState::Err(err);
}
Err(err) => return Err(err),
};
Some((
fragment_module,
fragment.entry_point.deref(),
fragment.targets.as_slice(),
let fragment_module = match &descriptor.fragment {
Some(fragment) => {
match shader_cache.get(&device, id, fragment.shader.id(), &fragment.shader_defs)
{
Ok(module) => Some(module),
Err(err) => return Err(err),
}
}
None => None,
};
let layout =
if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
None
} else {
Some(layout_cache.get(
&device,
&descriptor.layout,
descriptor.push_constant_ranges.to_vec(),
))
};
drop((shader_cache, layout_cache));
let vertex_buffer_layouts = descriptor
.vertex
.buffers
.iter()
.map(|layout| RawVertexBufferLayout {
array_stride: layout.array_stride,
attributes: &layout.attributes,
step_mode: layout.step_mode,
})
.collect::<Vec<_>>();
let fragment_data = descriptor.fragment.as_ref().map(|fragment| {
(
fragment_module.unwrap(),
fragment.entry_point.deref(),
fragment.targets.as_slice(),
)
});
let descriptor = RawRenderPipelineDescriptor {
multiview: None,
depth_stencil: descriptor.depth_stencil.clone(),
label: descriptor.label.as_deref(),
layout: layout.as_deref(),
multisample: descriptor.multisample,
primitive: descriptor.primitive,
vertex: RawVertexState {
buffers: &vertex_buffer_layouts,
entry_point: descriptor.vertex.entry_point.deref(),
module: &vertex_module,
},
fragment: fragment_data
.as_ref()
.map(|(module, entry_point, targets)| RawFragmentState {
entry_point,
module,
targets,
}),
};
Ok(Pipeline::RenderPipeline(
device.create_render_pipeline(&descriptor),
))
} else {
None
};
let vertex_buffer_layouts = descriptor
.vertex
.buffers
.iter()
.map(|layout| RawVertexBufferLayout {
array_stride: layout.array_stride,
attributes: &layout.attributes,
step_mode: layout.step_mode,
})
.collect::<Vec<_>>();
let layout = if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
None
} else {
Some(self.layout_cache.get(
&self.device,
&descriptor.layout,
descriptor.push_constant_ranges.to_vec(),
))
};
let descriptor = RawRenderPipelineDescriptor {
multiview: None,
depth_stencil: descriptor.depth_stencil.clone(),
label: descriptor.label.as_deref(),
layout,
multisample: descriptor.multisample,
primitive: descriptor.primitive,
vertex: RawVertexState {
buffers: &vertex_buffer_layouts,
entry_point: descriptor.vertex.entry_point.deref(),
module: &vertex_module,
},
fragment: fragment_data
.as_ref()
.map(|(module, entry_point, targets)| RawFragmentState {
entry_point,
module,
targets,
}),
};
let pipeline = self.device.create_render_pipeline(&descriptor);
CachedPipelineState::Ok(Pipeline::RenderPipeline(pipeline))
})
}
fn process_compute_pipeline(
fn start_create_compute_pipeline(
&mut self,
id: CachedPipelineId,
descriptor: &ComputePipelineDescriptor,
descriptor: ComputePipelineDescriptor,
) -> CachedPipelineState {
let compute_module = match self.shader_cache.get(
&self.device,
id,
descriptor.shader.id(),
&descriptor.shader_defs,
) {
Ok(module) => module,
Err(err) => {
return CachedPipelineState::Err(err);
}
};
let device = self.device.clone();
let shader_cache = self.shader_cache.clone();
let layout_cache = self.layout_cache.clone();
create_pipeline_task(async move {
let mut shader_cache = shader_cache.lock().unwrap();
let mut layout_cache = layout_cache.lock().unwrap();
let layout = if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
None
} else {
Some(self.layout_cache.get(
&self.device,
&descriptor.layout,
descriptor.push_constant_ranges.to_vec(),
let compute_module = match shader_cache.get(
&device,
id,
descriptor.shader.id(),
&descriptor.shader_defs,
) {
Ok(module) => module,
Err(err) => return Err(err),
};
let layout =
if descriptor.layout.is_empty() && descriptor.push_constant_ranges.is_empty() {
None
} else {
Some(layout_cache.get(
&device,
&descriptor.layout,
descriptor.push_constant_ranges.to_vec(),
))
};
drop((shader_cache, layout_cache));
let descriptor = RawComputePipelineDescriptor {
label: descriptor.label.as_deref(),
layout: layout.as_deref(),
module: &compute_module,
entry_point: &descriptor.entry_point,
};
Ok(Pipeline::ComputePipeline(
device.create_compute_pipeline(&descriptor),
))
};
let descriptor = RawComputePipelineDescriptor {
label: descriptor.label.as_deref(),
layout,
module: &compute_module,
entry_point: descriptor.entry_point.as_ref(),
};
let pipeline = self.device.create_compute_pipeline(&descriptor);
CachedPipelineState::Ok(Pipeline::ComputePipeline(pipeline))
})
}
/// Process the pipeline queue and create all pending pipelines if possible.
@ -800,44 +834,61 @@ impl PipelineCache {
}
for id in waiting_pipelines {
let pipeline = &mut pipelines[id];
if matches!(pipeline.state, CachedPipelineState::Ok(_)) {
continue;
}
pipeline.state = match &pipeline.descriptor {
PipelineDescriptor::RenderPipelineDescriptor(descriptor) => {
self.process_render_pipeline(id, descriptor)
}
PipelineDescriptor::ComputePipelineDescriptor(descriptor) => {
self.process_compute_pipeline(id, descriptor)
}
};
if let CachedPipelineState::Err(err) = &pipeline.state {
match err {
PipelineCacheError::ShaderNotLoaded(_)
| PipelineCacheError::ShaderImportNotYetAvailable => {
// retry
self.waiting_pipelines.insert(id);
}
// shader could not be processed ... retrying won't help
PipelineCacheError::ProcessShaderError(err) => {
let error_detail = err.emit_to_string(&self.shader_cache.composer);
error!("failed to process shader:\n{}", error_detail);
continue;
}
PipelineCacheError::CreateShaderModule(description) => {
error!("failed to create shader module: {}", description);
continue;
}
}
}
self.process_pipeline(&mut pipelines[id], id);
}
self.pipelines = pipelines;
}
fn process_pipeline(&mut self, cached_pipeline: &mut CachedPipeline, id: usize) {
match &mut cached_pipeline.state {
CachedPipelineState::Queued => {
cached_pipeline.state = match &cached_pipeline.descriptor {
PipelineDescriptor::RenderPipelineDescriptor(descriptor) => {
self.start_create_render_pipeline(id, *descriptor.clone())
}
PipelineDescriptor::ComputePipelineDescriptor(descriptor) => {
self.start_create_compute_pipeline(id, *descriptor.clone())
}
};
}
CachedPipelineState::Creating(ref mut task) => {
match bevy_utils::futures::check_ready(task) {
Some(Ok(pipeline)) => {
cached_pipeline.state = CachedPipelineState::Ok(pipeline);
return;
}
Some(Err(err)) => cached_pipeline.state = CachedPipelineState::Err(err),
_ => (),
}
}
CachedPipelineState::Err(err) => match err {
// Retry
PipelineCacheError::ShaderNotLoaded(_)
| PipelineCacheError::ShaderImportNotYetAvailable => {}
// Shader could not be processed ... retrying won't help
PipelineCacheError::ProcessShaderError(err) => {
let error_detail =
err.emit_to_string(&self.shader_cache.lock().unwrap().composer);
error!("failed to process shader:\n{}", error_detail);
return;
}
PipelineCacheError::CreateShaderModule(description) => {
error!("failed to create shader module: {}", description);
return;
}
},
CachedPipelineState::Ok(_) => return,
}
// Retry
self.waiting_pipelines.insert(id);
}
pub(crate) fn process_pipeline_queue_system(mut cache: ResMut<Self>) {
cache.process_queue();
}
@ -850,6 +901,7 @@ impl PipelineCache {
for event in events.read() {
#[allow(clippy::match_same_arms)]
match event {
// PERF: Instead of blocking waiting for the shader cache lock, try again next frame if the lock is currently held
AssetEvent::Added { id } | AssetEvent::Modified { id } => {
if let Some(shader) = shaders.get(*id) {
cache.set_shader(*id, shader);
@ -865,6 +917,27 @@ impl PipelineCache {
}
}
fn create_pipeline_task(
task: impl Future<Output = Result<Pipeline, PipelineCacheError>> + Send + 'static,
) -> CachedPipelineState {
#[cfg(all(
not(target_arch = "wasm32"),
not(target_os = "macos"),
feature = "multi-threaded"
))]
return CachedPipelineState::Creating(bevy_tasks::AsyncComputeTaskPool::get().spawn(task));
#[cfg(any(
target_arch = "wasm32",
target_os = "macos",
not(feature = "multi-threaded")
))]
match futures_lite::future::block_on(task) {
Ok(pipeline) => CachedPipelineState::Ok(pipeline),
Err(err) => CachedPipelineState::Err(err),
}
}
/// Type of error returned by a [`PipelineCache`] when the creation of a GPU pipeline object failed.
#[derive(Error, Debug)]
pub enum PipelineCacheError {

View file

@ -5,7 +5,7 @@ use std::{
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
/// Consumes the future, polls it once, and immediately returns the output
/// Consumes a future, polls it once, and immediately returns the output
/// or returns `None` if it wasn't ready yet.
///
/// This will cancel the future if it's not ready.
@ -22,6 +22,20 @@ pub fn now_or_never<F: Future>(mut future: F) -> Option<F::Output> {
}
}
/// Polls a future once, and returns the output if ready
/// or returns `None` if it wasn't ready yet.
pub fn check_ready<F: Future + Unpin>(future: &mut F) -> Option<F::Output> {
let noop_waker = noop_waker();
let mut cx = Context::from_waker(&noop_waker);
let future = Pin::new(future);
match future.poll(&mut cx) {
Poll::Ready(x) => Some(x),
_ => None,
}
}
unsafe fn noop_clone(_data: *const ()) -> RawWaker {
noop_raw_waker()
}