From 3482df1176203b4999353e8266f42032536b561c Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Tue, 25 May 2021 02:01:32 -0400
Subject: [PATCH] glsl: Simply FP storage atomics

---
 .../backend/glsl/emit_context.cpp             | 28 +++++------
 .../backend/glsl/emit_glsl_atomic.cpp         | 48 ++++++-------------
 2 files changed, 28 insertions(+), 48 deletions(-)

diff --git a/src/shader_recompiler/backend/glsl/emit_context.cpp b/src/shader_recompiler/backend/glsl/emit_context.cpp
index a413219e3..9c3fd44ba 100644
--- a/src/shader_recompiler/backend/glsl/emit_context.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_context.cpp
@@ -78,32 +78,32 @@ void EmitContext::DefineHelperFunctions() {
             "uint CasDecrement(uint op_a,uint op_b){return(op_a==0||op_a>op_b)?op_b:(op_a-1u);}\n";
     }
     if (info.uses_atomic_f32_add) {
-        code += "uint CasFloatAdd(uint op_a,uint op_b){return "
-                "floatBitsToUint(uintBitsToFloat(op_a)+uintBitsToFloat(op_b));}\n";
+        code += "uint CasFloatAdd(uint op_a,float op_b){return "
+                "floatBitsToUint(uintBitsToFloat(op_a)+op_b);}\n";
     }
     if (info.uses_atomic_f32x2_add) {
-        code += "uint CasFloatAdd32x2(uint op_a,uint op_b){return "
-                "packHalf2x16(unpackHalf2x16(op_a)+unpackHalf2x16(op_b));}\n";
+        code += "uint CasFloatAdd32x2(uint op_a,vec2 op_b){return "
+                "packHalf2x16(unpackHalf2x16(op_a)+op_b);}\n";
     }
     if (info.uses_atomic_f32x2_min) {
-        code += "uint CasFloatMin32x2(uint op_a,uint op_b){return "
-                "packHalf2x16(min(unpackHalf2x16(op_a),unpackHalf2x16(op_b)));}\n";
+        code += "uint CasFloatMin32x2(uint op_a,vec2 op_b){return "
+                "packHalf2x16(min(unpackHalf2x16(op_a),op_b));}\n";
     }
     if (info.uses_atomic_f32x2_max) {
-        code += "uint CasFloatMax32x2(uint op_a,uint op_b){return "
-                "packHalf2x16(max(unpackHalf2x16(op_a),unpackHalf2x16(op_b)));}\n";
+        code += "uint CasFloatMax32x2(uint op_a,vec2 op_b){return "
+                "packHalf2x16(max(unpackHalf2x16(op_a),op_b));}\n";
     }
     if (info.uses_atomic_f16x2_add) {
-        code += "uint CasFloatAdd16x2(uint op_a,uint op_b){return "
-                "packFloat2x16(unpackFloat2x16(op_a)+unpackFloat2x16(op_b));}\n";
+        code += "uint CasFloatAdd16x2(uint op_a,f16vec2 op_b){return "
+                "packFloat2x16(unpackFloat2x16(op_a)+op_b);}\n";
     }
     if (info.uses_atomic_f16x2_min) {
-        code += "uint CasFloatMin16x2(uint op_a,uint op_b){return "
-                "packFloat2x16(min(unpackFloat2x16(op_a),unpackFloat2x16(op_b)));}\n";
+        code += "uint CasFloatMin16x2(uint op_a,f16vec2 op_b){return "
+                "packFloat2x16(min(unpackFloat2x16(op_a),op_b));}\n";
     }
     if (info.uses_atomic_f16x2_max) {
-        code += "uint CasFloatMax16x2(uint op_a,uint op_b){return "
-                "packFloat2x16(max(unpackFloat2x16(op_a),unpackFloat2x16(op_b)));}\n";
+        code += "uint CasFloatMax16x2(uint op_a,f16vec2 op_b){return "
+                "packFloat2x16(max(unpackFloat2x16(op_a),op_b));}\n";
     }
     // TODO: Track this usage
     code += "uint CasMinS32(uint op_a,uint op_b){return uint(min(int(op_a),int(op_b)));}";
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_atomic.cpp b/src/shader_recompiler/backend/glsl/emit_glsl_atomic.cpp
index b6b326762..0c3af75f5 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_atomic.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_atomic.cpp
@@ -19,16 +19,11 @@ for (;;){{
     if ({}==old_value){{break;}}
 }})"};
 
-void CasFunction(EmitContext& ctx, std::string_view ret, std::string_view ssbo,
-                 std::string_view value, std::string_view function) {
-    ctx.Add(cas_loop.data(), ret, ssbo, ret, ssbo, function, ssbo, value, ret);
-}
-
-void CasFunctionInt32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
-                      const IR::Value& offset, std::string_view value, std::string_view function) {
+void CasFunction(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
+                 const IR::Value& offset, std::string_view value, std::string_view function) {
     const auto ret{ctx.reg_alloc.Define(inst)};
     const std::string ssbo{fmt::format("ssbo{}[{}]", binding.U32(), offset.U32())};
-    CasFunction(ctx, ret, ssbo, value, function);
+    ctx.Add(cas_loop.data(), ret, ssbo, ret, ssbo, function, ssbo, value, ret);
 }
 
 void CasFunctionF32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
@@ -37,25 +32,10 @@ void CasFunctionF32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
     const std::string u32_value{fmt::format("floatBitsToUint({})", value)};
     const auto ret{ctx.reg_alloc.Define(inst)};
     const auto ret_32{ret + "_u32"};
-    CasFunction(ctx, ret_32, ssbo, u32_value, function);
+    ctx.Add(cas_loop.data(), ret_32, ssbo, ret_32, ssbo, function, ssbo, value, ret_32);
     ctx.Add("float {}=uintBitsToFloat({});", ret, ret_32);
 }
 
