add storage_texture option to as_bind_group macro (#9943)

# Objective

- Add the ability to describe storage texture bindings when deriving
`AsBindGroup`.
- This is especially valuable for the compute story of bevy which
deserves some extra love imo.

## Solution

- This add the ability to annotate struct fields with a
`#[storage_texture(0)]` annotation.
- Instead of adding specific option parsing for all the image formats
and access modes, I simply accept a token stream and defer checking to
see if the option is valid to the compiler. This still results in useful
and friendly errors and is free to maintain and always compatible with
wgpu changes.

---

## Changelog

- The `#[storage_texture(..)]` annotation is now accepted for fields of
`Handle<Image>` in structs that derive `AsBindGroup`.
- The game_of_life compute shader example has been updated to use
`AsBindGroup` together with `[storage_texture(..)]` to obtain the
`BindGroupLayout`.

## Migration Guide
This commit is contained in:
HugoPeters1024 2024-01-21 19:47:13 +01:00 committed by GitHub
parent 0fa14c86de
commit 8afb3ceb89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 146 additions and 13 deletions

View file

@ -11,6 +11,7 @@ use syn::{
const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
enum BindingType {
Uniform,
Texture,
StorageTexture,
Sampler,
Storage,
}
@ -133,6 +135,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
BindingType::Uniform
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
BindingType::Texture
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
BindingType::StorageTexture
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
BindingType::Sampler
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
@ -255,6 +259,45 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
}
});
}
BindingType::StorageTexture => {
let StorageTextureAttrs {
dimension,
image_format,
access,
visibility,
} = get_storage_texture_binding_attr(nested_meta_items)?;
let visibility =
visibility.hygienic_quote(&quote! { #render_path::render_resource });
let fallback_image = get_fallback_image(&render_path, dimension);
binding_impls.push(quote! {
( #binding_index,
#render_path::render_resource::OwnedBindingResource::TextureView({
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
if let Some(handle) = handle {
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
} else {
#fallback_image.texture_view.clone()
}
})
)
});
binding_layouts.push(quote! {
#render_path::render_resource::BindGroupLayoutEntry {
binding: #binding_index,
visibility: #visibility,
ty: #render_path::render_resource::BindingType::StorageTexture {
access: #render_path::render_resource::StorageTextureAccess::#access,
format: #render_path::render_resource::TextureFormat::#image_format,
view_dimension: #render_path::render_resource::#dimension,
},
count: None,
}
});
}
BindingType::Texture => {
let TextureAttrs {
dimension,
@ -585,6 +628,10 @@ impl ShaderStageVisibility {
fn vertex_fragment() -> Self {
Self::Flags(VisibilityFlags::vertex_fragment())
}
fn compute() -> Self {
Self::Flags(VisibilityFlags::compute())
}
}
impl VisibilityFlags {
@ -595,6 +642,13 @@ impl VisibilityFlags {
..Default::default()
}
}
fn compute() -> Self {
Self {
compute: true,
..Default::default()
}
}
}
impl ShaderStageVisibility {
@ -741,7 +795,72 @@ impl Default for TextureAttrs {
}
}
struct StorageTextureAttrs {
dimension: BindingTextureDimension,
// Parsing of the image_format parameter is deferred to the type checker,
// which will error if the format is not member of the TextureFormat enum.
image_format: proc_macro2::TokenStream,
// Parsing of the access parameter is deferred to the type checker,
// which will error if the access is not member of the StorageTextureAccess enum.
access: proc_macro2::TokenStream,
visibility: ShaderStageVisibility,
}
impl Default for StorageTextureAttrs {
fn default() -> Self {
Self {
dimension: Default::default(),
image_format: quote! { Rgba8Unorm },
access: quote! { ReadWrite },
visibility: ShaderStageVisibility::compute(),
}
}
}
fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
let mut storage_texture_attrs = StorageTextureAttrs::default();
for meta in metas {
use syn::Meta::{List, NameValue};
match meta {
// Parse #[storage_texture(0, dimension = "...")].
NameValue(m) if m.path == DIMENSION => {
let value = get_lit_str(DIMENSION, &m.value)?;
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
}
// Parse #[storage_texture(0, format = ...))].
NameValue(m) if m.path == IMAGE_FORMAT => {
storage_texture_attrs.image_format = m.value.into_token_stream();
}
// Parse #[storage_texture(0, access = ...))].
NameValue(m) if m.path == ACCESS => {
storage_texture_attrs.access = m.value.into_token_stream();
}
// Parse #[storage_texture(0, visibility(...))].
List(m) if m.path == VISIBILITY => {
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
}
NameValue(m) => {
return Err(Error::new_spanned(
m.path,
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
));
}
_ => {
return Err(Error::new_spanned(
meta,
"Not a name value pair: `foo = \"...\"`",
));
}
}
}
Ok(storage_texture_attrs)
}
const DIMENSION: Symbol = Symbol("dimension");
const IMAGE_FORMAT: Symbol = Symbol("image_format");
const ACCESS: Symbol = Symbol("access");
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
const FILTERABLE: Symbol = Symbol("filterable");
const MULTISAMPLED: Symbol = Symbol("multisampled");

