From 5355568a2dbbb5bc4122109e2bd04ce6903adff1 Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Thu, 3 Jun 2021 20:24:56 -0400
Subject: [PATCH] glsl: Refactor Global memory functions

---
 .../backend/glsl/emit_context.cpp             | 149 +++++++++---------
 .../backend/glsl/emit_context.h               |   1 +
 2 files changed, 76 insertions(+), 74 deletions(-)

diff --git a/src/shader_recompiler/backend/glsl/emit_context.cpp b/src/shader_recompiler/backend/glsl/emit_context.cpp
index f68f33212..fbc4b9c0f 100644
--- a/src/shader_recompiler/backend/glsl/emit_context.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_context.cpp
@@ -9,11 +9,11 @@
 
 namespace Shader::Backend::GLSL {
 namespace {
-u32 CbufIndex(u32 offset) {
+u32 CbufIndex(size_t offset) {
     return (offset / 4) % 4;
 }
 
-char OffsetSwizzle(u32 offset) {
+char CbufSwizzle(size_t offset) {
     return "xyzw"[CbufIndex(offset)];
 }
 
@@ -434,81 +434,82 @@ void EmitContext::DefineHelperFunctions() {
         header += "uint CasMaxS32(uint op_a,uint op_b){return uint(max(int(op_a),int(op_b)));}";
     }
     if (info.uses_global_memory) {
-        std::string write_func{"void WriteGlobal32(uint64_t addr,uint data){\n"};
-        std::string write_func_64{"void WriteGlobal64(uint64_t addr,uvec2 data){\n"};
-        std::string write_func_128{"void WriteGlobal128(uint64_t addr,uvec4 data){\n"};
-
-        std::string load_func{"uint LoadGlobal32(uint64_t addr){\n"};
-        std::string load_func_64{"uvec2 LoadGlobal64(uint64_t addr){\n"};
-        std::string load_func_128{"uvec4 LoadGlobal128(uint64_t addr){\n"};
-        const size_t num_buffers{info.storage_buffers_descriptors.size()};
-        for (size_t index = 0; index < num_buffers; ++index) {
-            if (!info.nvn_buffer_used[index]) {
-                continue;
-            }
-            const auto& ssbo{info.storage_buffers_descriptors[index]};
-            const u32 size_cbuf_offset{ssbo.cbuf_offset + 8};
-            const auto ssbo_addr{fmt::format("ssbo_addr{}", index)};
-            const auto cbuf{fmt::format("{}_cbuf{}", stage_name, ssbo.cbuf_index)};
-            const auto cbuf_value{fmt::format(
-                "uint64_t {}=packUint2x32(uvec2(ftou({}[{}].{}),ftou({}[{}].{})));", ssbo_addr,
-                cbuf, ssbo.cbuf_offset / 16, OffsetSwizzle(ssbo.cbuf_offset), cbuf,
-                (ssbo.cbuf_offset + 4) / 16, OffsetSwizzle(ssbo.cbuf_offset + 4))};
-
-            write_func += cbuf_value;
-            write_func_64 += cbuf_value;
-            write_func_128 += cbuf_value;
-            load_func += cbuf_value;
-            load_func_64 += cbuf_value;
-            load_func_128 += cbuf_value;
-            const auto ssbo_size{fmt::format("ftou({}[{}].{}),ftou({}[{}].{})", cbuf,
-                                             size_cbuf_offset / 16, OffsetSwizzle(size_cbuf_offset),
-                                             cbuf, (size_cbuf_offset + 4) / 16,
-                                             OffsetSwizzle(size_cbuf_offset + 4))};
-            const auto comparison{fmt::format("if((addr>={})&&(addr<({}+\nuint64_t(uvec2({}))))){{",
-                                              ssbo_addr, ssbo_addr, ssbo_size)};
-            write_func += comparison;
-            write_func_64 += comparison;
-            write_func_128 += comparison;
-            load_func += comparison;
-            load_func_64 += comparison;
-            load_func_128 += comparison;
-
-            const auto ssbo_name{fmt::format("{}_ssbo{}", stage_name, index)};
-            write_func += fmt::format("{}[uint(addr-{})>>2]=data;return;}}", ssbo_name, ssbo_addr);
-            write_func_64 +=
-                fmt::format("{}[uint(addr-{})>>2]=data.x;{}[uint(addr-{}+4)>>2]=data.y;return;}}",
-                            ssbo_name, ssbo_addr, ssbo_name, ssbo_addr);
-            write_func_128 +=
-                fmt::format("{}[uint(addr-{})>>2]=data.x;{}[uint(addr-{}+4)>>2]=data.y;{}[uint("
-                            "addr-{}+8)>>2]=data.z;{}[uint(addr-{}+12)>>2]=data.w;return;}}",
-                            ssbo_name, ssbo_addr, ssbo_name, ssbo_addr, ssbo_name, ssbo_addr,
-                            ssbo_name, ssbo_addr);
-            load_func += fmt::format("return {}[uint(addr-{})>>2];}}", ssbo_name, ssbo_addr);
-            load_func_64 +=
-                fmt::format("return uvec2({}[uint(addr-{})>>2],{}[uint(addr-{}+4)>>2]);}}",
-                            ssbo_name, ssbo_addr, ssbo_name, ssbo_addr);
-            load_func_128 += fmt::format("return "
-                                         "uvec4({}[uint(addr-{})>>2],{}[uint(addr-{}+4)>>2],{}["
-                                         "uint(addr-{}+8)>>2],{}[uint(addr-{}+12)>>2]);}}",
-                                         ssbo_name, ssbo_addr, ssbo_name, ssbo_addr, ssbo_name,
-                                         ssbo_addr, ssbo_name, ssbo_addr);
-        }
-        write_func += "}\n";
-        write_func_64 += "}\n";
-        write_func_128 += "}\n";
-        load_func += "return 0u;}\n";
-        load_func_64 += "return uvec2(0);}\n";
-        load_func_128 += "return uvec4(0);}\n";
-        header += write_func;
-        header += write_func_64;
-        header += write_func_128;
-        header += load_func;
-        header += load_func_64;
-        header += load_func_128;
+        header += DefineGlobalMemoryFunctions();
     }
 }
 
+std::string EmitContext::DefineGlobalMemoryFunctions() {
+    const auto define_body{[&](std::string& func, size_t index, u32 num_components,
+                               std::string_view return_statement) {
+        const auto& ssbo{info.storage_buffers_descriptors[index]};
+        const u32 size_cbuf_offset{ssbo.cbuf_offset + 8};
+        const auto ssbo_addr{fmt::format("ssbo_addr{}", index)};
+        const auto cbuf{fmt::format("{}_cbuf{}", stage_name, ssbo.cbuf_index)};
+        std::array<std::string, 2> addr_xy;
+        std::array<std::string, 2> size_xy;
+        for (size_t i = 0; i < addr_xy.size(); ++i) {
+            const auto addr_loc{ssbo.cbuf_offset + 4 * i};
+            const auto size_loc{size_cbuf_offset + 4 * i};
+            addr_xy[i] = fmt::format("ftou({}[{}].{})", cbuf, addr_loc / 16, CbufSwizzle(addr_loc));
+            size_xy[i] = fmt::format("ftou({}[{}].{})", cbuf, size_loc / 16, CbufSwizzle(size_loc));
+        }
+        const auto addr_pack{fmt::format("packUint2x32(uvec2({},{}))", addr_xy[0], addr_xy[1])};
+        const auto addr_statment{fmt::format("uint64_t {}={};", ssbo_addr, addr_pack)};
+        func += addr_statment;
+
+        const auto size_vec{fmt::format("uvec2({},{})", size_xy[0], size_xy[1])};
+        const auto comp_lhs{fmt::format("(addr>={})", ssbo_addr)};
+        const auto comp_rhs{fmt::format("(addr<({}+uint64_t({})))", ssbo_addr, size_vec)};
+        const auto comparison{fmt::format("if({}&&{}){{", comp_lhs, comp_rhs)};
+        func += comparison;
+
+        const auto ssbo_name{fmt::format("{}_ssbo{}", stage_name, index)};
+        switch (num_components) {
+        case 1:
+            func += fmt::format(return_statement, ssbo_name, ssbo_addr);
+            break;
+        case 2:
+            func += fmt::format(return_statement, ssbo_name, ssbo_addr, ssbo_name, ssbo_addr);
+            break;
+        case 4:
+            func += fmt::format(return_statement, ssbo_name, ssbo_addr, ssbo_name, ssbo_addr,
+                                ssbo_name, ssbo_addr, ssbo_name, ssbo_addr);
+            break;
+        }
+    }};
+    std::string write_func{"void WriteGlobal32(uint64_t addr,uint data){\n"};
+    std::string write_func_64{"void WriteGlobal64(uint64_t addr,uvec2 data){\n"};
+    std::string write_func_128{"void WriteGlobal128(uint64_t addr,uvec4 data){\n"};
+    std::string load_func{"uint LoadGlobal32(uint64_t addr){\n"};
+    std::string load_func_64{"uvec2 LoadGlobal64(uint64_t addr){\n"};
+    std::string load_func_128{"uvec4 LoadGlobal128(uint64_t addr){\n"};
+    const size_t num_buffers{info.storage_buffers_descriptors.size()};
+    for (size_t index = 0; index < num_buffers; ++index) {
+        if (!info.nvn_buffer_used[index]) {
+            continue;
+        }
+        define_body(write_func, index, 1, "{}[uint(addr-{})>>2]=data;return;}}");
+        define_body(write_func_64, index, 2,
+                    "{}[uint(addr-{})>>2]=data.x;{}[uint(addr-{}+4)>>2]=data.y;return;}}");
+        define_body(write_func_128, index, 4,
+                    "{}[uint(addr-{})>>2]=data.x;{}[uint(addr-{}+4)>>2]=data.y;{}[uint("
+                    "addr-{}+8)>>2]=data.z;{}[uint(addr-{}+12)>>2]=data.w;return;}}");
+        define_body(load_func, index, 1, "return {}[uint(addr-{})>>2];}}");
+        define_body(load_func_64, index, 2,
+                    "return uvec2({}[uint(addr-{})>>2],{}[uint(addr-{}+4)>>2]);}}");
+        define_body(load_func_128, index, 4,
+                    "return uvec4({}[uint(addr-{})>>2],{}[uint(addr-{}+4)>>2],{}["
+                    "uint(addr-{}+8)>>2],{}[uint(addr-{}+12)>>2]);}}");
+    }
+    write_func += "}";
+    write_func_64 += "}";
+    write_func_128 += "}";
+    load_func += "return 0u;}";
+    load_func_64 += "return uvec2(0);}";
+    load_func_128 += "return uvec4(0);}";
+    return write_func + write_func_64 + write_func_128 + load_func + load_func_64 + load_func_128;
+}
+
 void EmitContext::SetupImages(Bindings& bindings) {
     image_buffer_bindings.reserve(info.image_buffer_descriptors.size());
     for (const auto& desc : info.image_buffer_descriptors) {
diff --git a/src/shader_recompiler/backend/glsl/emit_context.h b/src/shader_recompiler/backend/glsl/emit_context.h
index 26a76f8a3..daca1b6f9 100644
--- a/src/shader_recompiler/backend/glsl/emit_context.h
+++ b/src/shader_recompiler/backend/glsl/emit_context.h
@@ -167,6 +167,7 @@ private:
     void DefineStorageBuffers(Bindings& bindings);
     void DefineGenericOutput(size_t index, u32 invocations);
     void DefineHelperFunctions();
+    std::string DefineGlobalMemoryFunctions();
     void SetupImages(Bindings& bindings);
 };