-void CasFunctionF32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
-                      const IR::Value& offset, std::string_view value, std::string_view function) {
-    const std::string ssbo{fmt::format("ssbo{}[{}]", binding.U32(), offset.U32())};
-    const std::string u32_value{fmt::format("packHalf2x16({})", value)};
-    const auto ret{ctx.reg_alloc.Define(inst)};
-    CasFunction(ctx, ret, ssbo, u32_value, function);
-}
-
-void CasFunctionF16x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
-                      const IR::Value& offset, std::string_view value, std::string_view function) {
-    const std::string ssbo{fmt::format("ssbo{}[{}]", binding.U32(), offset.U32())};
-    const std::string u32_value{fmt::format("packFloat2x16({})", value)};
-    const auto ret{ctx.reg_alloc.Define(inst)};
-    CasFunction(ctx, ret, ssbo, u32_value, function);
-}
 } // namespace
 
 void EmitStorageAtomicIAdd32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
@@ -66,7 +46,7 @@ void EmitStorageAtomicIAdd32(EmitContext& ctx, IR::Inst& inst, const IR::Value&
 void EmitStorageAtomicSMin32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                              const IR::Value& offset, std::string_view value) {
     const std::string u32_value{fmt::format("uint({})", value)};
-    CasFunctionInt32(ctx, inst, binding, offset, u32_value, "CasMinS32");
+    CasFunction(ctx, inst, binding, offset, u32_value, "CasMinS32");
 }
 
 void EmitStorageAtomicUMin32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
@@ -77,7 +57,7 @@ void EmitStorageAtomicUMin32(EmitContext& ctx, IR::Inst& inst, const IR::Value&
 void EmitStorageAtomicSMax32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                              const IR::Value& offset, std::string_view value) {
     const std::string u32_value{fmt::format("uint({})", value)};
-    CasFunctionInt32(ctx, inst, binding, offset, u32_value, "CasMaxS32");
+    CasFunction(ctx, inst, binding, offset, u32_value, "CasMaxS32");
 }
 
 void EmitStorageAtomicUMax32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
@@ -87,12 +67,12 @@ void EmitStorageAtomicUMax32(EmitContext& ctx, IR::Inst& inst, const IR::Value&
 
 void EmitStorageAtomicInc32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                             const IR::Value& offset, std::string_view value) {
-    CasFunctionInt32(ctx, inst, binding, offset, value, "CasIncrement");
+    CasFunction(ctx, inst, binding, offset, value, "CasIncrement");
 }
 
 void EmitStorageAtomicDec32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                             const IR::Value& offset, std::string_view value) {
-    CasFunctionInt32(ctx, inst, binding, offset, value, "CasDecrement");
+    CasFunction(ctx, inst, binding, offset, value, "CasDecrement");
 }
 
 void EmitStorageAtomicAnd32(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
@@ -199,32 +179,32 @@ void EmitStorageAtomicAddF32(EmitContext& ctx, IR::Inst& inst, const IR::Value&
 
 void EmitStorageAtomicAddF16x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                                const IR::Value& offset, std::string_view value) {
-    CasFunctionF16x2(ctx, inst, binding, offset, value, "CasFloatAdd16x2");
+    CasFunction(ctx, inst, binding, offset, value, "CasFloatAdd16x2");
 }
 
 void EmitStorageAtomicAddF32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                                const IR::Value& offset, std::string_view value) {
-    CasFunctionF32x2(ctx, inst, binding, offset, value, "CasFloatAdd32x2");
+    CasFunction(ctx, inst, binding, offset, value, "CasFloatAdd32x2");
 }
 
 void EmitStorageAtomicMinF16x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                                const IR::Value& offset, std::string_view value) {
-    CasFunctionF16x2(ctx, inst, binding, offset, value, "CasFloatMin16x2");
+    CasFunction(ctx, inst, binding, offset, value, "CasFloatMin16x2");
 }
 
 void EmitStorageAtomicMinF32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                                const IR::Value& offset, std::string_view value) {
-    CasFunctionF32x2(ctx, inst, binding, offset, value, "CasFloatMin32x2");
+    CasFunction(ctx, inst, binding, offset, value, "CasFloatMin32x2");
 }
 
 void EmitStorageAtomicMaxF16x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                                const IR::Value& offset, std::string_view value) {
-    CasFunctionF16x2(ctx, inst, binding, offset, value, "CasFloatMax16x2");
+    CasFunction(ctx, inst, binding, offset, value, "CasFloatMax16x2");
 }
 
 void EmitStorageAtomicMaxF32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding,
                                const IR::Value& offset, std::string_view value) {
-    CasFunctionF32x2(ctx, inst, binding, offset, value, "CasFloatMax32x2");
+    CasFunction(ctx, inst, binding, offset, value, "CasFloatMax32x2");
 }
 
 void EmitGlobalAtomicIAdd32(EmitContext&) {