View file

@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {
#[proc_macro_derive(
AsBindGroup,
attributes(uniform, texture, sampler, bind_group_data, storage)
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
)]
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);

View file

@ -87,6 +87,8 @@ impl Deref for BindGroup {
/// values: Vec<f32>,
/// #[storage(4, read_only, buffer)]
/// buffer: Buffer,
/// #[storage_texture(5)]
/// storage_texture: Handle<Image>,
/// }
/// ```
///
@ -97,6 +99,7 @@ impl Deref for BindGroup {
/// @group(2) @binding(1) var color_texture: texture_2d<f32>;
/// @group(2) @binding(2) var color_sampler: sampler;
/// @group(2) @binding(3) var<storage> values: array<f32>;
/// @group(2) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
/// ```
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
/// are generally bound to group 2.
@ -123,6 +126,19 @@ impl Deref for BindGroup {
/// | `multisampled` = ... | `true`, `false` | `false` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
///
/// * `storage_texture(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
/// [`None`], the [`FallbackImage`] resource will be used instead.
///
/// | Arguments | Values | Default |
/// |------------------------|--------------------------------------------------------------------------------------------|---------------|
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` |
/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
///
/// * `sampler(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`] GPU
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,

View file

@ -10,7 +10,7 @@ use bevy::{
render_asset::RenderAssetPersistencePolicy,
render_asset::RenderAssets,
render_graph::{self, RenderGraph},
render_resource::{binding_types::texture_storage_2d, *},
render_resource::*,
renderer::{RenderContext, RenderDevice},
Render, RenderApp, RenderSet,
},
@ -65,7 +65,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
});
commands.spawn(Camera2dBundle::default());
commands.insert_resource(GameOfLifeImage(image));
commands.insert_resource(GameOfLifeImage { texture: image });
}
pub struct GameOfLifeComputePlugin;
@ -95,8 +95,11 @@ impl Plugin for GameOfLifeComputePlugin {
}
}
#[derive(Resource, Clone, Deref, ExtractResource)]
struct GameOfLifeImage(Handle<Image>);
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
struct GameOfLifeImage {
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
texture: Handle<Image>,
}
#[derive(Resource)]
struct GameOfLifeImageBindGroup(BindGroup);
@ -108,7 +111,7 @@ fn prepare_bind_group(
game_of_life_image: Res<GameOfLifeImage>,
render_device: Res<RenderDevice>,
) {
let view = gpu_images.get(&game_of_life_image.0).unwrap();
let view = gpu_images.get(&game_of_life_image.texture).unwrap();
let bind_group = render_device.create_bind_group(
None,
&pipeline.texture_bind_group_layout,
@ -126,13 +129,8 @@ pub struct GameOfLifePipeline {
impl FromWorld for GameOfLifePipeline {
fn from_world(world: &mut World) -> Self {
let texture_bind_group_layout = world.resource::<RenderDevice>().create_bind_group_layout(
None,
&BindGroupLayoutEntries::single(
ShaderStages::COMPUTE,
texture_storage_2d(TextureFormat::Rgba8Unorm, StorageTextureAccess::ReadWrite),
),
);
let render_device = world.resource::<RenderDevice>();
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
let shader = world
.resource::<AssetServer>()
.load("shaders/game_of_life.wgsl");