From b758c861b06a37cdf823476898a289497d0ddf60 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Thu, 14 Mar 2019 02:36:14 -0300
Subject: [PATCH] vk_shader_decompiler: Implement OperationCode decompilation
 interface

---
 .../renderer_vulkan/vk_shader_decompiler.cpp  | 412 +++++++++++++++++-
 1 file changed, 411 insertions(+), 1 deletion(-)

diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 5f174bb7f..e4c3e3d9c 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -30,6 +30,7 @@ using namespace VideoCommon::Shader;
 
 using Maxwell = Tegra::Engines::Maxwell3D::Regs;
 using ShaderStage = Tegra::Engines::Maxwell3D::Regs::ShaderStage;
+using Operation = const OperationNode&;
 
 // TODO(Rodrigo): Use rasterizer's value
 constexpr u32 MAX_CONSTBUFFER_ELEMENTS = 0x1000;
@@ -73,6 +74,19 @@ constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) {
     return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
 }
 
+/// Returns true if an object has to be treated as precise
+bool IsPrecise(Operation operand) {
+    const auto& meta = operand.GetMeta();
+
+    if (std::holds_alternative<MetaArithmetic>(meta)) {
+        return std::get<MetaArithmetic>(meta).precise;
+    }
+    if (std::holds_alternative<MetaHalfArithmetic>(meta)) {
+        return std::get<MetaHalfArithmetic>(meta).precise;
+    }
+    return false;
+}
+
 } // namespace
 
 class SPIRVDecompiler : public Sirit::Module {
@@ -194,6 +208,10 @@ public:
     }
 
 private:
+    using OperationDecompilerFn = Id (SPIRVDecompiler::*)(Operation);
+    using OperationDecompilersArray =
+        std::array<OperationDecompilerFn, static_cast<std::size_t>(OperationCode::Amount)>;
+
     static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
     static constexpr u32 CBUF_STRIDE = 16;
 
