blob: 397c99bbc0256cb4d67b48665acd70c5e0992971 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
#include <algorithm>
#include <cctype>
#include <memory>
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Support/TypeID.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
#include "tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
namespace mlir {
namespace TFL {
namespace tac {
namespace {
struct RegisteredTargetHardware {
// TODO(b/177376459): Remove this constructor.
RegisteredTargetHardware(const std::string& name,
const std::string& description, mlir::TypeID type_id,
std::unique_ptr<TargetHardware> target_hardware)
: unique_name(GetCanonicalHardwareName(name)),
description(description),
type_id(type_id),
target_hardware(std::move(target_hardware)) {}
RegisteredTargetHardware(
const std::string& name, const std::string& description,
mlir::TypeID type_id,
std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory)
: unique_name(GetCanonicalHardwareName(name)),
description(description),
target_hardware_factory(target_hardware_factory) {}
std::string unique_name;
std::string description;
mlir::TypeID type_id;
std::unique_ptr<TargetHardware> target_hardware;
std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory;
};
struct RegisteredTargetHardwareOps {
explicit RegisteredTargetHardwareOps(mlir::TypeID hardware_type)
: hardware_typeid(hardware_type) {}
// Key is the Operation TypeID
llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>
target_hardware_ops;
// Key is the Operation TypeID
llvm::DenseMap<mlir::TypeID,
std::function<std::unique_ptr<TargetHardwareOperation>()>>
target_hardware_ops_factory;
mlir::TypeID hardware_typeid;
};
std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
GetRegisteredTargetHardwareOps() {
static std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>*
hardwares_ops =
[]() -> std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>* {
return new std::vector<std::unique_ptr<RegisteredTargetHardwareOps>>();
}();
return hardwares_ops;
}
std::vector<RegisteredTargetHardware>* GetRegisteredHardwares() {
static std::vector<RegisteredTargetHardware>* hardwares =
[]() -> std::vector<RegisteredTargetHardware>* {
return new std::vector<RegisteredTargetHardware>();
}();
return hardwares;
}
llvm::DenseMap<mlir::TypeID, std::unique_ptr<TargetHardwareOperation>>*
getRegisteredOperationsForHardware(mlir::TypeID type_id) {
auto* hardwares = GetRegisteredTargetHardwareOps();
for (auto& hardware : *hardwares) {
if (hardware->hardware_typeid == type_id) {
return &hardware->target_hardware_ops;
}
}
return nullptr;
}
// A deny list for op cost computation since those ops are not arithemtic.
inline bool IsNonArithmeticOp(mlir::Operation* op) {
if (llvm::isa<func::ReturnOp, func::FuncOp>(op)) return true;
if (op->hasTrait<OpTrait::ConstantLike>()) return true;
if (llvm::isa<QConstOp, SparseQConstOp>(op)) return true;
if (!NotTFLQuantDequantizeOp(op)) return true;
return false;
}
} // namespace
bool TargetHardware::Init() {
auto* hardware_ops_factory = GetRegisteredTargetHardwareOps();
for (auto& hardware_ops : *hardware_ops_factory) {
if (hardware_ops->hardware_typeid != this->GetTypeId()) continue;
auto& op_factories = hardware_ops->target_hardware_ops_factory;
for (auto& op_factory : op_factories) {
hardware_ops_.emplace_back(op_factory.getSecond()());
}
break;
}
return true;
}
double TargetHardware::GetOpCost(mlir::Operation* op) const {
auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
if (registered_ops == nullptr) {
return kDefaultFixedValuedCost;
}
auto abstract_op = op->getRegisteredInfo();
auto hardware_op = registered_ops->find(abstract_op->getTypeID());
if (hardware_op == registered_ops->end()) return kDefaultFixedValuedCost;
return hardware_op->second->GetOpCost(op);
}
bool TargetHardware::IsOpSupported(mlir::Operation* op) const {
auto* registered_ops = getRegisteredOperationsForHardware(GetTypeId());
if (registered_ops == nullptr) {
return false;
}
auto abstract_op = op->getRegisteredInfo();
auto hardware_op = registered_ops->find(abstract_op->getTypeID());
if (hardware_op == registered_ops->end()) return false;
return hardware_op->second->IsOpSupported(op);
}
double TargetHardware::GetFuncCost(func::FuncOp* func) const {
double total_cost = 0.0;
func->walk([&](Operation* op) {
if (IsNonArithmeticOp(op)) return;
// We will always defer to the hardware to decide the cost.
total_cost += GetOpCost(op);
});
return total_cost;
}
const TargetHardware* GetTargetHardware(const std::string& hardware_name) {
const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
// Just loop for now, we don't expect number of hardwares to be huge.
// Revisit to have map if number of elements increased.
auto* registered_hardwares = GetRegisteredHardwares();
for (const auto& hardware : *registered_hardwares) {
if (hardware.unique_name == canonical_name) {
return hardware.target_hardware.get();
}
}
return nullptr;
}
std::function<std::unique_ptr<TargetHardware>()> GetTargetHardwareFactory(
const std::string& hardware_name) {
const std::string canonical_name = GetCanonicalHardwareName(hardware_name);
// Just loop for now, we don't expect number of hardwares to be huge.
// Revisit to have map if number of elements increased.
auto* registered_hardwares = GetRegisteredHardwares();
for (const auto& hardware : *registered_hardwares) {
if (hardware.unique_name == canonical_name) {
return hardware.target_hardware_factory;
}
}
return nullptr;
}
namespace internal {
void RegisterTargetHardware(
const std::string& unique_name, const std::string& description,
mlir::TypeID type_id,
std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
auto* registered_hardwares = GetRegisteredHardwares();
for (const auto& hardware : *registered_hardwares) {
if (hardware.unique_name == unique_name) {
llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
<< " already registered\n";
return;
}
}
registered_hardwares->push_back(RegisteredTargetHardware(
unique_name, description, type_id, target_hardware_factory()));
}
void RegisterTargetHardwareFactory(
const std::string& unique_name, const std::string& description,
mlir::TypeID type_id,
std::function<std::unique_ptr<TargetHardware>()> target_hardware_factory) {
auto* registered_hardwares = GetRegisteredHardwares();
for (auto& hardware : *registered_hardwares) {
if (hardware.unique_name == unique_name) {
llvm::errs() << "Ignoring duplicate hardware. Hardware " << unique_name
<< " already registered\n";
hardware.target_hardware_factory = target_hardware_factory;
return;
}
}
registered_hardwares->push_back(RegisteredTargetHardware(
unique_name, description, type_id, target_hardware_factory));
}
void RegisterTargetHardwareOp(
mlir::TypeID hardware_type, mlir::TypeID op_type,
std::function<std::unique_ptr<TargetHardwareOperation>()>
target_hardware_op_factory) {
auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
for (auto& hardware : *registered_hardware_ops) {
if (hardware->hardware_typeid == hardware_type) {
if (hardware->target_hardware_ops.count(op_type)) {
llvm::errs() << "Trying to register duplicate Op";
return;
}
hardware->target_hardware_ops[op_type] = target_hardware_op_factory();
return;
}
}
registered_hardware_ops->push_back(
std::make_unique<RegisteredTargetHardwareOps>(
RegisteredTargetHardwareOps(hardware_type)));
registered_hardware_ops->back()->target_hardware_ops[op_type] =
target_hardware_op_factory();
}
void RegisterTargetHardwareOpFactory(
mlir::TypeID hardware_type, mlir::TypeID op_type,
std::function<std::unique_ptr<TargetHardwareOperation>()>
target_hardware_op_factory) {
auto* registered_hardware_ops = GetRegisteredTargetHardwareOps();
for (auto& hardware : *registered_hardware_ops) {
if (hardware->hardware_typeid == hardware_type) {
if (hardware->target_hardware_ops_factory.count(op_type)) {
llvm::errs() << "Trying to register duplicate Op";
return;
}
hardware->target_hardware_ops_factory[op_type] =
target_hardware_op_factory;
return;
}
}
registered_hardware_ops->push_back(
std::make_unique<RegisteredTargetHardwareOps>(
RegisteredTargetHardwareOps(hardware_type)));
registered_hardware_ops->back()->target_hardware_ops_factory[op_type] =
target_hardware_op_factory;
}
} // namespace internal
bool ProcessTargetDevices(llvm::ArrayRef<std::string> specified_device_specs,
std::vector<std::string>* device_specs) {
bool cpu_include = false;
for (auto& device_spec : specified_device_specs) {
auto device = GetCanonicalHardwareName(device_spec);
if (device == "CPU") cpu_include = true;
device_specs->push_back(device);
}
if (!cpu_include) {
device_specs->push_back("CPU");
}
// Make sure all the devices are registered.
for (const std::string& device : *device_specs) {
if (GetTargetHardware(device) == nullptr) {
llvm::errs() << "cannot get target hardware for device: " << device;
return false;
}
}
return true;
}
std::string GetHardwareName(const TargetHardware* hardware) {
const auto* registered_hardwares = GetRegisteredHardwares();
for (const auto& registered_hardware : *registered_hardwares) {
if (registered_hardware.type_id == hardware->GetTypeId())
return registered_hardware.unique_name;
}
return "";
}
} // namespace tac
} // namespace TFL
} // namespace mlir