blob: 1e3b41428cd1a648e1e5f234f326f824d759565e [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tfrt/transforms/corert_converter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
#include "tfrt/core_runtime/opdefs/attributes.h" // from @tf_runtime
#include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
#include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
namespace tensorflow {
CoreRTConverter::CoreRTConverter(
mlir::MLIRContext *context,
const mlir::TF::SideEffectAnalysis::Info *side_effect_analysis)
: builder_(context), side_effect_analysis_(*side_effect_analysis) {
addConversion([](tfrt::compiler::ChainType type) { return type; });
addConversion([](tfrt::corert::OpHandlerType type) { return type; });
addConversion([](tfrt::dist::DistributedContextType type) { return type; });
addConversion([](tfrt::corert::TensorHandleType type) { return type; });
addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
// Ref types are not supported in both compiler and runtime.
if (type.getElementType().isa<mlir::TF::TensorFlowRefType>())
return llvm::None;
return tensor_handle_type();
});
addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
if (type == builder_.getI1Type()) return type;
return llvm::None;
});
}
void CoreRTConverter::MaterializeDerivedAttributes(mlir::Operation *op) {
if (auto interface = llvm::dyn_cast<mlir::DerivedAttributeOpInterface>(op)) {
auto derived_attrs = interface.materializeDerivedAttributes();
for (auto named_attr : derived_attrs) {
op->setAttr(named_attr.getName(), named_attr.getValue());
}
}
}
bool CoreRTConverter::IsSupportedNumericDType(mlir::Type type) const {
// Most of the tensorflow data types (eg. f32, i64) are supported and they
// are standard MLIR types that need no conversion here.
if (type.isBF16() || type.isF16() || type.isF32() || type.isF64() ||
type.isInteger(1) || type.isInteger(8) || type.isInteger(16) ||
type.isInteger(32) || type.isInteger(64) || type.isUnsignedInteger(8) ||
type.isUnsignedInteger(16) || type.isUnsignedInteger(32) ||
type.isUnsignedInteger(64))
return true;
if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
auto element_type = complex_type.getElementType();
if (element_type.isF32() || element_type.isF64()) return true;
}
return false;
}
mlir::ArrayAttr CoreRTConverter::CreateOpAttrs(ArrayRef<NamedAttribute> attrs) {
llvm::SmallVector<mlir::Attribute, 4> attr_array;
for (auto key_and_value : attrs) {
if (!IsUnusedAttribute(key_and_value.getName())) {
auto converted = ConvertAttribute(key_and_value.getValue());
if (!converted) return {};
mlir::StringAttr key =
builder_.getStringAttr(key_and_value.getName().strref());
attr_array.push_back(builder_.getArrayAttr({key, converted}));
}
}
return builder_.getArrayAttr(attr_array);
}
mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs(
ArrayRef<NamedAttribute> attrs,
llvm::SmallVector<mlir::StringAttr, 4> *func_attr_keys) {
llvm::SmallVector<mlir::Attribute, 4> attr_array;
for (auto key_and_value : attrs) {
auto attr_key = key_and_value.getName();
auto attr_value = key_and_value.getValue();
if (!IsUnusedAttribute(attr_key) &&
attr_value.isa<mlir::FlatSymbolRefAttr, mlir::SymbolRefAttr>()) {
auto func_attr = attr_value.dyn_cast<mlir::FlatSymbolRefAttr>();
auto converted = ConvertSymbolAttrToStringAttr(func_attr);
mlir::StringAttr key = builder_.getStringAttr(attr_key.strref());
attr_array.push_back(builder_.getArrayAttr({key, converted}));
// Remove the attribute to avoid being converted again.
func_attr_keys->push_back(attr_key);
}
}
return builder_.getArrayAttr(attr_array);
}
// TODO(chky): Add support for multiple device instances.
llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
llvm::StringRef device_name) const {
std::string tf_device_name = device_name.str();
if (tf_device_name.empty()) {
return llvm::None;
}
ParseDeviceNameResult result;
result.device_name = tf_device_name;
// Parse the device name in format of the current tensorflow.
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(result.device_name, &parsed_name)) {
return llvm::None;
}
if (!parsed_name.has_type) {
return llvm::None;
}
result.device_type = parsed_name.type;
result.op_handler_name = tf_device_name;
return result;
}
llvm::Optional<ParseDeviceNameResult> CoreRTConverter::ParseDeviceName(
mlir::Operation *op) const {
auto device_attr = op->getAttr("device");
if (!device_attr) {
return llvm::None;
}
auto parsed_device_name =
ParseDeviceName(device_attr.cast<mlir::StringAttr>().getValue());
if (!parsed_device_name) op->emitWarning("failed to parse device name.");
return parsed_device_name;
}
mlir::Value CoreRTConverter::ConvertOpHandler(
mlir::Operation *op, llvm::StringRef op_handler_name,
ConversionPatternRewriter *rewriter) {
auto iter = op_handler_by_name_.find(op_handler_name);
if (iter != op_handler_by_name_.end()) return iter->second;
mlir::Block *block = op->getBlock();
ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
rewriter->setInsertionPointToStart(block);
func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
mlir::Value in_chain = func_op.getArgument(0);
auto get_op_handler_op = rewriter->create<tfrt::corert::GetOpHandler>(
block->getParent()->getLoc(), op_handler_type(), in_chain,
op_handler_name);
op_handler_by_name_[op_handler_name] = get_op_handler_op.getResult();
return get_op_handler_op.getResult();
}
mlir::Value CoreRTConverter::GetDistributedContext(
mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
auto iter = distributed_context_by_func_.find(func_op.getOperation());
if (iter != distributed_context_by_func_.end()) {
return iter->second;
}
ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
rewriter->setInsertionPoint(op);
auto get_dist_ctx_op = rewriter->create<tfrt::dist::GetDistributedContextOp>(
op->getLoc(), distributed_context_type());
mlir::Value result = get_dist_ctx_op.result();
distributed_context_by_func_[func_op.getOperation()] = result;
return result;
}
mlir::Value CoreRTConverter::GetRemoteChainManager(
mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
auto iter = remote_chain_mgr_by_func_.find(func_op.getOperation());
if (iter != remote_chain_mgr_by_func_.end()) {
return iter->second;
}
ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
rewriter->setInsertionPoint(op);
mlir::Type remote_chain_mgr_type =
builder_.getType<::tfrt::dist::RemoteChainManagerType>();
mlir::Value dist_ctx = GetDistributedContext(op, rewriter);
auto create_mgr_op = rewriter->create<tfrt::dist::CreateRemoteChainManager>(
op->getLoc(), remote_chain_mgr_type, dist_ctx);
mlir::Value result = create_mgr_op.result();
remote_chain_mgr_by_func_[func_op.getOperation()] = result;
return result;
}
mlir::Value CoreRTConverter::GetLocalSideEffectChain(
mlir::Operation *op, mlir::ConversionPatternRewriter *rewriter) {
auto func_op = op->getParentOfType<mlir::func::FuncOp>();
llvm::SmallVector<mlir::Operation *, 4> predecessors;
if (llvm::isa<mlir::func::ReturnOp>(op)) {
auto sinks = side_effect_analysis_.ControlSinks();
predecessors.assign(sinks.begin(), sinks.end());
} else {
predecessors = side_effect_analysis_.DirectControlPredecessors(op);
}
llvm::SmallVector<mlir::Value, 2> chains;
for (auto *pred : predecessors) {
// TODO(chky): ReadVariableOp is removed in the pass and not converted.
// Ideally, every side-effecting op should be converted to a
// tfrt_fallback.executeop.seq op. The special rewrite logic of
// ReadVariableOp should be done in a previous pass.
if (auto chain = local_side_effect_chains_.lookup(pred))
chains.push_back(chain);
}
// If there is no side-effect predecessor, then the input side-effect chain
// is used.
if (chains.empty()) return func_op.getArgument(0);
if (chains.size() == 1) return chains[0];
// If there are multiple side-effect predecessors, insert a merge_chains
// kernel and return the merged chain.
ConversionPatternRewriter::InsertionGuard insertion_guard(*rewriter);
rewriter->setInsertionPoint(op);
return rewriter->create<tfrt::compiler::MergeChainsOp>(op->getLoc(),
chain_type(), chains);
}
mlir::Value CoreRTConverter::GetTaskHandle(
mlir::Operation *op, StringRef task_name,
mlir::ConversionPatternRewriter *rewriter) {
mlir::func::FuncOp func_op = op->getParentOfType<mlir::func::FuncOp>();
llvm::StringMap<mlir::Value> &task_handle_by_name =
task_handles_by_func_[func_op.getOperation()];
auto iter = task_handle_by_name.find(task_name);
if (iter != task_handle_by_name.end()) {
return iter->second;
}
mlir::Value distributed_context = GetDistributedContext(op, rewriter);
auto task_handle_op = rewriter->create<tfrt::dist::GetTaskHandleOp>(
op->getLoc(), rewriter->getType<tfrt::dist::TaskHandleType>(),
distributed_context, task_name);
task_handle_by_name[task_name] = task_handle_op.getResult();
return task_handle_op.getResult();
}
mlir::Value CoreRTConverter::GetRemoteSideEffectChain(
mlir::Operation *op, StringRef remote_host,
mlir::ConversionPatternRewriter *rewriter) {
mlir::Value remote_chain_mgr = GetRemoteChainManager(op, rewriter);
mlir::Value local_chain = GetLocalSideEffectChain(op, rewriter);
mlir::Value task_handle = GetTaskHandle(op, remote_host, rewriter);
mlir::Type remote_obj_id_ty =
rewriter->getType<tfrt::dist::RemoteObjectIdType>();
// Get the remote chain using the tfrt_dist.get_chain_for_task_handle op.
auto get_chain_op = rewriter->create<tfrt::dist::GetChainForTaskHandleOp>(
op->getLoc(), remote_obj_id_ty, local_chain, remote_chain_mgr,
task_handle);
return get_chain_op.getResult();
}
mlir::Attribute CoreRTConverter::ConvertAttribute(mlir::Attribute attr) {
// The supported attributes here should be kept consistent with
// //third_party/tf_runtime/include/tfrt/core_runtime/op_attr_type.h
//
// Currently, not all tensorflow data types are supported. Unranked shape
// attributes are not supported yet.
// Return directly if the attribute is already supported.
if (attr.isa<mlir::IntegerAttr, mlir::FloatAttr, mlir::BoolAttr,
mlir::StringAttr, mlir::DenseIntOrFPElementsAttr>())
return attr;
// For type attributes, we convert non-standard MLIR types to corresponding
// corert types.
if (auto type_attr = attr.dyn_cast<mlir::TypeAttr>()) {
if (auto shape_type = type_attr.getValue().dyn_cast<mlir::TensorType>()) {
if (!shape_type.hasRank())
return tfrt::corert::ShapeAttr::get(builder_.getContext());
return tfrt::corert::ShapeAttr::get(builder_.getContext(),
shape_type.getShape());
}
return ConvertTypeAttribute(type_attr);
}
// Convert the attribute to the corresponding format in TFRT dialect if
// needed.
if (auto shape_attr = attr.dyn_cast<mlir::TF::ShapeAttr>()) {
if (!shape_attr.hasRank())
return tfrt::corert::ShapeAttr::get(builder_.getContext());
return tfrt::corert::ShapeAttr::get(builder_.getContext(),
shape_attr.getShape());
}
// For arrays, we recursively convert the elements.
if (auto array_attr = attr.dyn_cast<mlir::ArrayAttr>()) {
llvm::SmallVector<mlir::Attribute, 8> attrs;
attrs.reserve(array_attr.size());
for (auto attr : array_attr) {
auto converted = ConvertAttribute(attr);
if (!converted) return {};
attrs.push_back(converted);
}
return builder_.getArrayAttr(attrs);
}
return {};
}
mlir::StringAttr CoreRTConverter::ConvertSymbolAttrToStringAttr(
mlir::FlatSymbolRefAttr symbol_attr) {
// Currently in TF graph to MLIR importing, a "0" is appended to the original
// function name, so we pop it here. The renaming is for TF/XLA v1 bridge
// use cases. Refer to b/142268695, b/141617294 for more context.
//
// In TFRT use cases, in almost every case "0" is the only literal
// appended since TF Graph already guarantee function name uniqueness.
// TODO(b/172092902): Investigate a better way to make the tf_func_name to
// mlir_tf_func_name conversion reversible.
auto func_name = symbol_attr.getValue().drop_back().str();
return mlir::StringAttr::get(builder_.getContext(), func_name);
}
mlir::TypeAttr CoreRTConverter::ConvertTypeAttribute(mlir::TypeAttr type_attr) {
auto type = type_attr.getValue();
if (IsSupportedNumericDType(type)) return type_attr;
// For TF custom types, we convert it to custom corert types.
if (type.isa<mlir::TF::StringType>())
return mlir::TypeAttr::get(
tfrt::corert::StringType::get(builder_.getContext()));
if (type.isa<mlir::TF::ResourceType>())
return mlir::TypeAttr::get(
tfrt::corert::ResourceType::get(builder_.getContext()));
if (type.isa<mlir::TF::VariantType>())
return mlir::TypeAttr::get(
tfrt::corert::VariantType::get(builder_.getContext()));
if (type.isa<mlir::TF::Quint8Type>()) {
return mlir::TypeAttr::get(
tfrt::corert::Quint8Type::get(builder_.getContext()));
}
if (type.isa<mlir::TF::Quint16Type>()) {
return mlir::TypeAttr::get(
tfrt::corert::Quint16Type::get(builder_.getContext()));
}
if (type.isa<mlir::TF::Qint8Type>()) {
return mlir::TypeAttr::get(
tfrt::corert::Qint8Type::get(builder_.getContext()));
}
if (type.isa<mlir::TF::Qint16Type>()) {
return mlir::TypeAttr::get(
tfrt::corert::Qint16Type::get(builder_.getContext()));
}
if (type.isa<mlir::TF::Qint32Type>()) {
return mlir::TypeAttr::get(
tfrt::corert::Qint32Type::get(builder_.getContext()));
}
// Return invalid results to emit error for unsupported types.
return {};
}
} // namespace tensorflow