@@ -468,7 +486,12 @@ private:
 
     Id Visit(Node node) {
         if (const auto operation = std::get_if<OperationNode>(node)) {
-            UNIMPLEMENTED();
+            const auto operation_index = static_cast<std::size_t>(operation->GetCode());
+            const auto decompiler = operation_decompilers[operation_index];
+            if (decompiler == nullptr) {
+                UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
+            }
+            return (this->*decompiler)(*operation);
 
         } else if (const auto gpr = std::get_if<GprNode>(node)) {
             UNIMPLEMENTED();
@@ -500,6 +523,184 @@ private:
         return {};
     }
 
+    template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
+    Id Unary(Operation operation) {
+        const Id type_def = GetTypeDefinition(result_type);
+        const Id op_a = VisitOperand<type_a>(operation, 0);
+
+        const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a)));
+        if (IsPrecise(operation)) {
+            Decorate(value, spv::Decoration::NoContraction);
+        }
+        return value;
+    }
+
+    template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
+              Type type_b = type_a>
+    Id Binary(Operation operation) {
+        const Id type_def = GetTypeDefinition(result_type);
+        const Id op_a = VisitOperand<type_a>(operation, 0);
+        const Id op_b = VisitOperand<type_b>(operation, 1);
+
+        const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b)));
+        if (IsPrecise(operation)) {
+            Decorate(value, spv::Decoration::NoContraction);
+        }
+        return value;
+    }
+
+    template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
+              Type type_b = type_a, Type type_c = type_b>
+    Id Ternary(Operation operation) {
+        const Id type_def = GetTypeDefinition(result_type);
+        const Id op_a = VisitOperand<type_a>(operation, 0);
+        const Id op_b = VisitOperand<type_b>(operation, 1);
+        const Id op_c = VisitOperand<type_c>(operation, 2);
+
+        const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c)));
+        if (IsPrecise(operation)) {
+            Decorate(value, spv::Decoration::NoContraction);
+        }
+        return value;
+    }
+
+    template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
+              Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
+    Id Quaternary(Operation operation) {
+        const Id type_def = GetTypeDefinition(result_type);
+        const Id op_a = VisitOperand<type_a>(operation, 0);
+        const Id op_b = VisitOperand<type_b>(operation, 1);
+        const Id op_c = VisitOperand<type_c>(operation, 2);
+        const Id op_d = VisitOperand<type_d>(operation, 3);
+
+        const Id value =
+            BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
+        if (IsPrecise(operation)) {
+            Decorate(value, spv::Decoration::NoContraction);
+        }
+        return value;
+    }
+
+    Id Assign(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id HNegate(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id HMergeF32(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id HMergeH0(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id HMergeH1(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id HPack2(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id LogicalAssign(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id LogicalPick2(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id LogicalAll2(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id LogicalAny2(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id Texture(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id TextureLod(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id TextureGather(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id TextureQueryDimensions(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id TextureQueryLod(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id TexelFetch(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id Branch(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id PushFlowStack(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id PopFlowStack(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id Exit(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id Discard(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id EmitVertex(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id EndPrimitive(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
+    Id YNegate(Operation operation) {
+        UNIMPLEMENTED();
+        return {};
+    }
+
     Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type,
                       const std::string& name) {
         const Id id = OpVariable(type, storage);
@@ -518,6 +719,215 @@ private:
         return false;
     }
 
+    template <Type type>
+    Id VisitOperand(Operation operation, std::size_t operand_index) {
+        const Id value = Visit(operation[operand_index]);
+
+        switch (type) {
+        case Type::Bool:
+        case Type::Bool2:
+        case Type::Float:
+            return value;
+        case Type::Int:
+            return Emit(OpBitcast(t_int, value));
+        case Type::Uint:
+            return Emit(OpBitcast(t_uint, value));
+        case Type::HalfFloat:
+            UNIMPLEMENTED();
+        }
+        UNREACHABLE();
+        return value;
+    }
+
+    template <Type type>
+    Id BitcastFrom(Id value) {
+        switch (type) {
+        case Type::Bool:
+        case Type::Bool2:
+        case Type::Float:
+            return value;
+        case Type::Int:
+        case Type::Uint:
+            return Emit(OpBitcast(t_float, value));
+        case Type::HalfFloat:
+            UNIMPLEMENTED();
+        }
+        UNREACHABLE();
+        return value;
+    }
+
+    template <Type type>
+    Id BitcastTo(Id value) {
+        switch (type) {
+        case Type::Bool:
+        case Type::Bool2:
+            UNREACHABLE();
+        case Type::Float:
+            return Emit(OpBitcast(t_float, value));
+        case Type::Int:
+            return Emit(OpBitcast(t_int, value));
+        case Type::Uint:
+            return Emit(OpBitcast(t_uint, value));
+        case Type::HalfFloat:
+            UNIMPLEMENTED();
+        }
+        UNREACHABLE();
+        return value;
+    }
+
+    Id GetTypeDefinition(Type type) {
+        switch (type) {
+        case Type::Bool:
+            return t_bool;
+        case Type::Bool2:
+            return t_bool2;
+        case Type::Float:
+            return t_float;
+        case Type::Int:
+            return t_int;
+        case Type::Uint:
+            return t_uint;
+        case Type::HalfFloat:
+            UNIMPLEMENTED();
+        }
+        UNREACHABLE();
+        return {};
+    }
+
+    static constexpr OperationDecompilersArray operation_decompilers = {
+        &SPIRVDecompiler::Assign,
+
+        &SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float,
+                                  Type::Float>,
+
+        &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFMul, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFDiv, Type::Float>,
+        &SPIRVDecompiler::Ternary<&Module::OpFma, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
+        &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpSin, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpExp2, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpLog2, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpInverseSqrt, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpSqrt, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpRoundEven, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpFloor, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpCeil, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpTrunc, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpConvertSToF, Type::Float, Type::Int>,
+        &SPIRVDecompiler::Unary<&Module::OpConvertUToF, Type::Float, Type::Uint>,
+
+        &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpSDiv, Type::Int>,
+        &SPIRVDecompiler::Unary<&Module::OpSNegate, Type::Int>,
+        &SPIRVDecompiler::Unary<&Module::OpSAbs, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpSMin, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpSMax, Type::Int>,
+
+        &SPIRVDecompiler::Unary<&Module::OpConvertFToS, Type::Int, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Int, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Int, Type::Int, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Int, Type::Int, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Int, Type::Int, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Int>,
+        &SPIRVDecompiler::Unary<&Module::OpNot, Type::Int>,
+        &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Int>,
+        &SPIRVDecompiler::Ternary<&Module::OpBitFieldSExtract, Type::Int>,
+        &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Int>,
+
+        &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpUDiv, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpUMin, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpUMax, Type::Uint>,
+        &SPIRVDecompiler::Unary<&Module::OpConvertFToU, Type::Uint, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
+        &SPIRVDecompiler::Unary<&Module::OpNot, Type::Uint>,
+        &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Uint>,
+        &SPIRVDecompiler::Ternary<&Module::OpBitFieldUExtract, Type::Uint>,
+        &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Uint>,
+
+        &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFMul, Type::HalfFloat>,
+        &SPIRVDecompiler::Ternary<&Module::OpFma, Type::HalfFloat>,
+        &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::HalfFloat>,
+        &SPIRVDecompiler::HNegate,
+        &SPIRVDecompiler::HMergeF32,
+        &SPIRVDecompiler::HMergeH0,
+        &SPIRVDecompiler::HMergeH1,
+        &SPIRVDecompiler::HPack2,
+
+        &SPIRVDecompiler::LogicalAssign,
+        &SPIRVDecompiler::Binary<&Module::OpLogicalAnd, Type::Bool>,
+        &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
+        &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
+        &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
+        &SPIRVDecompiler::LogicalPick2,
+        &SPIRVDecompiler::LogicalAll2,
+        &SPIRVDecompiler::LogicalAny2,
+
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
+        &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>,
+
+        &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpSLessThanEqual, Type::Bool, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpSGreaterThan, Type::Bool, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Int>,
+        &SPIRVDecompiler::Binary<&Module::OpSGreaterThanEqual, Type::Bool, Type::Int>,
+
+        &SPIRVDecompiler::Binary<&Module::OpULessThan, Type::Bool, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpULessThanEqual, Type::Bool, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpUGreaterThan, Type::Bool, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
+        &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
+
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
+        &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
+
+        &SPIRVDecompiler::Texture,
+        &SPIRVDecompiler::TextureLod,
+        &SPIRVDecompiler::TextureGather,
+        &SPIRVDecompiler::TextureQueryDimensions,
+        &SPIRVDecompiler::TextureQueryLod,
+        &SPIRVDecompiler::TexelFetch,
+
+        &SPIRVDecompiler::Branch,
+        &SPIRVDecompiler::PushFlowStack,
+        &SPIRVDecompiler::PopFlowStack,
+        &SPIRVDecompiler::Exit,
+        &SPIRVDecompiler::Discard,
+
+        &SPIRVDecompiler::EmitVertex,
+        &SPIRVDecompiler::EndPrimitive,
+
+        &SPIRVDecompiler::YNegate,
+    };
+
     const ShaderIR& ir;
     const ShaderStage stage;
     const Tegra::Shader::Header header;