shader: Support SSA loops on IR

This commit is contained in:
ReinUsesLisp 2021-02-14 20:15:42 -03:00 committed by ameerj
parent 8af9297f09
commit cbfb7d182a
12 changed files with 150 additions and 46 deletions

View File

@ -32,6 +32,8 @@ add_executable(shader_recompiler
frontend/ir/opcodes.cpp frontend/ir/opcodes.cpp
frontend/ir/opcodes.h frontend/ir/opcodes.h
frontend/ir/opcodes.inc frontend/ir/opcodes.inc
frontend/ir/post_order.cpp
frontend/ir/post_order.h
frontend/ir/pred.h frontend/ir/pred.h
frontend/ir/program.cpp frontend/ir/program.cpp
frontend/ir/program.h frontend/ir/program.h

View File

@ -159,10 +159,10 @@ private:
Id EmitWorkgroupId(EmitContext& ctx); Id EmitWorkgroupId(EmitContext& ctx);
Id EmitLocalInvocationId(EmitContext& ctx); Id EmitLocalInvocationId(EmitContext& ctx);
Id EmitUndefU1(EmitContext& ctx); Id EmitUndefU1(EmitContext& ctx);
void EmitUndefU8(EmitContext& ctx); Id EmitUndefU8(EmitContext& ctx);
void EmitUndefU16(EmitContext& ctx); Id EmitUndefU16(EmitContext& ctx);
void EmitUndefU32(EmitContext& ctx); Id EmitUndefU32(EmitContext& ctx);
void EmitUndefU64(EmitContext& ctx); Id EmitUndefU64(EmitContext& ctx);
void EmitLoadGlobalU8(EmitContext& ctx); void EmitLoadGlobalU8(EmitContext& ctx);
void EmitLoadGlobalS8(EmitContext& ctx); void EmitLoadGlobalS8(EmitContext& ctx);
void EmitLoadGlobalU16(EmitContext& ctx); void EmitLoadGlobalU16(EmitContext& ctx);
@ -297,12 +297,12 @@ private:
void EmitBitFieldInsert(EmitContext& ctx); void EmitBitFieldInsert(EmitContext& ctx);
void EmitBitFieldSExtract(EmitContext& ctx); void EmitBitFieldSExtract(EmitContext& ctx);
Id EmitBitFieldUExtract(EmitContext& ctx, Id base, Id offset, Id count); Id EmitBitFieldUExtract(EmitContext& ctx, Id base, Id offset, Id count);
void EmitSLessThan(EmitContext& ctx); Id EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs);
void EmitULessThan(EmitContext& ctx); void EmitULessThan(EmitContext& ctx);
void EmitIEqual(EmitContext& ctx); void EmitIEqual(EmitContext& ctx);
void EmitSLessThanEqual(EmitContext& ctx); void EmitSLessThanEqual(EmitContext& ctx);
void EmitULessThanEqual(EmitContext& ctx); void EmitULessThanEqual(EmitContext& ctx);
void EmitSGreaterThan(EmitContext& ctx); Id EmitSGreaterThan(EmitContext& ctx, Id lhs, Id rhs);
void EmitUGreaterThan(EmitContext& ctx); void EmitUGreaterThan(EmitContext& ctx);
void EmitINotEqual(EmitContext& ctx); void EmitINotEqual(EmitContext& ctx);
void EmitSGreaterThanEqual(EmitContext& ctx); void EmitSGreaterThanEqual(EmitContext& ctx);

View File

