From 5c61e860e4f83524ffce10ca447398e83de81640 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sun, 11 Apr 2021 19:16:12 -0300
Subject: [PATCH] shader: Implement SR_THREAD_KILL

---
 src/shader_recompiler/backend/spirv/emit_context.cpp          | 3 +++
 src/shader_recompiler/backend/spirv/emit_context.h            | 1 +
 src/shader_recompiler/backend/spirv/emit_spirv.h              | 1 +
 .../backend/spirv/emit_spirv_context_get_set.cpp              | 4 ++++
 src/shader_recompiler/frontend/ir/ir_emitter.cpp              | 4 ++++
 src/shader_recompiler/frontend/ir/ir_emitter.h                | 2 ++
 src/shader_recompiler/frontend/ir/opcodes.inc                 | 1 +
 .../frontend/maxwell/translate/impl/move_special_register.cpp | 2 ++
 src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp     | 3 +++
 src/shader_recompiler/shader_info.h                           | 1 +
 10 files changed, 22 insertions(+)

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index e5d83e9b4b..bf2210899f 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -790,6 +790,9 @@ void EmitContext::DefineInputs(const Info& info) {
     if (info.uses_local_invocation_id) {
         local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
     }
+    if (info.uses_is_helper_invocation) {
+        is_helper_invocation = DefineInput(*this, U1, spv::BuiltIn::HelperInvocation);
+    }
     if (info.uses_subgroup_mask) {
         subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
         subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index 34f38454fa..98a9140bf8 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -107,6 +107,7 @@ public:
 
     Id workgroup_id{};
     Id local_invocation_id{};
+    Id is_helper_invocation{};
     Id subgroup_local_invocation_id{};
     Id subgroup_mask_eq{};
     Id subgroup_mask_lt{};
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index a3398a6054..04340fa704 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -65,6 +65,7 @@ void EmitSetCFlag(EmitContext& ctx);
 void EmitSetOFlag(EmitContext& ctx);
 Id EmitWorkgroupId(EmitContext& ctx);
 Id EmitLocalInvocationId(EmitContext& ctx);
+Id EmitIsHelperInvocation(EmitContext& ctx);
 Id EmitLoadLocal(EmitContext& ctx, Id word_offset);
 void EmitWriteLocal(EmitContext& ctx, Id word_offset, Id value);
 Id EmitUndefU1(EmitContext& ctx);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 5dc150ce29..d552a1b529 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -274,6 +274,10 @@ Id EmitLocalInvocationId(EmitContext& ctx) {
     return ctx.OpLoad(ctx.U32[3], ctx.local_invocation_id);
 }
 
+Id EmitIsHelperInvocation(EmitContext& ctx) {
+    return ctx.OpLoad(ctx.U1, ctx.is_helper_invocation);
+}
+
 Id EmitLoadLocal(EmitContext& ctx, Id word_offset) {
     const Id pointer{ctx.OpAccessChain(ctx.private_u32, ctx.local_memory, word_offset)};
     return ctx.OpLoad(ctx.U32[1], pointer);
diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.cpp b/src/shader_recompiler/frontend/ir/ir_emitter.cpp
index a3339f624a..54a273a921 100644
--- a/src/shader_recompiler/frontend/ir/ir_emitter.cpp
+++ b/src/shader_recompiler/frontend/ir/ir_emitter.cpp
@@ -347,6 +347,10 @@ U32 IREmitter::LocalInvocationIdZ() {
     return U32{CompositeExtract(Inst(Opcode::LocalInvocationId), 2)};
 }
 
+U1 IREmitter::IsHelperInvocation() {
+    return Inst<U1>(Opcode::IsHelperInvocation);
+}
+
 U32 IREmitter::LaneId() {
     return Inst<U32>(Opcode::LaneId);
 }
diff --git a/src/shader_recompiler/frontend/ir/ir_emitter.h b/src/shader_recompiler/frontend/ir/ir_emitter.h
index f9cbf1304c..d04224707a 100644
--- a/src/shader_recompiler/frontend/ir/ir_emitter.h
+++ b/src/shader_recompiler/frontend/ir/ir_emitter.h
@@ -90,6 +90,8 @@ public:
     [[nodiscard]] U32 LocalInvocationIdY();
     [[nodiscard]] U32 LocalInvocationIdZ();
 
+    [[nodiscard]] U1 IsHelperInvocation();
+
     [[nodiscard]] U32 LaneId();
 
     [[nodiscard]] U32 LoadGlobalU8(const U64& address);
diff --git a/src/shader_recompiler/frontend/ir/opcodes.inc b/src/shader_recompiler/frontend/ir/opcodes.inc
index dc776a73e6..f700086822 100644
--- a/src/shader_recompiler/frontend/ir/opcodes.inc
+++ b/src/shader_recompiler/frontend/ir/opcodes.inc
@@ -58,6 +58,7 @@ OPCODE(SetCFlag,                                            Void,           U1,
 OPCODE(SetOFlag,                                            Void,           U1,                                                                             )
 OPCODE(WorkgroupId,                                         U32x3,                                                                                          )
 OPCODE(LocalInvocationId,                                   U32x3,                                                                                          )
+OPCODE(IsHelperInvocation,                                  U1,                                                                                             )
 
 // Undefined
 OPCODE(UndefU1,                                             U1,                                                                                             )
diff --git a/src/shader_recompiler/frontend/maxwell/translate/impl/move_special_register.cpp b/src/shader_recompiler/frontend/maxwell/translate/impl/move_special_register.cpp
index be1f21e7b4..50650cc567 100644
--- a/src/shader_recompiler/frontend/maxwell/translate/impl/move_special_register.cpp
+++ b/src/shader_recompiler/frontend/maxwell/translate/impl/move_special_register.cpp
@@ -113,6 +113,8 @@ enum class SpecialRegister : u64 {
 
 [[nodiscard]] IR::U32 Read(IR::IREmitter& ir, SpecialRegister special_register) {
     switch (special_register) {
+    case SpecialRegister::SR_THREAD_KILL:
+        return IR::U32{ir.Select(ir.IsHelperInvocation(), ir.Imm32(-1), ir.Imm32(0))};
     case SpecialRegister::SR_TID_X:
         return ir.LocalInvocationIdX();
     case SpecialRegister::SR_TID_Y:
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index 73373576b1..c80d2d29ce 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -348,6 +348,9 @@ void VisitUsages(Info& info, IR::Inst& inst) {
     case IR::Opcode::LocalInvocationId:
         info.uses_local_invocation_id = true;
         break;
+    case IR::Opcode::IsHelperInvocation:
+        info.uses_is_helper_invocation = true;
+        break;
     case IR::Opcode::LaneId:
     case IR::Opcode::ShuffleIndex:
     case IR::Opcode::ShuffleUp:
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index 7bcecf5549..aa204ae37e 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -92,6 +92,7 @@ struct Info {
 
     bool uses_workgroup_id{};
     bool uses_local_invocation_id{};
+    bool uses_is_helper_invocation{};
     bool uses_subgroup_invocation_id{};
 
     std::array<InputVarying, 32> input_generics{};