From ee61ec2c39e6db53c56e7ac761a2223d99f06908 Mon Sep 17 00:00:00 2001
From: FernandoS27 <fsahmkow27@gmail.com>
Date: Sun, 2 May 2021 01:50:27 +0200
Subject: [PATCH] shader: Optimize NVN Fallthrough

---
 .../backend/spirv/emit_context.cpp            |  3 +
 .../frontend/maxwell/program.cpp              |  7 +-
 .../ir_opt/collect_shader_info_pass.cpp       | 78 +++++++++++++++++--
 src/shader_recompiler/shader_info.h           |  4 +
 4 files changed, 83 insertions(+), 9 deletions(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 549df0d4bd..be88b76f7e 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -840,6 +840,9 @@ void EmitContext::DefineGlobalMemoryFunctions(const Info& info) {
         AddLabel();
         const size_t num_buffers{info.storage_buffers_descriptors.size()};
         for (size_t index = 0; index < num_buffers; ++index) {
+            if (!info.nvn_buffer_used[index]) {
+                continue;
+            }
             const auto& ssbo{info.storage_buffers_descriptors[index]};
             const Id ssbo_addr_cbuf_offset{Const(ssbo.cbuf_offset / 8)};
             const Id ssbo_size_cbuf_offset{Const(ssbo.cbuf_offset / 4 + 2)};
diff --git a/src/shader_recompiler/frontend/maxwell/program.cpp b/src/shader_recompiler/frontend/maxwell/program.cpp
index a4fa4319da..0d3f006991 100644
--- a/src/shader_recompiler/frontend/maxwell/program.cpp
+++ b/src/shader_recompiler/frontend/maxwell/program.cpp
@@ -88,17 +88,20 @@ void AddNVNStorageBuffers(IR::Program& program) {
     }()};
     auto& descs{program.info.storage_buffers_descriptors};
     for (u32 index = 0; index < num_buffers; ++index) {
+        if (!program.info.nvn_buffer_used[index]) {
+            continue;
+        }
         const u32 offset{base + index * descriptor_size};
         const auto it{std::ranges::find(descs, offset, &StorageBufferDescriptor::cbuf_offset)};
         if (it != descs.end()) {
+            it->is_written |= program.info.stores_global_memory;
             continue;
         }
-        // Assume these are written for now
         descs.push_back({
             .cbuf_index = driver_cbuf,
             .cbuf_offset = offset,
             .count = 1,
-            .is_written = true,
+            .is_written = program.info.stores_global_memory,
         });
     }
 }
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index 7d8794a7e7..13b793d572 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -132,6 +132,30 @@ void SetPatch(Info& info, IR::Patch patch) {
     }
 }
 
