blob: bf7bfe8efa00e5d079b941a9d2609e7d614523f1 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/Optional.h"
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace mlir {
namespace xla_hlo {
namespace {
template <typename T, size_t N>
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
static bool IsOpWhitelisted(Operation* op) {
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
// building valid MLIR using MlirHloBuilder.
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
// all tf2xla kernels.
return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) || isa<TF::CastOp>(op) ||
isa<TF::GreaterOp>(op) || isa<TF::InvOp>(op) ||
isa<TF::SelectV2Op>(op);
}
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
const std::string& device_type, const Location& loc) {
// Register compilation kernels for all registered XLA backends.
tensorflow::XlaOpRegistry::RegisterCompilationKernels();
auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
}
class FuncLegalizer {
public:
static LogicalResult Legalize(FuncOp func, const std::string& device_type) {
FuncLegalizer legalizer(func, device_type);
if (failed(legalizer.PrepareParams())) return failure();
return legalizer.Legalize();
}
private:
FuncLegalizer(FuncOp func, const std::string& device_type)
: func_(func), device_type_(device_type), hlo_builder_(func) {}
~FuncLegalizer() { context_->Unref(); }
// Prepares OpKernelContext params common to all the ops.
// Emits an error on failure.
LogicalResult PrepareParams();
// Tries to legalize supported TensorFlow ops.
// Emits an error on failure.
LogicalResult Legalize();
// Tries to legalize the specified TensorFlow op, if supported.
//
// Emits an error and returns failure if an error is encountered during
// conversion. Note that success return value doesn't mean successful
// legalization.
LogicalResult LegalizeOp(Operation* op);
FuncOp func_;
std::string device_type_;
::xla::MlirHloBuilder hlo_builder_;
tensorflow::OpOrArgLocNameMapper name_mapper_;
tensorflow::XlaContext* context_; // Ref-counted.
std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
tensorflow::Device* device_; // Owned by device_mgr_;
std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
tensorflow::OpKernelContext::Params params_;
};
LogicalResult FuncLegalizer::PrepareParams() {
// XlaCompiler within the context is only used by the functional ops to
// compile functions. We are not handling those at the moment so XlaCompiler
// is not required.
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_);
context_->Ref();
mlir::Location loc = func_.getLoc();
device_mgr_ = CreateDeviceMgr(device_type_, loc);
if (!device_mgr_) return failure();
// Type of params_.device is DeviceBase* so store it as Device* to access
// derived class method.
device_ = device_mgr_->ListDevices().front();
params_.device = device_;
params_.resource_manager = device_->resource_manager();
// Resources are cleared at the time of device manager destruction so pass
// no-op cleanup function.
auto cleanup = [](const std::string& name) {};
// Use step_id zero as we only have a single context concurrently and
// concurrently running each of the MLIR functions create a new device.
step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
/*step_id=*/0, cleanup);
tensorflow::Status status = step_container_->Create(
device_->resource_manager(),
tensorflow::XlaContext::kXlaContextResourceName, context_);
if (!status.ok()) {
emitError(loc) << "failed to create XlaContext resource: "
<< status.ToString();
return failure();
}
params_.step_container = step_container_.get();
tensorflow::StatusOr<int64_t> version_or =
tensorflow::GetTfGraphProducerVersion(
func_.getParentOfType<mlir::ModuleOp>());
if (!version_or.ok()) {
emitError(loc) << version_or.status().ToString();
return failure();
}
flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
params_.function_library = pflr_->GetFLR(device_->name());
return success();
}
LogicalResult FuncLegalizer::Legalize() {
// TensorFlow functions don't use CFGs.
if (func_.getBlocks().size() > 1) {
emitError(func_.getLoc()) << "requires at most one block in a TF function";
return failure();
}
if (func_.getBlocks().empty()) return success();
Block& block = func_.getBlocks().front();
std::vector<Operation*> ops;
ops.reserve(block.getOperations().size());
for (Operation& op : block.getOperations()) {
ops.push_back(&op);
}
for (Operation* op : ops) {
if (failed(LegalizeOp(op))) return failure();
}
return success();
}
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
if (!IsOpWhitelisted(op)) return success();
// Only static shaped operands are supported in XLA builders for now.
for (Type ty : op->getOperandTypes()) {
auto ranked_ty = ty.cast<RankedTensorType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
op->emitRemark() << "lowering requires static shaped operands";
return success();
}
}
auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
op, name_mapper_.GetUniqueName(op), /*ignore_unregistered_attrs=*/true);
if (!nodedef_or.ok()) {
op->emitRemark() << "failed to convert op to NodeDef: "
<< nodedef_or.status().ToString();
return success();
}
std::shared_ptr<const tensorflow::NodeProperties> props;
tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
*nodedef_or.ValueOrDie(),
params_.function_library->GetFunctionLibraryDefinition(), &props);
if (!status.ok()) {
op->emitRemark() << "failed to create NodeProperties: "
<< status.ToString();
return success();
}
tensorflow::OpKernel* op_kernel_raw;
status = params_.function_library->CreateKernel(props, &op_kernel_raw);
if (!status.ok()) {
op->emitRemark() << "failed to create tf2xla kernel: " << status.ToString();
return success();
}
// Transfer ownership of the kernel to a local smart pointer.
auto op_kernel = absl::WrapUnique(op_kernel_raw);
// TensorValue in inputs are backed by tensors which in turn depend on
// expressions. So, pre-allocate them to the required size.
InlinedVector<tensorflow::XlaExpression, 4> expressions;
InlinedVector<tensorflow::Tensor, 4> tensors;
InlinedVector<tensorflow::TensorValue, 4> inputs;
expressions.reserve(op->getNumOperands());
tensors.reserve(op->getNumOperands());
inputs.reserve(op->getNumOperands());
// Prepare the list of Tensor inputs for the kernel.
for (Value operand : op->getOperands()) {
// Skip this op if XLA doesn't support this operand type.
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
if (!xla_op_or.ok()) {
op->emitRemark() << "skipping legalization due to "
<< xla_op_or.status().ToString();
return success();
}
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
tensorflow::DataType dtype;
status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to " << status.ToString();
return success();
}
auto expression = tensorflow::XlaExpression::XlaOp(xla_op, dtype);
expressions.push_back(expression);
if (!tensorflow::DataTypeCanUseMemcpy(dtype)) {
op->emitRemark() << "skipping legalization due to unsupported type "
<< operand.getType();
return success();
}
auto shape_or = expression.GetShape();
if (!shape_or.ok()) {
op->emitRemark() << "failed to get shape for expression. "
<< expression.HumanString();
return success();
}
tensors.emplace_back(
device_->GetAllocator(tensorflow::AllocatorAttributes()), dtype,
shape_or.ValueOrDie());
tensorflow::Tensor& tensor = tensors.back();
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expression,
&tensor);
inputs.emplace_back(&tensor);
}
params_.inputs = &inputs;
params_.op_kernel = op_kernel.get();
llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
op->getNumResults());
params_.output_attr_array = output_attr.data();
hlo_builder_.setInsertionPoint(op);
hlo_builder_.SetLocation(op->getLoc());
// Execute the kernel.
tensorflow::OpKernelContext op_context(&params_, op->getNumResults());
device_->Compute(params_.op_kernel, &op_context);
if (!op_context.status().ok()) {
op->emitRemark() << "compilation to HLO failed: "
<< op_context.status().ToString();
return success();
}
// Replace uses of old results using the corresponding value after the
// lowering.
for (int i = 0, e = op->getNumResults(); i < e; i++) {
tensorflow::Tensor* output = op_context.mutable_output(i);
const tensorflow::XlaExpression* expr =
tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output);
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
return op->emitError(
"expects XlaExpression of kind kXlaOp in compiled output");
auto value = hlo_builder_.GetValue(expr->handle());
op->getResult(i).replaceAllUsesWith(value);
}
op->erase();
return success();
}
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
explicit LegalizeTF(llvm::StringRef device_type) {
device_type_ = device_type.str();
}
LegalizeTF(const LegalizeTF&) {}
void runOnFunction() override {
if (failed(FuncLegalizer::Legalize(getFunction(), device_type_)))
signalPassFailure();
}
private:
// TODO(hinsu): Support finer grained device type assignment instead of a
// global device type for all TensorFlow ops.
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc("XLA device type for execution of TensorFlow ops. "
"Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")};
};
static PassRegistration<LegalizeTF> pass(
"xla-legalize-tf-with-tf2xla",
"Legalize from TensorFlow to the HLO dialect using tf2xla kernels");
} // end namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type) {
return std::make_unique<LegalizeTF>(device_type);
}
} // end namespace xla_hlo
} // end namespace mlir