Skip to content

Commit

Permalink
opt: add Int16 and Float16 to capability trim pass (KhronosGroup#5519)
Browse files Browse the repository at this point in the history
Add support for Int16 and Float16 trim.

Signed-off-by: Nathan Gauër <brioche@google.com>
  • Loading branch information
Keenuts authored Jan 4, 2024
1 parent 0a9f3d1 commit c7affa1
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 1 deletion.
24 changes: 23 additions & 1 deletion source/opt/trim_capabilities_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ static bool Has16BitCapability(const FeatureManager* feature_manager) {
// Handler names follow the following convention:
// Handler_<Opcode>_<Capability>()

static std::optional<spv::Capability> Handler_OpTypeFloat_Float16(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpTypeFloat &&
"This handler only support OpTypeFloat opcodes.");

const uint32_t size =
instruction->GetSingleWordInOperand(kOpTypeFloatSizeIndex);
return size == 16 ? std::optional(spv::Capability::Float16) : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeFloat_Float64(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpTypeFloat &&
Expand Down Expand Up @@ -274,6 +284,16 @@ static std::optional<spv::Capability> Handler_OpTypePointer_StorageUniform16(
: std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeInt_Int16(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpTypeInt &&
"This handler only support OpTypeInt opcodes.");

const uint32_t size =
instruction->GetSingleWordInOperand(kOpTypeIntSizeIndex);
return size == 16 ? std::optional(spv::Capability::Int16) : std::nullopt;
}

static std::optional<spv::Capability> Handler_OpTypeInt_Int64(
const Instruction* instruction) {
assert(instruction->opcode() == spv::Op::OpTypeInt &&
Expand Down Expand Up @@ -341,12 +361,14 @@ Handler_OpImageSparseRead_StorageImageReadWithoutFormat(
}

// Opcode of interest to determine capabilities requirements.
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 10> kOpcodeHandlers{{
constexpr std::array<std::pair<spv::Op, OpcodeHandler>, 12> kOpcodeHandlers{{
// clang-format off
{spv::Op::OpImageRead, Handler_OpImageRead_StorageImageReadWithoutFormat},
{spv::Op::OpImageSparseRead, Handler_OpImageSparseRead_StorageImageReadWithoutFormat},
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float16 },
{spv::Op::OpTypeFloat, Handler_OpTypeFloat_Float64 },
{spv::Op::OpTypeImage, Handler_OpTypeImage_ImageMSArray},
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int16 },
{spv::Op::OpTypeInt, Handler_OpTypeInt_Int64 },
{spv::Op::OpTypePointer, Handler_OpTypePointer_StorageInputOutput16},
{spv::Op::OpTypePointer, Handler_OpTypePointer_StoragePushConstant16},
Expand Down
2 changes: 2 additions & 0 deletions source/opt/trim_capabilities_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ class TrimCapabilitiesPass : public Pass {
// clang-format off
spv::Capability::ComputeDerivativeGroupLinearNV,
spv::Capability::ComputeDerivativeGroupQuadsNV,
spv::Capability::Float16,
spv::Capability::Float64,
spv::Capability::FragmentShaderPixelInterlockEXT,
spv::Capability::FragmentShaderSampleInterlockEXT,
spv::Capability::FragmentShaderShadingRateInterlockEXT,
spv::Capability::Groups,
spv::Capability::ImageMSArray,
spv::Capability::Int16,
spv::Capability::Int64,
spv::Capability::Linkage,
spv::Capability::MinLod,
Expand Down
98 changes: 98 additions & 0 deletions test/opt/trim_capabilities_pass_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2486,6 +2486,104 @@ TEST_F(TrimCapabilitiesPassTest,
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
}

TEST_F(TrimCapabilitiesPassTest, Float16_RemovedWhenUnused) {
const std::string kTest = R"(
OpCapability Float16
; CHECK-NOT: OpCapability Float16
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%1 = OpFunction %void None %3
%6 = OpLabel
OpReturn
OpFunctionEnd;
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
}

TEST_F(TrimCapabilitiesPassTest, Float16_RemainsWhenUsed) {
const std::string kTest = R"(
OpCapability Float16
; CHECK: OpCapability Float16
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
%void = OpTypeVoid
%float = OpTypeFloat 16
%3 = OpTypeFunction %void
%1 = OpFunction %void None %3
%6 = OpLabel
OpReturn
OpFunctionEnd;
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
}

TEST_F(TrimCapabilitiesPassTest, Int16_RemovedWhenUnused) {
const std::string kTest = R"(
OpCapability Int16
; CHECK-NOT: OpCapability Int16
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
%void = OpTypeVoid
%3 = OpTypeFunction %void
%1 = OpFunction %void None %3
%6 = OpLabel
OpReturn
OpFunctionEnd;
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithChange);
}

TEST_F(TrimCapabilitiesPassTest, Int16_RemainsWhenUsed) {
const std::string kTest = R"(
OpCapability Int16
; CHECK: OpCapability Int16
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
%void = OpTypeVoid
%int = OpTypeInt 16 1
%3 = OpTypeFunction %void
%1 = OpFunction %void None %3
%6 = OpLabel
OpReturn
OpFunctionEnd;
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
}

TEST_F(TrimCapabilitiesPassTest, UInt16_RemainsWhenUsed) {
const std::string kTest = R"(
OpCapability Int16
; CHECK: OpCapability Int16
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %1 "main"
%void = OpTypeVoid
%uint = OpTypeInt 16 0
%3 = OpTypeFunction %void
%1 = OpFunction %void None %3
%6 = OpLabel
OpReturn
OpFunctionEnd;
)";
const auto result =
SinglePassRunAndMatch<TrimCapabilitiesPass>(kTest, /* skip_nop= */ false);
EXPECT_EQ(std::get<1>(result), Pass::Status::SuccessWithoutChange);
}

} // namespace
} // namespace opt
} // namespace spvtools

0 comments on commit c7affa1

Please sign in to comment.