shader: Fold comparisons and Pack/Unpack16
This commit is contained in:
parent
b4a5e767d0
commit
c3bace756f
@ -3,9 +3,9 @@
|
|||||||
// Refer to the license.txt file included.
|
// Refer to the license.txt file included.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <ranges>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <ranges>
|
|
||||||
|
|
||||||
#include "common/bit_cast.h"
|
#include "common/bit_cast.h"
|
||||||
#include "common/bit_util.h"
|
#include "common/bit_util.h"
|
||||||
@ -332,6 +332,18 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
|
||||||
|
const IR::Value value{inst.Arg(0)};
|
||||||
|
if (value.IsImmediate()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
IR::Inst* const arg_inst{value.InstRecursive()};
|
||||||
|
if (arg_inst->Opcode() == reverse) {
|
||||||
|
inst.ReplaceUsesWith(arg_inst->Arg(0));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Func, size_t... I>
|
template <typename Func, size_t... I>
|
||||||
IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) {
|
IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) {
|
||||||
using Traits = LambdaTraits<decltype(func)>;
|
using Traits = LambdaTraits<decltype(func)>;
|
||||||
@ -372,6 +384,10 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
|
|||||||
return FoldBitCast<IR::Opcode::BitCastU32F32, u32, f32>(inst, IR::Opcode::BitCastF32U32);
|
return FoldBitCast<IR::Opcode::BitCastU32F32, u32, f32>(inst, IR::Opcode::BitCastF32U32);
|
||||||
case IR::Opcode::IAdd64:
|
case IR::Opcode::IAdd64:
|
||||||
return FoldAdd<u64>(block, inst);
|
return FoldAdd<u64>(block, inst);
|
||||||
|
case IR::Opcode::PackHalf2x16:
|
||||||
|
return FoldInverseFunc(inst, IR::Opcode::UnpackHalf2x16);
|
||||||
|
case IR::Opcode::UnpackHalf2x16:
|
||||||
|
return FoldInverseFunc(inst, IR::Opcode::PackHalf2x16);
|
||||||
case IR::Opcode::SelectU1:
|
case IR::Opcode::SelectU1:
|
||||||
case IR::Opcode::SelectU8:
|
case IR::Opcode::SelectU8:
|
||||||
case IR::Opcode::SelectU16:
|
case IR::Opcode::SelectU16:
|
||||||
@ -395,6 +411,30 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
|
|||||||
case IR::Opcode::ULessThan:
|
case IR::Opcode::ULessThan:
|
||||||
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
|
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
|
||||||
return;
|
return;
|
||||||
|
case IR::Opcode::SLessThanEqual:
|
||||||
|
FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a <= b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::ULessThanEqual:
|
||||||
|
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a <= b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::SGreaterThan:
|
||||||
|
FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a > b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::UGreaterThan:
|
||||||
|
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a > b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::SGreaterThanEqual:
|
||||||
|
FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a >= b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::UGreaterThanEqual:
|
||||||
|
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a >= b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::IEqual:
|
||||||
|
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a == b; });
|
||||||
|
return;
|
||||||
|
case IR::Opcode::INotEqual:
|
||||||
|
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a != b; });
|
||||||
|
return;
|
||||||
case IR::Opcode::BitFieldUExtract:
|
case IR::Opcode::BitFieldUExtract:
|
||||||
FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
|
FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
|
||||||
if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
|
if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user