From 32e6727daecab60d368d14619c1e04d0d7e60008 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Fri, 15 May 2020 02:52:35 -0300
Subject: [PATCH] shader/other: Implement MEMBAR.CTS

This silences an assertion we were hitting and uses workgroup memory
barriers when the game requests it.
---
 .../renderer_opengl/gl_shader_decompiler.cpp       | 10 ++++++++--
 .../renderer_vulkan/vk_shader_decompiler.cpp       |  7 ++++---
 src/video_core/shader/decode/other.cpp             | 14 ++++++++++++--
 src/video_core/shader/node.h                       |  5 +++--
 4 files changed, 27 insertions(+), 9 deletions(-)

diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
index 253484968..9cb115959 100644
--- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
@@ -2344,7 +2344,12 @@ private:
         return {};
     }
 
-    Expression MemoryBarrierGL(Operation) {
+    Expression MemoryBarrierGroup(Operation) {
+        code.AddLine("groupMemoryBarrier();");
+        return {};
+    }
+
+    Expression MemoryBarrierGlobal(Operation) {
         code.AddLine("memoryBarrier();");
         return {};
     }
@@ -2591,7 +2596,8 @@ private:
         &GLSLDecompiler::ShuffleIndexed,
 
         &GLSLDecompiler::Barrier,
-        &GLSLDecompiler::MemoryBarrierGL,
+        &GLSLDecompiler::MemoryBarrierGroup,
+        &GLSLDecompiler::MemoryBarrierGlobal,
     };
     static_assert(operation_decompilers.size() == static_cast<std::size_t>(OperationCode::Amount));
 
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 890f34a2c..6f6dedd82 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -2215,8 +2215,8 @@ private:
         return {};
     }
 
-    Expression MemoryBarrierGL(Operation) {
-        const auto scope = spv::Scope::Device;
+    template <spv::Scope scope>
+    Expression MemoryBarrier(Operation) {
         const auto semantics =
             spv::MemorySemanticsMask::AcquireRelease | spv::MemorySemanticsMask::UniformMemory |
             spv::MemorySemanticsMask::WorkgroupMemory |
@@ -2681,7 +2681,8 @@ private:
         &SPIRVDecompiler::ShuffleIndexed,
 
         &SPIRVDecompiler::Barrier,
-        &SPIRVDecompiler::MemoryBarrierGL,
+        &SPIRVDecompiler::MemoryBarrier<spv::Scope::Workgroup>,
+        &SPIRVDecompiler::MemoryBarrier<spv::Scope::Device>,
     };
     static_assert(operation_decompilers.size() == static_cast<std::size_t>(OperationCode::Amount));
 
diff --git a/src/video_core/shader/decode/other.cpp b/src/video_core/shader/decode/other.cpp
index 694b325e1..d00e10913 100644
--- a/src/video_core/shader/decode/other.cpp
+++ b/src/video_core/shader/decode/other.cpp
@@ -299,9 +299,19 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) {
         break;
     }
     case OpCode::Id::MEMBAR: {
-        UNIMPLEMENTED_IF(instr.membar.type != Tegra::Shader::MembarType::GL);
         UNIMPLEMENTED_IF(instr.membar.unknown != Tegra::Shader::MembarUnknown::Default);
-        bb.push_back(Operation(OperationCode::MemoryBarrierGL));
+        const OperationCode type = [instr] {
+            switch (instr.membar.type) {
+            case Tegra::Shader::MembarType::CTA:
+                return OperationCode::MemoryBarrierGroup;
+            case Tegra::Shader::MembarType::GL:
+                return OperationCode::MemoryBarrierGlobal;
+            default:
+                UNIMPLEMENTED_MSG("MEMBAR type={}", static_cast<int>(instr.membar.type.Value()));
+                return OperationCode::MemoryBarrierGlobal;
+            }
+        }();
+        bb.push_back(Operation(type));
         break;
     }
     case OpCode::Id::DEPBAR: {
diff --git a/src/video_core/shader/node.h b/src/video_core/shader/node.h
index c06512413..c5e5165ff 100644
--- a/src/video_core/shader/node.h
+++ b/src/video_core/shader/node.h
@@ -233,8 +233,9 @@ enum class OperationCode {
     ThreadLtMask,   /// () -> uint
     ShuffleIndexed, /// (uint value, uint index) -> uint
 
-    Barrier,         /// () -> void
-    MemoryBarrierGL, /// () -> void
+    Barrier,             /// () -> void
+    MemoryBarrierGroup,  /// () -> void
+    MemoryBarrierGlobal, /// () -> void
 
     Amount,
 };