blob: 0170dcb0d1b6f8c6c4728aa8ca37bfe0c00e6bce [file] [log] [blame]
// Copyright (c) 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fold_spec_constant_op_and_composite_pass.h"
#include <initializer_list>
#include <memory>
#include <tuple>
#include <unordered_map>
#include "constants.h"
#include "make_unique.h"
namespace spvtools {
namespace opt {
namespace {
// Returns the single-word result from performing the given unary operation on
// the operand value which is passed in as a 32-bit word.
uint32_t UnaryOperate(SpvOp opcode, uint32_t operand) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpSNegate:
return -static_cast<int32_t>(operand);
case SpvOp::SpvOpNot:
return ~operand;
case SpvOp::SpvOpLogicalNot:
return !static_cast<bool>(operand);
default:
assert(false &&
"Unsupported unary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given binary operation on
// the operand values which are passed in as two 32-bit word.
uint32_t BinaryOperate(SpvOp opcode, uint32_t a, uint32_t b) {
switch (opcode) {
// Arthimetics
case SpvOp::SpvOpIAdd:
return a + b;
case SpvOp::SpvOpISub:
return a - b;
case SpvOp::SpvOpIMul:
return a * b;
case SpvOp::SpvOpUDiv:
assert(b != 0);
return a / b;
case SpvOp::SpvOpSDiv:
assert(b != 0u);
return (static_cast<int32_t>(a)) / (static_cast<int32_t>(b));
case SpvOp::SpvOpSRem: {
// The sign of non-zero result comes from the first operand: a. This is
// guaranteed by C++11 rules for integer division operator. The division
// result is rounded toward zero, so the result of '%' has the sign of
// the first operand.
assert(b != 0u);
return static_cast<int32_t>(a) % static_cast<int32_t>(b);
}
case SpvOp::SpvOpSMod: {
// The sign of non-zero result comes from the second operand: b
assert(b != 0u);
int32_t rem = BinaryOperate(SpvOp::SpvOpSRem, a, b);
int32_t b_prim = static_cast<int32_t>(b);
return (rem + b_prim) % b_prim;
}
case SpvOp::SpvOpUMod:
assert(b != 0u);
return (a % b);
// Shifting
case SpvOp::SpvOpShiftRightLogical: {
return a >> b;
}
case SpvOp::SpvOpShiftRightArithmetic:
return (static_cast<int32_t>(a)) >> b;
case SpvOp::SpvOpShiftLeftLogical:
return a << b;
// Bitwise operations
case SpvOp::SpvOpBitwiseOr:
return a | b;
case SpvOp::SpvOpBitwiseAnd:
return a & b;
case SpvOp::SpvOpBitwiseXor:
return a ^ b;
// Logical
case SpvOp::SpvOpLogicalEqual:
return (static_cast<bool>(a)) == (static_cast<bool>(b));
case SpvOp::SpvOpLogicalNotEqual:
return (static_cast<bool>(a)) != (static_cast<bool>(b));
case SpvOp::SpvOpLogicalOr:
return (static_cast<bool>(a)) || (static_cast<bool>(b));
case SpvOp::SpvOpLogicalAnd:
return (static_cast<bool>(a)) && (static_cast<bool>(b));
// Comparison
case SpvOp::SpvOpIEqual:
return a == b;
case SpvOp::SpvOpINotEqual:
return a != b;
case SpvOp::SpvOpULessThan:
return a < b;
case SpvOp::SpvOpSLessThan:
return (static_cast<int32_t>(a)) < (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThan:
return a > b;
case SpvOp::SpvOpSGreaterThan:
return (static_cast<int32_t>(a)) > (static_cast<int32_t>(b));
case SpvOp::SpvOpULessThanEqual:
return a <= b;
case SpvOp::SpvOpSLessThanEqual:
return (static_cast<int32_t>(a)) <= (static_cast<int32_t>(b));
case SpvOp::SpvOpUGreaterThanEqual:
return a >= b;
case SpvOp::SpvOpSGreaterThanEqual:
return (static_cast<int32_t>(a)) >= (static_cast<int32_t>(b));
default:
assert(false &&
"Unsupported binary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given ternary operation
// on the operand values which are passed in as three 32-bit word.
uint32_t TernaryOperate(SpvOp opcode, uint32_t a, uint32_t b, uint32_t c) {
switch (opcode) {
case SpvOp::SpvOpSelect:
return (static_cast<bool>(a)) ? b : c;
default:
assert(false &&
"Unsupported ternary operation for OpSpecConstantOp instruction");
return 0u;
}
}
// Returns the single-word result from performing the given operation on the
// operand words. This only works with 32-bit operations and uses boolean
// convention that 0u is false, and anything else is boolean true.
// TODO(qining): Support operands other than 32-bit wide.
uint32_t OperateWords(SpvOp opcode,
const std::vector<uint32_t>& operand_words) {
switch (operand_words.size()) {
case 1:
return UnaryOperate(opcode, operand_words.front());
case 2:
return BinaryOperate(opcode, operand_words.front(), operand_words.back());
case 3:
return TernaryOperate(opcode, operand_words[0], operand_words[1],
operand_words[2]);
default:
assert(false && "Invalid number of operands");
return 0;
}
}
// Returns the result of performing an operation on scalar constant operands.
// This function extracts the operand values as 32 bit words and returns the
// result in 32 bit word. Scalar constants with longer than 32-bit width are
// not accepted in this function.
uint32_t OperateScalars(SpvOp opcode,
const std::vector<analysis::Constant*>& operands) {
std::vector<uint32_t> operand_values_in_raw_words;
for (analysis::Constant* operand : operands) {
if (analysis::ScalarConstant* scalar = operand->AsScalarConstant()) {
const auto& scalar_words = scalar->words();
assert(scalar_words.size() == 1 &&
"Scalar constants with longer than 32-bit width are not allowed "
"in OperateScalars()");
operand_values_in_raw_words.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_in_raw_words.push_back(0u);
} else {
assert(false &&
"OperateScalars() only accepts ScalarConst or NullConst type of "
"constant");
}
}
return OperateWords(opcode, operand_values_in_raw_words);
}
// Returns the result of performing an operation over constant vectors. This
// function iterates through the given vector type constant operands and
// calculates the result for each element of the result vector to return.
// Vectors with longer than 32-bit scalar components are not accepted in this
// function.
std::vector<uint32_t> OperateVectors(
SpvOp opcode, uint32_t num_dims,
const std::vector<analysis::Constant*>& operands) {
std::vector<uint32_t> result;
for (uint32_t d = 0; d < num_dims; d++) {
std::vector<uint32_t> operand_values_for_one_dimension;
for (analysis::Constant* operand : operands) {
if (analysis::VectorConstant* vector_operand =
operand->AsVectorConstant()) {
// Extract the raw value of the scalar component constants
// in 32-bit words here. The reason of not using OperateScalars() here
// is that we do not create temporary null constants as components
// when the vector operand is a NullConstant because Constant creation
// may need extra checks for the validity and that is not manageed in
// here.
if (const analysis::ScalarConstant* scalar_component =
vector_operand->GetComponents().at(d)->AsScalarConstant()) {
const auto& scalar_words = scalar_component->words();
assert(
scalar_words.size() == 1 &&
"Vector components with longer than 32-bit width are not allowed "
"in OperateVectors()");
operand_values_for_one_dimension.push_back(scalar_words.front());
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"VectorConst should only has ScalarConst or NullConst as "
"components");
}
} else if (operand->AsNullConstant()) {
operand_values_for_one_dimension.push_back(0u);
} else {
assert(false &&
"OperateVectors() only accepts VectorConst or NullConst type of "
"constant");
}
}
result.push_back(OperateWords(opcode, operand_values_for_one_dimension));
}
return result;
}
} // anonymous namespace
FoldSpecConstantOpAndCompositePass::FoldSpecConstantOpAndCompositePass()
: max_id_(0),
module_(nullptr),
def_use_mgr_(nullptr),
type_mgr_(nullptr),
id_to_const_val_() {}
Pass::Status FoldSpecConstantOpAndCompositePass::ProcessImpl(
ir::Module* module) {
bool modified = false;
// Traverse through all the constant defining instructions. For Normal
// Constants whose values are determined and do not depend on OpUndef
// instructions, records their values in two internal maps: id_to_const_val_
// and const_val_to_id_ so that we can use them to infer the value of Spec
// Constants later.
// For Spec Constants defined with OpSpecConstantComposite instructions, if
// all of their components are Normal Constants, they will be turned into
// Normal Constants too. For Spec Constants defined with OpSpecConstantOp
// instructions, we check if they only depends on Normal Constants and fold
// them when possible. The two maps for Normal Constants: id_to_const_val_
// and const_val_to_id_ will be updated along the traversal so that the new
// Normal Constants generated from folding can be used to fold following Spec
// Constants.
// This algorithm depends on the SSA property of SPIR-V when
// defining constants. The dependent constants must be defined before the
// dependee constants. So a dependent Spec Constant must be defined and
// will be processed before its dependee Spec Constant. When we encounter
// the dependee Spec Constants, all its dependent constants must have been
// processed and all its dependent Spec Constants should have been folded if
// possible.
for (ir::Module::inst_iterator inst_iter = module->types_values_begin();
// Need to re-evaluate the end iterator since we may modify the list of
// instructions in this section of the module as the process goes.
inst_iter != module->types_values_end(); ++inst_iter) {
ir::Instruction* inst = &*inst_iter;
// Collect constant values of normal constants and process the
// OpSpecConstantOp and OpSpecConstantComposite instructions if possible.
// The constant values will be stored in analysis::Constant instances.
// OpConstantSampler instruction is not collected here because it cannot be
// used in OpSpecConstant{Composite|Op} instructions.
// TODO(qining): If the constant or its type has decoration, we may need
// to skip it.
if (GetType(inst) && !GetType(inst)->decoration_empty()) continue;
switch (SpvOp opcode = inst->opcode()) {
// Records the values of Normal Constants.
case SpvOp::SpvOpConstantTrue:
case SpvOp::SpvOpConstantFalse:
case SpvOp::SpvOpConstant:
case SpvOp::SpvOpConstantNull:
case SpvOp::SpvOpConstantComposite:
case SpvOp::SpvOpSpecConstantComposite: {
// A Constant instance will be created if the given instruction is a
// Normal Constant whose value(s) are fixed. Note that for a composite
// Spec Constant defined with OpSpecConstantComposite instruction, if
// all of its components are Normal Constants already, the Spec
// Constant will be turned in to a Normal Constant. In that case, a
// Constant instance should also be created successfully and recorded
// in the id_to_const_val_ and const_val_to_id_ mapps.
if (auto const_value = CreateConstFromInst(inst)) {
// Need to replace the OpSpecConstantComposite instruction with a
// corresponding OpConstantComposite instruction.
if (opcode == SpvOp::SpvOpSpecConstantComposite) {
inst->SetOpcode(SpvOp::SpvOpConstantComposite);
modified = true;
}
const_val_to_id_[const_value.get()] = inst->result_id();
id_to_const_val_[inst->result_id()] = std::move(const_value);
}
break;
}
// For a Spec Constants defined with OpSpecConstantOp instruction, check
// if it only depends on Normal Constants. If so, the Spec Constant will
// be folded. The original Spec Constant defining instruction will be
// replaced by Normal Constant defining instructions, and the new Normal
// Constants will be added to id_to_const_val_ and const_val_to_id_ so
// that we can use the new Normal Constants when folding following Spec
// Constants.
case SpvOp::SpvOpSpecConstantOp:
modified |= ProcessOpSpecConstantOp(&inst_iter);
break;
default:
break;
}
}
return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}
bool FoldSpecConstantOpAndCompositePass::ProcessOpSpecConstantOp(
ir::Module::inst_iterator* pos) {
ir::Instruction* inst = &**pos;
ir::Instruction* folded_inst = nullptr;
assert(inst->GetInOperand(0).type ==
SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER &&
"The first in-operand of OpSpecContantOp instruction must be of "
"SPV_OPERAND_TYPE_SPEC_CONSTANT_OP_NUMBER type");
switch (static_cast<SpvOp>(inst->GetSingleWordInOperand(0))) {
case SpvOp::SpvOpCompositeExtract:
folded_inst = DoCompositeExtract(pos);
break;
case SpvOp::SpvOpVectorShuffle:
folded_inst = DoVectorShuffle(pos);
break;
case SpvOp::SpvOpCompositeInsert:
// Current Glslang does not generate code with OpSpecConstantOp
// CompositeInsert instruction, so this is not implmented so far.
// TODO(qining): Implement CompositeInsert case.
return false;
default:
// Component-wise operations.
folded_inst = DoComponentWiseOperation(pos);
break;
}
if (!folded_inst) return false;
// Replace the original constant with the new folded constant, kill the
// original constant.
uint32_t new_id = folded_inst->result_id();
uint32_t old_id = inst->result_id();
def_use_mgr_->ReplaceAllUsesWith(old_id, new_id);
def_use_mgr_->KillDef(old_id);
return true;
}
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoCompositeExtract(
ir::Module::inst_iterator* pos) {
ir::Instruction* inst = &**pos;
assert(inst->NumInOperands() - 1 >= 2 &&
"OpSpecConstantOp CompositeExtract requires at least two non-type "
"non-opcode operands.");
assert(inst->GetInOperand(1).type == SPV_OPERAND_TYPE_ID &&
"The vector operand must have a SPV_OPERAND_TYPE_ID type");
assert(
inst->GetInOperand(2).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
"The literal operand must have a SPV_OPERAND_TYPE_LITERAL_INTEGER type");
// Note that for OpSpecConstantOp, the second in-operand is the first id
// operand. The first in-operand is the spec opcode.
analysis::Constant* first_operand_const =
FindRecordedConst(inst->GetSingleWordInOperand(1));
if (!first_operand_const) return nullptr;
const analysis::Constant* current_const = first_operand_const;
for (uint32_t i = 2; i < inst->NumInOperands(); i++) {
uint32_t literal = inst->GetSingleWordInOperand(i);
if (const analysis::CompositeConstant* composite_const =
current_const->AsCompositeConstant()) {
// Case 1: current constant is a non-null composite type constant.
assert(literal < composite_const->GetComponents().size() &&
"Literal index out of bound of the composite constant");
current_const = composite_const->GetComponents().at(literal);
} else if (current_const->AsNullConstant()) {
// Case 2: current constant is a constant created with OpConstantNull.
// Because components of a NullConstant are always NullConstants, we can
// return early with a NullConstant in the result type.
return BuildInstructionAndAddToModule(CreateConst(GetType(inst), {}),
pos);
} else {
// Dereferencing a non-composite constant. Invalid case.
return nullptr;
}
}
return BuildInstructionAndAddToModule(current_const->Copy(), pos);
}
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoVectorShuffle(
ir::Module::inst_iterator* pos) {
ir::Instruction* inst = &**pos;
analysis::Vector* result_vec_type = GetType(inst)->AsVector();
assert(inst->NumInOperands() - 1 > 2 &&
"OpSpecConstantOp DoVectorShuffle instruction requires more than 2 "
"operands (2 vector ids and at least one literal operand");
assert(result_vec_type &&
"The result of VectorShuffle must be of type vector");
// A temporary null constants that can be used as the components fo the
// result vector. This is needed when any one of the vector operands are null
// constant.
std::unique_ptr<analysis::Constant> null_component_constants;
// Get a concatenated vector of scalar constants. The vector should be built
// with the components from the first and the second operand of VectorShuffle.
std::vector<const analysis::Constant*> concatenated_components;
// Note that for OpSpecConstantOp, the second in-operand is the first id
// operand. The first in-operand is the spec opcode.
for (uint32_t i : {1, 2}) {
assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_ID &&
"The vector operand must have a SPV_OPERAND_TYPE_ID type");
uint32_t operand_id = inst->GetSingleWordInOperand(i);
analysis::Constant* operand_const = FindRecordedConst(operand_id);
if (!operand_const) return nullptr;
const analysis::Type* operand_type = operand_const->type();
assert(operand_type->AsVector() &&
"The first two operand of VectorShuffle must be of vector type");
if (analysis::VectorConstant* vec_const =
operand_const->AsVectorConstant()) {
// case 1: current operand is a non-null vector constant.
concatenated_components.insert(concatenated_components.end(),
vec_const->GetComponents().begin(),
vec_const->GetComponents().end());
} else if (operand_const->AsNullConstant()) {
// case 2: current operand is a null vector constant. Create a temporary
// null scalar constant as the component.
if (!null_component_constants) {
const analysis::Type* component_type =
operand_type->AsVector()->element_type();
null_component_constants = CreateConst(component_type, {});
}
// Append the null scalar consts to the concatenated components
// vector.
concatenated_components.insert(concatenated_components.end(),
operand_type->AsVector()->element_count(),
null_component_constants.get());
} else {
// no other valid cases
return nullptr;
}
}
// Create null component constants if there are any. The component constants
// must be added to the module before the dependee composite constants to
// satisfy SSA def-use dominance.
if (null_component_constants) {
BuildInstructionAndAddToModule(std::move(null_component_constants), pos);
}
// Create the new vector constant with the selected components.
std::vector<const analysis::Constant*> selected_components;
for (uint32_t i = 3; i < inst->NumInOperands(); i++) {
assert(inst->GetInOperand(i).type == SPV_OPERAND_TYPE_LITERAL_INTEGER &&
"The literal operand must of type SPV_OPERAND_TYPE_LITERAL_INTEGER");
uint32_t literal = inst->GetSingleWordInOperand(i);
assert(literal < concatenated_components.size() &&
"Literal index out of bound of the concatenated vector");
selected_components.push_back(concatenated_components[literal]);
}
auto new_vec_const = MakeUnique<analysis::VectorConstant>(
result_vec_type, selected_components);
return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
}
namespace {
// A helper function to check the type for component wise operations. Returns
// true if the type:
// 1) is bool type;
// 2) is 32-bit int type;
// 3) is vector of bool type;
// 4) is vector of 32-bit integer type.
// Otherwise returns false.
bool IsValidTypeForComponentWiseOperation(const analysis::Type* type) {
if (type->AsBool()) {
return true;
} else if (auto* it = type->AsInteger()) {
if (it->width() == 32) return true;
} else if (auto* vt = type->AsVector()) {
if (vt->element_type()->AsBool())
return true;
else if (auto* vit = vt->element_type()->AsInteger()) {
if (vit->width() == 32) return true;
}
}
return false;
}
}
ir::Instruction* FoldSpecConstantOpAndCompositePass::DoComponentWiseOperation(
ir::Module::inst_iterator* pos) {
const ir::Instruction* inst = &**pos;
const analysis::Type* result_type = GetType(inst);
SpvOp spec_opcode = static_cast<SpvOp>(inst->GetSingleWordInOperand(0));
// Check and collect operands.
std::vector<analysis::Constant*> operands;
if (!std::all_of(inst->cbegin(), inst->cend(),
[&operands, this](const ir::Operand& o) {
// skip the operands that is not an id.
if (o.type != spv_operand_type_t::SPV_OPERAND_TYPE_ID)
return true;
uint32_t id = o.words.front();
if (analysis::Constant* c = FindRecordedConst(id)) {
if (IsValidTypeForComponentWiseOperation(c->type())) {
operands.push_back(c);
return true;
}
}
return false;
}))
return nullptr;
if (result_type->AsInteger() || result_type->AsBool()) {
// Scalar operation
uint32_t result_val = OperateScalars(spec_opcode, operands);
auto result_const = CreateConst(result_type, {result_val});
return BuildInstructionAndAddToModule(std::move(result_const), pos);
} else if (result_type->AsVector()) {
// Vector operation
const analysis::Type* element_type =
result_type->AsVector()->element_type();
uint32_t num_dims = result_type->AsVector()->element_count();
std::vector<uint32_t> result_vec =
OperateVectors(spec_opcode, num_dims, operands);
std::vector<const analysis::Constant*> result_vector_components;
for (uint32_t r : result_vec) {
if (auto rc = CreateConst(element_type, {r})) {
result_vector_components.push_back(rc.get());
if (!BuildInstructionAndAddToModule(std::move(rc), pos)) {
assert(false &&
"Failed to build and insert constant declaring instruction "
"for the given vector component constant");
}
} else {
assert(false && "Failed to create constants with 32-bit word");
}
}
auto new_vec_const = MakeUnique<analysis::VectorConstant>(
result_type->AsVector(), result_vector_components);
return BuildInstructionAndAddToModule(std::move(new_vec_const), pos);
} else {
// Cannot process invalid component wise operation. The result of component
// wise operation must be of integer or bool scalar or vector of
// integer/bool type.
return nullptr;
}
}
ir::Instruction*
FoldSpecConstantOpAndCompositePass::BuildInstructionAndAddToModule(
std::unique_ptr<analysis::Constant> c, ir::Module::inst_iterator* pos) {
analysis::Constant* new_const = c.get();
uint32_t new_id = ++max_id_;
module_->SetIdBound(new_id + 1);
const_val_to_id_[new_const] = new_id;
id_to_const_val_[new_id] = std::move(c);
auto new_inst = CreateInstruction(new_id, new_const);
if (!new_inst) return nullptr;
auto* new_inst_ptr = new_inst.get();
*pos = pos->InsertBefore(std::move(new_inst));
(*pos)++;
def_use_mgr_->AnalyzeInstDefUse(new_inst_ptr);
return new_inst_ptr;
}
std::unique_ptr<analysis::Constant>
FoldSpecConstantOpAndCompositePass::CreateConstFromInst(ir::Instruction* inst) {
std::vector<uint32_t> literal_words_or_ids;
std::unique_ptr<analysis::Constant> new_const;
// Collect the constant defining literals or component ids.
for (uint32_t i = 0; i < inst->NumInOperands(); i++) {
literal_words_or_ids.insert(literal_words_or_ids.end(),
inst->GetInOperand(i).words.begin(),
inst->GetInOperand(i).words.end());
}
switch (inst->opcode()) {
// OpConstant{True|Flase} have the value embedded in the opcode. So they
// are not handled by the for-loop above. Here we add the value explicitly.
case SpvOp::SpvOpConstantTrue:
literal_words_or_ids.push_back(true);
break;
case SpvOp::SpvOpConstantFalse:
literal_words_or_ids.push_back(false);
break;
case SpvOp::SpvOpConstantNull:
case SpvOp::SpvOpConstant:
case SpvOp::SpvOpConstantComposite:
case SpvOp::SpvOpSpecConstantComposite:
break;
default:
return nullptr;
}
return CreateConst(GetType(inst), literal_words_or_ids);
}
analysis::Constant* FoldSpecConstantOpAndCompositePass::FindRecordedConst(
uint32_t id) {
auto iter = id_to_const_val_.find(id);
if (iter == id_to_const_val_.end()) {
return nullptr;
} else {
return iter->second.get();
}
}
uint32_t FoldSpecConstantOpAndCompositePass::FindRecordedConst(
const analysis::Constant* c) {
auto iter = const_val_to_id_.find(c);
if (iter == const_val_to_id_.end()) {
return 0;
} else {
return iter->second;
}
}
std::vector<const analysis::Constant*>
FoldSpecConstantOpAndCompositePass::GetConstsFromIds(
const std::vector<uint32_t>& ids) {
std::vector<const analysis::Constant*> constants;
for (uint32_t id : ids) {
if (analysis::Constant* c = FindRecordedConst(id)) {
constants.push_back(c);
} else {
return {};
}
}
return constants;
}
std::unique_ptr<analysis::Constant>
FoldSpecConstantOpAndCompositePass::CreateConst(
const analysis::Type* type,
const std::vector<uint32_t>& literal_words_or_ids) {
std::unique_ptr<analysis::Constant> new_const;
if (literal_words_or_ids.size() == 0) {
// Constant declared with OpConstantNull
return MakeUnique<analysis::NullConstant>(type);
} else if (auto* bt = type->AsBool()) {
assert(literal_words_or_ids.size() == 1 &&
"Bool constant should be declared with one operand");
return MakeUnique<analysis::BoolConstant>(bt, literal_words_or_ids.front());
} else if (auto* it = type->AsInteger()) {
return MakeUnique<analysis::IntConstant>(it, literal_words_or_ids);
} else if (auto* ft = type->AsFloat()) {
return MakeUnique<analysis::FloatConstant>(ft, literal_words_or_ids);
} else if (auto* vt = type->AsVector()) {
auto components = GetConstsFromIds(literal_words_or_ids);
if (components.empty()) return nullptr;
// All components of VectorConstant must be of type Bool, Integer or Float.
if (!std::all_of(components.begin(), components.end(),
[](const analysis::Constant* c) {
if (c->type()->AsBool() || c->type()->AsInteger() ||
c->type()->AsFloat()) {
return true;
} else {
return false;
}
}))
return nullptr;
// All components of VectorConstant must be in the same type.
const auto* component_type = components.front()->type();
if (!std::all_of(components.begin(), components.end(),
[&component_type](const analysis::Constant* c) {
if (c->type() == component_type) return true;
return false;
}))
return nullptr;
return MakeUnique<analysis::VectorConstant>(vt, components);
} else if (auto* st = type->AsStruct()) {
auto components = GetConstsFromIds(literal_words_or_ids);
if (components.empty()) return nullptr;
return MakeUnique<analysis::StructConstant>(st, components);
} else if (auto* at = type->AsArray()) {
auto components = GetConstsFromIds(literal_words_or_ids);
if (components.empty()) return nullptr;
return MakeUnique<analysis::ArrayConstant>(at, components);
} else {
return nullptr;
}
}
std::vector<ir::Operand> BuildOperandsFromIds(
const std::vector<uint32_t>& ids) {
std::vector<ir::Operand> operands;
for (uint32_t id : ids) {
operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{id});
}
return operands;
}
std::unique_ptr<ir::Instruction>
FoldSpecConstantOpAndCompositePass::CreateInstruction(uint32_t id,
analysis::Constant* c) {
if (c->AsNullConstant()) {
return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantNull,
type_mgr_->GetId(c->type()), id,
std::initializer_list<ir::Operand>{});
} else if (analysis::BoolConstant* bc = c->AsBoolConstant()) {
return MakeUnique<ir::Instruction>(
bc->value() ? SpvOp::SpvOpConstantTrue : SpvOp::SpvOpConstantFalse,
type_mgr_->GetId(c->type()), id, std::initializer_list<ir::Operand>{});
} else if (analysis::IntConstant* ic = c->AsIntConstant()) {
return MakeUnique<ir::Instruction>(
SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
std::initializer_list<ir::Operand>{ir::Operand(
spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
ic->words())});
} else if (analysis::FloatConstant* fc = c->AsFloatConstant()) {
return MakeUnique<ir::Instruction>(
SpvOp::SpvOpConstant, type_mgr_->GetId(c->type()), id,
std::initializer_list<ir::Operand>{ir::Operand(
spv_operand_type_t::SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER,
fc->words())});
} else if (analysis::CompositeConstant* cc = c->AsCompositeConstant()) {
return CreateCompositeInstruction(id, cc);
} else {
return nullptr;
}
}
std::unique_ptr<ir::Instruction>
FoldSpecConstantOpAndCompositePass::CreateCompositeInstruction(
uint32_t result_id, analysis::CompositeConstant* cc) {
std::vector<ir::Operand> operands;
for (const analysis::Constant* component_const : cc->GetComponents()) {
uint32_t id = FindRecordedConst(component_const);
if (id == 0) {
// Cannot get the id of the component constant, while all components
// should have been added to the module prior to the composite constant.
// Cannot create OpConstantComposite instruction in this case.
return nullptr;
}
operands.emplace_back(spv_operand_type_t::SPV_OPERAND_TYPE_ID,
std::initializer_list<uint32_t>{id});
}
return MakeUnique<ir::Instruction>(SpvOp::SpvOpConstantComposite,
type_mgr_->GetId(cc->type()), result_id,
std::move(operands));
}
} // namespace opt
} // namespace spvtools