@ -73,8 +73,8 @@ Id EmitSPIRV::EmitBitFieldUExtract(EmitContext& ctx, Id base, Id offset, Id coun
return ctx.OpBitFieldUExtract(ctx.u32[1], base, offset, count); return ctx.OpBitFieldUExtract(ctx.u32[1], base, offset, count);
} }
void EmitSPIRV::EmitSLessThan(EmitContext&) { Id EmitSPIRV::EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs) {
throw NotImplementedException("SPIR-V Instruction"); return ctx.OpSLessThan(ctx.u1, lhs, rhs);
} }
void EmitSPIRV::EmitULessThan(EmitContext&) { void EmitSPIRV::EmitULessThan(EmitContext&) {
@ -93,8 +93,8 @@ void EmitSPIRV::EmitULessThanEqual(EmitContext&) {
throw NotImplementedException("SPIR-V Instruction"); throw NotImplementedException("SPIR-V Instruction");
} }
void EmitSPIRV::EmitSGreaterThan(EmitContext&) { Id EmitSPIRV::EmitSGreaterThan(EmitContext& ctx, Id lhs, Id rhs) {
throw NotImplementedException("SPIR-V Instruction"); return ctx.OpSGreaterThan(ctx.u1, lhs, rhs);
} }
void EmitSPIRV::EmitUGreaterThan(EmitContext&) { void EmitSPIRV::EmitUGreaterThan(EmitContext&) {

View File

@ -10,19 +10,19 @@ Id EmitSPIRV::EmitUndefU1(EmitContext& ctx) {
return ctx.OpUndef(ctx.u1); return ctx.OpUndef(ctx.u1);
} }
void EmitSPIRV::EmitUndefU8(EmitContext&) { Id EmitSPIRV::EmitUndefU8(EmitContext&) {
throw NotImplementedException("SPIR-V Instruction"); throw NotImplementedException("SPIR-V Instruction");
} }
void EmitSPIRV::EmitUndefU16(EmitContext&) { Id EmitSPIRV::EmitUndefU16(EmitContext&) {
throw NotImplementedException("SPIR-V Instruction"); throw NotImplementedException("SPIR-V Instruction");
} }
void EmitSPIRV::EmitUndefU32(EmitContext&) { Id EmitSPIRV::EmitUndefU32(EmitContext& ctx) {
throw NotImplementedException("SPIR-V Instruction"); return ctx.OpUndef(ctx.u32[1]);
} }
void EmitSPIRV::EmitUndefU64(EmitContext&) { Id EmitSPIRV::EmitUndefU64(EmitContext&) {
throw NotImplementedException("SPIR-V Instruction"); throw NotImplementedException("SPIR-V Instruction");
} }

View File

@ -12,6 +12,7 @@ namespace Shader::IR {
struct Function { struct Function {
BlockList blocks; BlockList blocks;
BlockList post_order_blocks;
}; };
} // namespace Shader::IR } // namespace Shader::IR

View File

@ -0,0 +1,48 @@
// Copyright 2021 yuzu Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#include <boost/container/flat_set.hpp>
#include <boost/container/small_vector.hpp>
#include "shader_recompiler/frontend/ir/basic_block.h"
#include "shader_recompiler/frontend/ir/post_order.h"
namespace Shader::IR {
BlockList PostOrder(const BlockList& blocks) {
boost::container::small_vector<Block*, 16> block_stack;
boost::container::flat_set<Block*> visited;
BlockList post_order_blocks;
post_order_blocks.reserve(blocks.size());
Block* const first_block{blocks.front()};
visited.insert(first_block);
block_stack.push_back(first_block);
const auto visit_branch = [&](Block* block, Block* branch) {
if (!branch) {
return false;
}
if (!visited.insert(branch).second) {
return false;
}
// Calling push_back twice is faster than insert on msvc
block_stack.push_back(block);
block_stack.push_back(branch);
return true;
};
while (!block_stack.empty()) {
Block* const block{block_stack.back()};
block_stack.pop_back();
if (!visit_branch(block, block->TrueBranch()) &&
!visit_branch(block, block->FalseBranch())) {
post_order_blocks.push_back(block);
}
}
return post_order_blocks;
}
} // namespace Shader::IR

View File

@ -0,0 +1,13 @@
// Copyright 2021 yuzu Emulator Project
// Licensed under GPLv2 or any later version
// Refer to the license.txt file included.
#pragma once
#include "shader_recompiler/frontend/ir/basic_block.h"
namespace Shader::IR {
BlockList PostOrder(const BlockList& blocks);
} // namespace Shader::IR

View File

@ -7,6 +7,7 @@
#include <vector> #include <vector>
#include "shader_recompiler/frontend/ir/basic_block.h" #include "shader_recompiler/frontend/ir/basic_block.h"
#include "shader_recompiler/frontend/ir/post_order.h"
#include "shader_recompiler/frontend/ir/structured_control_flow.h" #include "shader_recompiler/frontend/ir/structured_control_flow.h"
#include "shader_recompiler/frontend/maxwell/program.h" #include "shader_recompiler/frontend/maxwell/program.h"
#include "shader_recompiler/frontend/maxwell/translate/translate.h" #include "shader_recompiler/frontend/maxwell/translate/translate.h"
@ -56,11 +57,14 @@ IR::Program TranslateProgram(ObjectPool<IR::Inst>& inst_pool, ObjectPool<IR::Blo
} }
fmt::print(stdout, "No optimizations: {}", IR::DumpProgram(program)); fmt::print(stdout, "No optimizations: {}", IR::DumpProgram(program));
std::ranges::for_each(functions, Optimization::SsaRewritePass);
for (IR::Function& function : functions) { for (IR::Function& function : functions) {
Optimization::Invoke(Optimization::GlobalMemoryToStorageBufferPass, function); function.post_order_blocks = PostOrder(function.blocks);
Optimization::Invoke(Optimization::ConstantPropagationPass, function); Optimization::SsaRewritePass(function.post_order_blocks);
Optimization::Invoke(Optimization::DeadCodeEliminationPass, function); }
for (IR::Function& function : functions) {
Optimization::PostOrderInvoke(Optimization::GlobalMemoryToStorageBufferPass, function);
Optimization::PostOrderInvoke(Optimization::ConstantPropagationPass, function);
Optimization::PostOrderInvoke(Optimization::DeadCodeEliminationPass, function);
Optimization::IdentityRemovalPass(function); Optimization::IdentityRemovalPass(function);
Optimization::VerificationPass(function); Optimization::VerificationPass(function);
} }

View File

@ -13,7 +13,7 @@ namespace Shader::Optimization {
void DeadCodeEliminationPass(IR::Block& block) { void DeadCodeEliminationPass(IR::Block& block) {
// We iterate over the instructions in reverse order. // We iterate over the instructions in reverse order.
// This is because removing an instruction reduces the number of uses for earlier instructions. // This is because removing an instruction reduces the number of uses for earlier instructions.
for (IR::Inst& inst : std::views::reverse(block)) { for (IR::Inst& inst : block | std::views::reverse) {
if (!inst.HasUses() && !inst.MayHaveSideEffects()) { if (!inst.HasUses() && !inst.MayHaveSideEffects()) {
inst.Invalidate(); inst.Invalidate();
} }

View File

@ -4,14 +4,16 @@
#pragma once #pragma once
#include <span>
#include "shader_recompiler/frontend/ir/basic_block.h" #include "shader_recompiler/frontend/ir/basic_block.h"
#include "shader_recompiler/frontend/ir/function.h" #include "shader_recompiler/frontend/ir/function.h"
namespace Shader::Optimization { namespace Shader::Optimization {
template <typename Func> template <typename Func>
void Invoke(Func&& func, IR::Function& function) { void PostOrderInvoke(Func&& func, IR::Function& function) {
for (const auto& block : function.blocks) { for (const auto& block : function.post_order_blocks) {
func(*block); func(*block);
} }
} }
@ -20,7 +22,7 @@ void ConstantPropagationPass(IR::Block& block);
void DeadCodeEliminationPass(IR::Block& block); void DeadCodeEliminationPass(IR::Block& block);
void GlobalMemoryToStorageBufferPass(IR::Block& block); void GlobalMemoryToStorageBufferPass(IR::Block& block);
void IdentityRemovalPass(IR::Function& function); void IdentityRemovalPass(IR::Function& function);
void SsaRewritePass(IR::Function& function); void SsaRewritePass(std::span<IR::Block* const> post_order_blocks);
void VerificationPass(const IR::Function& function); void VerificationPass(const IR::Function& function);
} // namespace Shader::Optimization } // namespace Shader::Optimization

View File

@ -14,7 +14,13 @@
// https://link.springer.com/chapter/10.1007/978-3-642-37051-9_6 // https://link.springer.com/chapter/10.1007/978-3-642-37051-9_6
// //
#include <ranges>
#include <span>
#include <variant>
#include <vector>
#include <boost/container/flat_map.hpp> #include <boost/container/flat_map.hpp>
#include <boost/container/flat_set.hpp>
#include "shader_recompiler/frontend/ir/basic_block.h" #include "shader_recompiler/frontend/ir/basic_block.h"
#include "shader_recompiler/frontend/ir/function.h" #include "shader_recompiler/frontend/ir/function.h"
@ -26,9 +32,9 @@
namespace Shader::Optimization { namespace Shader::Optimization {
namespace { namespace {
using ValueMap = boost::container::flat_map<IR::Block*, IR::Value, std::less<IR::Block*>>; struct FlagTag {
auto operator<=>(const FlagTag&) const noexcept = default;
struct FlagTag {}; };
struct ZeroFlagTag : FlagTag {}; struct ZeroFlagTag : FlagTag {};
struct SignFlagTag : FlagTag {}; struct SignFlagTag : FlagTag {};
struct CarryFlagTag : FlagTag {}; struct CarryFlagTag : FlagTag {};
@ -38,9 +44,15 @@ struct GotoVariable : FlagTag {
GotoVariable() = default; GotoVariable() = default;
explicit GotoVariable(u32 index_) : index{index_} {} explicit GotoVariable(u32 index_) : index{index_} {}
auto operator<=>(const GotoVariable&) const noexcept = default;
u32 index; u32 index;
}; };
using Variant = std::variant<IR::Reg, IR::Pred, ZeroFlagTag, SignFlagTag, CarryFlagTag,
OverflowFlagTag, GotoVariable>;
using ValueMap = boost::container::flat_map<IR::Block*, IR::Value, std::less<IR::Block*>>;
struct DefTable { struct DefTable {
[[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept { [[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept {
return regs[IR::RegIndex(variable)]; return regs[IR::RegIndex(variable)];
@ -102,19 +114,35 @@ public:
} }
IR::Value ReadVariable(auto variable, IR::Block* block) { IR::Value ReadVariable(auto variable, IR::Block* block) {
auto& def{current_def[variable]}; const ValueMap& def{current_def[variable]};
if (const auto it{def.find(block)}; it != def.end()) { if (const auto it{def.find(block)}; it != def.end()) {
return it->second; return it->second;
} }
return ReadVariableRecursive(variable, block); return ReadVariableRecursive(variable, block);
} }
void SealBlock(IR::Block* block) {
const auto it{incomplete_phis.find(block)};
if (it != incomplete_phis.end()) {
for (auto& [variant, phi] : it->second) {
std::visit([&](auto& variable) { AddPhiOperands(variable, *phi, block); }, variant);
}
}
sealed_blocks.insert(block);
}
private: private:
IR::Value ReadVariableRecursive(auto variable, IR::Block* block) { IR::Value ReadVariableRecursive(auto variable, IR::Block* block) {
IR::Value val; IR::Value val;
if (const std::span preds{block->ImmediatePredecessors()}; preds.size() == 1) { if (!sealed_blocks.contains(block)) {
// Incomplete CFG
IR::Inst* phi{&*block->PrependNewInst(block->begin(), IR::Opcode::Phi)};
incomplete_phis[block].insert_or_assign(variable, phi);
val = IR::Value{&*phi};
} else if (const std::span imm_preds{block->ImmediatePredecessors()};
imm_preds.size() == 1) {
// Optimize the common case of one predecessor: no phi needed // Optimize the common case of one predecessor: no phi needed
val = ReadVariable(variable, preds.front()); val = ReadVariable(variable, imm_preds.front());
} else { } else {
// Break potential cycles with operandless phi // Break potential cycles with operandless phi
IR::Inst& phi_inst{*block->PrependNewInst(block->begin(), IR::Opcode::Phi)}; IR::Inst& phi_inst{*block->PrependNewInst(block->begin(), IR::Opcode::Phi)};
@ -127,8 +155,8 @@ private:
} }
IR::Value AddPhiOperands(auto variable, IR::Inst& phi, IR::Block* block) { IR::Value AddPhiOperands(auto variable, IR::Inst& phi, IR::Block* block) {
for (IR::Block* const pred : block->ImmediatePredecessors()) { for (IR::Block* const imm_pred : block->ImmediatePredecessors()) {
phi.AddPhiOperand(pred, ReadVariable(variable, pred)); phi.AddPhiOperand(imm_pred, ReadVariable(variable, imm_pred));
} }
return TryRemoveTrivialPhi(phi, block, UndefOpcode(variable)); return TryRemoveTrivialPhi(phi, block, UndefOpcode(variable));
} }
@ -159,6 +187,9 @@ private:
return same; return same;
} }
boost::container::flat_set<IR::Block*> sealed_blocks;
boost::container::flat_map<IR::Block*, boost::container::flat_map<Variant, IR::Inst*>>
incomplete_phis;
DefTable current_def; DefTable current_def;
}; };
@ -218,14 +249,19 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
break; break;
} }
} }
} // Anonymous namespace
void SsaRewritePass(IR::Function& function) { void VisitBlock(Pass& pass, IR::Block* block) {
Pass pass;
for (IR::Block* const block : function.blocks) {
for (IR::Inst& inst : block->Instructions()) { for (IR::Inst& inst : block->Instructions()) {
VisitInst(pass, block, inst); VisitInst(pass, block, inst);
} }
pass.SealBlock(block);
}
} // Anonymous namespace
void SsaRewritePass(std::span<IR::Block* const> post_order_blocks) {
Pass pass;
for (IR::Block* const block : post_order_blocks | std::views::reverse) {
VisitBlock(pass, block);
} }
} }

View File

@ -69,7 +69,6 @@ int main() {
// FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; // FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"};
FileEnvironment env{"D:\\Shaders\\shader.bin"}; FileEnvironment env{"D:\\Shaders\\shader.bin"};
for (int i = 0; i < 1; ++i) {
block_pool->ReleaseContents(); block_pool->ReleaseContents();
inst_pool->ReleaseContents(); inst_pool->ReleaseContents();
flow_block_pool->ReleaseContents(); flow_block_pool->ReleaseContents();
@ -77,6 +76,5 @@ int main() {
fmt::print(stdout, "{}\n", cfg.Dot()); fmt::print(stdout, "{}\n", cfg.Dot());
IR::Program program{TranslateProgram(*inst_pool, *block_pool, env, cfg)}; IR::Program program{TranslateProgram(*inst_pool, *block_pool, env, cfg)};
fmt::print(stdout, "{}\n", IR::DumpProgram(program)); fmt::print(stdout, "{}\n", IR::DumpProgram(program));
Backend::SPIRV::EmitSPIRV spirv{program}; // Backend::SPIRV::EmitSPIRV spirv{program};
}
} }