From 8c684b3e2327bc7b0c02f2a22dbf52c11884ecd3 Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Mon, 31 May 2021 23:07:13 -0400
Subject: [PATCH] glsl: Implement tessellation shaders

---
 .../backend/glsl/emit_context.cpp             | 90 ++++++++++++++++---
 .../glsl/emit_glsl_context_get_set.cpp        | 69 +++++++++++++-
 .../backend/glsl/emit_glsl_floating_point.cpp |  2 +-
 .../backend/glsl/emit_glsl_instructions.h     |  2 +-
 .../glsl/emit_glsl_not_implemented.cpp        | 10 +--
 5 files changed, 146 insertions(+), 27 deletions(-)

diff --git a/src/shader_recompiler/backend/glsl/emit_context.cpp b/src/shader_recompiler/backend/glsl/emit_context.cpp
index 923060386..01403ca17 100644
--- a/src/shader_recompiler/backend/glsl/emit_context.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_context.cpp
@@ -21,10 +21,21 @@ std::string_view InterpDecorator(Interpolation interp) {
     throw InvalidArgument("Invalid interpolation {}", interp);
 }
 
-std::string_view ArrayDecorator(Stage stage) {
+std::string_view InputArrayDecorator(Stage stage) {
     switch (stage) {
     case Stage::Geometry:
-        return "[1]";
+    case Stage::TessellationControl:
+    case Stage::TessellationEval:
+        return "[]";
+    default:
+        return "";
+    }
+}
+
+std::string OutputDecorator(Stage stage, u32 size) {
+    switch (stage) {
+    case Stage::TessellationControl:
+        return fmt::format("[{}]", size);
     default:
         return "";
     }
@@ -73,6 +84,30 @@ std::string_view SamplerType(TextureType type, bool is_depth) {
     }
 }
 
+std::string_view GetTessMode(TessPrimitive primitive) {
+    switch (primitive) {
+    case TessPrimitive::Triangles:
+        return "triangles";
+    case TessPrimitive::Quads:
+        return "quads";
+    case TessPrimitive::Isolines:
+        return "isolines";
+    }
+    throw InvalidArgument("Invalid tessellation primitive {}", primitive);
+}
+
+std::string_view GetTessSpacing(TessSpacing spacing) {
+    switch (spacing) {
+    case TessSpacing::Equal:
+        return "equal_spacing";
+    case TessSpacing::FractionalOdd:
+        return "fractional_odd_spacing";
+    case TessSpacing::FractionalEven:
+        return "fractional_even_spacing";
+    }
+    throw InvalidArgument("Invalid tessellation spacing {}", spacing);
+}
+
 std::string_view InputPrimitive(InputTopology topology) {
     switch (topology) {
     case InputTopology::Points:
@@ -100,6 +135,23 @@ std::string_view OutputPrimitive(OutputTopology topology) {
     }
     throw InvalidArgument("Invalid output topology {}", topology);
 }
+
+void SetupOutPerVertex(Stage stage, const Info& info, std::string& header) {
+    if (stage != Stage::VertexA && stage != Stage::VertexB && stage != Stage::Geometry) {
+        return;
+    }
+    header += "out gl_PerVertex{";
+    if (info.stores_position) {
+        header += "vec4 gl_Position;";
+    }
+    if (info.stores_point_size) {
+        header += "float gl_PointSize;";
+    }
+    if (info.stores_clip_distance) {
+        header += "float gl_ClipDistance[];";
+    }
+    header += "};";
+}
 } // namespace
 
 EmitContext::EmitContext(IR::Program& program, Bindings& bindings, const Profile& profile_,
@@ -111,17 +163,20 @@ EmitContext::EmitContext(IR::Program& program, Bindings& bindings, const Profile
     case Stage::VertexA:
     case Stage::VertexB:
         stage_name = "vs";
-        // TODO: add only what's used by the shader
-        header +=
-            "out gl_PerVertex {vec4 gl_Position;float gl_PointSize;float gl_ClipDistance[];};";
         break;
     case Stage::TessellationControl:
+        stage_name = "tsc";
+        header += fmt::format("layout(vertices={})out;\n", program.invocations);
+        break;
     case Stage::TessellationEval:
-        stage_name = "ts";
+        stage_name = "tse";
+        header += fmt::format("layout({},{},{})in;\n", GetTessMode(runtime_info.tess_primitive),
+                              GetTessSpacing(runtime_info.tess_spacing),
+                              runtime_info.tess_clockwise ? "cw" : "ccw");
         break;
     case Stage::Geometry:
         stage_name = "gs";
-        header += fmt::format("layout({})in;layout({}, max_vertices={})out;\n",
+        header += fmt::format("layout({})in;layout({},max_vertices={})out;\n",
                               InputPrimitive(runtime_info.input_topology),
                               OutputPrimitive(program.output_topology), program.output_vertices);
         break;
@@ -135,12 +190,23 @@ EmitContext::EmitContext(IR::Program& program, Bindings& bindings, const Profile
                               program.workgroup_size[2]);
         break;
     }
+    SetupOutPerVertex(stage, info, header);
     for (size_t index = 0; index < info.input_generics.size(); ++index) {
         const auto& generic{info.input_generics[index]};
         if (generic.used) {
-            header +=
-                fmt::format("layout(location={}){} in vec4 in_attr{}{};", index,
-                            InterpDecorator(generic.interpolation), index, ArrayDecorator(stage));
+            header += fmt::format("layout(location={}){} in vec4 in_attr{}{};", index,
+                                  InterpDecorator(generic.interpolation), index,
+                                  InputArrayDecorator(stage));
+        }
+    }
+    for (size_t index = 0; index < info.uses_patches.size(); ++index) {
+        if (!info.uses_patches[index]) {
+            continue;
+        }
+        if (stage == Stage::TessellationControl) {
+            header += fmt::format("layout(location={})patch out vec4 patch{};", index, index);
+        } else {
+            header += fmt::format("layout(location={})patch in vec4 patch{};", index, index);
         }
     }
     for (size_t index = 0; index < info.stores_frag_color.size(); ++index) {
@@ -151,8 +217,8 @@ EmitContext::EmitContext(IR::Program& program, Bindings& bindings, const Profile
     }
     for (size_t index = 0; index < info.stores_generics.size(); ++index) {
         // TODO: Properly resolve attribute issues
-        const auto declaration{
-            fmt::format("layout(location={}) out vec4 out_attr{};", index, index)};
+        const auto declaration{fmt::format("layout(location={}) out vec4 out_attr{}{};", index,
+                                           index, OutputDecorator(stage, program.invocations))};
         if (info.stores_generics[index] || stage == Stage::VertexA || stage == Stage::VertexB) {
             header += declaration;
         }
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp b/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp
index 28e89a0a6..5c56477bf 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_context_get_set.cpp
@@ -25,9 +25,24 @@ bool IsInputArray(Stage stage) {
            stage == Stage::TessellationEval;
 }
 
-std::string VertexIndex(EmitContext& ctx, std::string_view vertex) {
+std::string InputVertexIndex(EmitContext& ctx, std::string_view vertex) {
     return IsInputArray(ctx.stage) ? fmt::format("[{}]", vertex) : "";
 }
+
+bool IsOutputArray(Stage stage) {
+    return stage == Stage::Geometry || stage == Stage::TessellationControl;
+}
+
+std::string OutputVertexIndex(EmitContext& ctx, std::string_view vertex) {
+    switch (ctx.stage) {
+    case Stage::Geometry:
+        return fmt::format("[{}]", vertex);
+    case Stage::TessellationControl:
+        return "[gl_InvocationID]";
+    default:
+        return "";
+    }
+}
 } // namespace
 
 void EmitGetCbufU8([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
@@ -132,12 +147,12 @@ void EmitGetCbufU32x2(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding
 }
 
 void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr,
-                      [[maybe_unused]] std::string_view vertex) {
+                      std::string_view vertex) {
     const u32 element{static_cast<u32>(attr) % 4};
     const char swizzle{"xyzw"[element]};
     if (IR::IsGeneric(attr)) {
         const u32 index{IR::GenericAttributeIndex(attr)};
-        ctx.AddF32("{}=in_attr{}{}.{};", inst, index, VertexIndex(ctx, vertex), swizzle);
+        ctx.AddF32("{}=in_attr{}{}.{};", inst, index, InputVertexIndex(ctx, vertex), swizzle);
         return;
     }
     switch (attr) {
@@ -150,6 +165,10 @@ void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr,
         case Stage::VertexB:
             ctx.AddF32("{}=gl_Position.{};", inst, swizzle);
             break;
+        case Stage::TessellationEval:
+            ctx.AddF32("{}=gl_TessCoord.{};", inst, swizzle);
+            break;
+        case Stage::TessellationControl:
         case Stage::Geometry:
             ctx.AddF32("{}=gl_in[{}].gl_Position.{};", inst, vertex, swizzle);
             break;
@@ -173,6 +192,10 @@ void EmitGetAttribute(EmitContext& ctx, IR::Inst& inst, IR::Attribute attr,
     case IR::Attribute::FrontFace:
         ctx.AddF32("{}=intBitsToFloat(gl_FrontFacing?-1:0);", inst);
         break;
+    case IR::Attribute::TessellationEvaluationPointU:
+    case IR::Attribute::TessellationEvaluationPointV:
+        ctx.AddF32("{}=gl_TessCoord.{};", inst, swizzle);
+        break;
     default:
         fmt::print("Get attribute {}", attr);
         throw NotImplementedException("Get attribute {}", attr);
@@ -185,7 +208,7 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, std::string_view val
     const char swizzle{"xyzw"[element]};
     if (IR::IsGeneric(attr)) {
         const u32 index{IR::GenericAttributeIndex(attr)};
-        ctx.Add("out_attr{}.{}={};", index, swizzle, value);
+        ctx.Add("out_attr{}{}.{}={};", index, OutputVertexIndex(ctx, vertex), swizzle, value);
         return;
     }
     switch (attr) {
@@ -219,6 +242,44 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, std::string_view val
     }
 }
 
+void EmitGetPatch([[maybe_unused]] EmitContext& ctx, IR::Inst& inst,
+                  [[maybe_unused]] IR::Patch patch) {
+    if (!IR::IsGeneric(patch)) {
+        throw NotImplementedException("Non-generic patch load");
+    }
+    const u32 index{IR::GenericPatchIndex(patch)};
+    const u32 element{IR::GenericPatchElement(patch)};
+    const char swizzle{"xyzw"[element]};
+    ctx.AddF32("{}=patch{}.{};", inst, index, swizzle);
+}
+
+void EmitSetPatch(EmitContext& ctx, IR::Patch patch, std::string_view value) {
+    if (IR::IsGeneric(patch)) {
+        const u32 index{IR::GenericPatchIndex(patch)};
+        const u32 element{IR::GenericPatchElement(patch)};
+        ctx.Add("patch{}.{}={};", index, "xyzw"[element], value);
+        return;
+    }
+    switch (patch) {
+    case IR::Patch::TessellationLodLeft:
+    case IR::Patch::TessellationLodRight:
+    case IR::Patch::TessellationLodTop:
+    case IR::Patch::TessellationLodBottom: {
+        const u32 index{static_cast<u32>(patch) - u32(IR::Patch::TessellationLodLeft)};
+        ctx.Add("gl_TessLevelOuter[{}]={};", index, value);
+        break;
+    }
+    case IR::Patch::TessellationLodInteriorU:
+        ctx.Add("gl_TessLevelInner[0]={};", value);
+        break;
+    case IR::Patch::TessellationLodInteriorV:
+        ctx.Add("gl_TessLevelInner[1]={};", value);
+        break;
+    default:
+        throw NotImplementedException("Patch {}", patch);
+    }
+}
+
 void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, std::string_view value) {
     const char swizzle{"xyzw"[component]};
     ctx.Add("frag_color{}.{}={};", index, swizzle, value);
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_floating_point.cpp b/src/shader_recompiler/backend/glsl/emit_glsl_floating_point.cpp
index 49ab182ea..f4b81407a 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_floating_point.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_floating_point.cpp
@@ -161,7 +161,7 @@ void EmitFPRecip64(EmitContext& ctx, IR::Inst& inst, std::string_view value) {
 
 void EmitFPRecipSqrt32([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
                        [[maybe_unused]] std::string_view value) {
-    ctx.AddF32("{}=(1.0f)/sqrt({});", inst, value);
+    ctx.AddF32("{}=inversesqrt({});", inst, value);
 }
 
 void EmitFPRecipSqrt64([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h b/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h
index e7009d8e9..89ded3614 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_instructions.h
@@ -75,7 +75,7 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, std::string_view val
 void EmitGetAttributeIndexed(EmitContext& ctx, std::string_view offset, std::string_view vertex);
 void EmitSetAttributeIndexed(EmitContext& ctx, std::string_view offset, std::string_view value,
                              std::string_view vertex);
-void EmitGetPatch(EmitContext& ctx, IR::Patch patch);
+void EmitGetPatch(EmitContext& ctx, IR::Inst& inst, IR::Patch patch);
 void EmitSetPatch(EmitContext& ctx, IR::Patch patch, std::string_view value);
 void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, std::string_view value);
 void EmitSetSampleMask(EmitContext& ctx, std::string_view value);
diff --git a/src/shader_recompiler/backend/glsl/emit_glsl_not_implemented.cpp b/src/shader_recompiler/backend/glsl/emit_glsl_not_implemented.cpp
index 3ed4e04d3..cf7b2a51e 100644
--- a/src/shader_recompiler/backend/glsl/emit_glsl_not_implemented.cpp
+++ b/src/shader_recompiler/backend/glsl/emit_glsl_not_implemented.cpp
@@ -151,14 +151,6 @@ void EmitSetAttributeIndexed(EmitContext& ctx, std::string_view offset, std::str
     NotImplemented();
 }
 
-void EmitGetPatch(EmitContext& ctx, IR::Patch patch) {
-    NotImplemented();
-}
-
-void EmitSetPatch(EmitContext& ctx, IR::Patch patch, std::string_view value) {
-    NotImplemented();
-}
-
 void EmitSetSampleMask(EmitContext& ctx, std::string_view value) {
     NotImplemented();
 }
@@ -204,7 +196,7 @@ void EmitWorkgroupId(EmitContext& ctx, IR::Inst& inst) {
 }
 
 void EmitInvocationId(EmitContext& ctx, IR::Inst& inst) {
-    NotImplemented();
+    ctx.AddU32("{}=uint(gl_InvocationID);", inst);
 }
 
 void EmitSampleId(EmitContext& ctx, IR::Inst& inst) {