blob: e75337f689cfe765127f2382d93e70357d9e9ecc [file] [log] [blame]
// Copyright (c) 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "transformation_replace_load_store_with_copy_memory.h"
#include "source/fuzz/fuzzer_util.h"
#include "source/fuzz/instruction_descriptor.h"
#include "source/opcode.h"
namespace spvtools {
namespace fuzz {
namespace {
const uint32_t kOpStoreOperandIndexTargetVariable = 0;
const uint32_t kOpStoreOperandIndexIntermediateIdToWrite = 1;
const uint32_t kOpLoadOperandIndexSourceVariable = 2;
} // namespace
TransformationReplaceLoadStoreWithCopyMemory::
TransformationReplaceLoadStoreWithCopyMemory(
protobufs::TransformationReplaceLoadStoreWithCopyMemory message)
: message_(std::move(message)) {}
TransformationReplaceLoadStoreWithCopyMemory::
TransformationReplaceLoadStoreWithCopyMemory(
const protobufs::InstructionDescriptor& load_instruction_descriptor,
const protobufs::InstructionDescriptor& store_instruction_descriptor) {
*message_.mutable_load_instruction_descriptor() = load_instruction_descriptor;
*message_.mutable_store_instruction_descriptor() =
store_instruction_descriptor;
}
bool TransformationReplaceLoadStoreWithCopyMemory::IsApplicable(
opt::IRContext* ir_context, const TransformationContext& /*unused*/) const {
// This transformation is only applicable to the pair of OpLoad and OpStore
// instructions.
// The OpLoad instruction must be defined.
auto load_instruction =
FindInstruction(message_.load_instruction_descriptor(), ir_context);
if (!load_instruction || load_instruction->opcode() != SpvOpLoad) {
return false;
}
// The OpStore instruction must be defined.
auto store_instruction =
FindInstruction(message_.store_instruction_descriptor(), ir_context);
if (!store_instruction || store_instruction->opcode() != SpvOpStore) {
return false;
}
// Intermediate values of the OpLoad and the OpStore must match.
if (load_instruction->result_id() !=
store_instruction->GetSingleWordOperand(
kOpStoreOperandIndexIntermediateIdToWrite)) {
return false;
}
// Get storage class of the variable pointed by the source operand in OpLoad.
opt::Instruction* source_id = ir_context->get_def_use_mgr()->GetDef(
load_instruction->GetSingleWordOperand(2));
SpvStorageClass storage_class = fuzzerutil::GetStorageClassFromPointerType(
ir_context, source_id->type_id());
// Iterate over all instructions between |load_instruction| and
// |store_instruction|.
for (auto it = load_instruction; it != store_instruction;
it = it->NextNode()) {
//|load_instruction| and |store_instruction| are not in the same block.
if (it == nullptr) {
return false;
}
// We need to make sure that the value pointed to by the source of the
// OpLoad hasn't changed by the time we see the matching OpStore
// instruction.
if (IsMemoryWritingOpCode(it->opcode())) {
return false;
} else if (IsMemoryBarrierOpCode(it->opcode()) &&
!IsStorageClassSafeAcrossMemoryBarriers(storage_class)) {
return false;
}
}
return true;
}
void TransformationReplaceLoadStoreWithCopyMemory::Apply(
opt::IRContext* ir_context, TransformationContext* /*unused*/) const {
// OpLoad and OpStore instructions must be defined.
auto load_instruction =
FindInstruction(message_.load_instruction_descriptor(), ir_context);
assert(load_instruction && load_instruction->opcode() == SpvOpLoad &&
"The required OpLoad instruction must be defined.");
auto store_instruction =
FindInstruction(message_.store_instruction_descriptor(), ir_context);
assert(store_instruction && store_instruction->opcode() == SpvOpStore &&
"The required OpStore instruction must be defined.");
// Intermediate values of the OpLoad and the OpStore must match.
assert(load_instruction->result_id() ==
store_instruction->GetSingleWordOperand(
kOpStoreOperandIndexIntermediateIdToWrite) &&
"OpLoad and OpStore must refer to the same value.");
// Get the ids of the source operand of the OpLoad and the target operand of
// the OpStore.
uint32_t source_variable_id =
load_instruction->GetSingleWordOperand(kOpLoadOperandIndexSourceVariable);
uint32_t target_variable_id = store_instruction->GetSingleWordOperand(
kOpStoreOperandIndexTargetVariable);
// Insert the OpCopyMemory instruction before the OpStore instruction.
store_instruction->InsertBefore(MakeUnique<opt::Instruction>(
ir_context, SpvOpCopyMemory, 0, 0,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {target_variable_id}},
{SPV_OPERAND_TYPE_ID, {source_variable_id}}})));
// Remove the OpStore instruction.
ir_context->KillInst(store_instruction);
ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
}
bool TransformationReplaceLoadStoreWithCopyMemory::IsMemoryWritingOpCode(
SpvOp op_code) {
if (spvOpcodeIsAtomicOp(op_code)) {
return op_code != SpvOpAtomicLoad;
}
switch (op_code) {
case SpvOpStore:
case SpvOpCopyMemory:
case SpvOpCopyMemorySized:
return true;
default:
return false;
}
}
bool TransformationReplaceLoadStoreWithCopyMemory::IsMemoryBarrierOpCode(
SpvOp op_code) {
switch (op_code) {
case SpvOpMemoryBarrier:
case SpvOpMemoryNamedBarrier:
return true;
default:
return false;
}
}
bool TransformationReplaceLoadStoreWithCopyMemory::
IsStorageClassSafeAcrossMemoryBarriers(SpvStorageClass storage_class) {
switch (storage_class) {
case SpvStorageClassUniformConstant:
case SpvStorageClassInput:
case SpvStorageClassUniform:
case SpvStorageClassPrivate:
case SpvStorageClassFunction:
return true;
default:
return false;
}
}
protobufs::Transformation
TransformationReplaceLoadStoreWithCopyMemory::ToMessage() const {
protobufs::Transformation result;
*result.mutable_replace_load_store_with_copy_memory() = message_;
return result;
}
std::unordered_set<uint32_t>
TransformationReplaceLoadStoreWithCopyMemory::GetFreshIds() const {
return std::unordered_set<uint32_t>();
}
} // namespace fuzz
} // namespace spvtools