+void CheckCBufNVN(Info& info, IR::Inst& inst) {
+    const IR::Value cbuf_index{inst.Arg(0)};
+    if (!cbuf_index.IsImmediate()) {
+        info.nvn_buffer_used.set();
+        return;
+    }
+    const u32 index{cbuf_index.U32()};
+    if (index != 0) {
+        return;
+    }
+    const IR::Value cbuf_offset{inst.Arg(1)};
+    if (!cbuf_offset.IsImmediate()) {
+        info.nvn_buffer_used.set();
+        return;
+    }
+    const u32 offset{cbuf_offset.U32()};
+    const u32 descriptor_size{0x10};
+    const u32 upper_limit{info.nvn_buffer_base + descriptor_size * 16};
+    if (offset >= info.nvn_buffer_base && offset < upper_limit) {
+        const std::size_t nvn_index{(offset - info.nvn_buffer_base) / descriptor_size};
+        info.nvn_buffer_used.set(nvn_index, true);
+    }
+}
+
 void VisitUsages(Info& info, IR::Inst& inst) {
     switch (inst.GetOpcode()) {
     case IR::Opcode::CompositeConstructF16x2:
@@ -382,13 +406,6 @@ void VisitUsages(Info& info, IR::Inst& inst) {
         break;
     }
     switch (inst.GetOpcode()) {
-    case IR::Opcode::LoadGlobalU8:
-    case IR::Opcode::LoadGlobalS8:
-    case IR::Opcode::LoadGlobalU16:
-    case IR::Opcode::LoadGlobalS16:
-    case IR::Opcode::LoadGlobal32:
-    case IR::Opcode::LoadGlobal64:
-    case IR::Opcode::LoadGlobal128:
     case IR::Opcode::WriteGlobalU8:
     case IR::Opcode::WriteGlobalS8:
     case IR::Opcode::WriteGlobalU16:
@@ -423,6 +440,15 @@ void VisitUsages(Info& info, IR::Inst& inst) {
     case IR::Opcode::GlobalAtomicMinF32x2:
     case IR::Opcode::GlobalAtomicMaxF16x2:
     case IR::Opcode::GlobalAtomicMaxF32x2:
+        info.stores_global_memory = true;
+        [[fallthrough]];
+    case IR::Opcode::LoadGlobalU8:
+    case IR::Opcode::LoadGlobalS8:
+    case IR::Opcode::LoadGlobalU16:
+    case IR::Opcode::LoadGlobalS16:
+    case IR::Opcode::LoadGlobal32:
+    case IR::Opcode::LoadGlobal64:
+    case IR::Opcode::LoadGlobal128:
         info.uses_int64 = true;
         info.uses_global_memory = true;
         info.used_constant_buffer_types |= IR::Type::U32 | IR::Type::U32x2;
@@ -800,9 +826,27 @@ void VisitFpModifiers(Info& info, IR::Inst& inst) {
     }
 }
 
+void VisitCbufs(Info& info, IR::Inst& inst) {
+    switch (inst.GetOpcode()) {
+    case IR::Opcode::GetCbufU8:
+    case IR::Opcode::GetCbufS8:
+    case IR::Opcode::GetCbufU16:
+    case IR::Opcode::GetCbufS16:
+    case IR::Opcode::GetCbufU32:
+    case IR::Opcode::GetCbufF32:
+    case IR::Opcode::GetCbufU32x2: {
+        CheckCBufNVN(info, inst);
+        break;
+    }
+    default:
+        break;
+    }
+}
+
 void Visit(Info& info, IR::Inst& inst) {
     VisitUsages(info, inst);
     VisitFpModifiers(info, inst);
+    VisitCbufs(info, inst);
 }
 
 void GatherInfoFromHeader(Environment& env, Info& info) {
@@ -839,6 +883,26 @@ void GatherInfoFromHeader(Environment& env, Info& info) {
 
 void CollectShaderInfoPass(Environment& env, IR::Program& program) {
     Info& info{program.info};
+    const u32 base{[&] {
+        switch (program.stage) {
+        case Stage::VertexA:
+        case Stage::VertexB:
+            return 0x110u;
+        case Stage::TessellationControl:
+            return 0x210u;
+        case Stage::TessellationEval:
+            return 0x310u;
+        case Stage::Geometry:
+            return 0x410u;
+        case Stage::Fragment:
+            return 0x510u;
+        case Stage::Compute:
+            return 0x310u;
+        }
+        throw InvalidArgument("Invalid stage {}", program.stage);
+    }()};
+    info.nvn_buffer_base = base;
+
     for (IR::Block* const block : program.post_order_blocks) {
         for (IR::Inst& inst : block->Instructions()) {
             Visit(info, inst);
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index 2f6adf714e..a50a9a18c5 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -5,6 +5,7 @@
 #pragma once
 
 #include <array>
+#include <bitset>
 
 #include "common/common_types.h"
 #include "shader_recompiler/frontend/ir/type.h"
@@ -140,6 +141,7 @@ struct Info {
     bool stores_tess_level_outer{};
     bool stores_tess_level_inner{};
     bool stores_indexed_attributes{};
+    bool stores_global_memory{};
 
     bool uses_fp16{};
     bool uses_fp64{};
@@ -180,6 +182,8 @@ struct Info {
     IR::Type used_storage_buffer_types{};
 
     u32 constant_buffer_mask{};
+    u32 nvn_buffer_base{};
+    std::bitset<16> nvn_buffer_used{};
 
     boost::container::static_vector<ConstantBufferDescriptor, MAX_CBUFS>
         constant_buffer_descriptors;