blob: b8f21209c768a2d5cbf3c457c33f97ed6bb7c5b4 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_expression.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_properties.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace mlir {
namespace mhlo {
bool IsOpAllowedTf2XlaFallback(Operation* op) {
// Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
// building valid MLIR using MlirHloBuilder.
// TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
// all tf2xla kernels.
// clang-format off
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
TypeID::get<TF::AcoshOp>(),
TypeID::get<TF::AcosOp>(),
TypeID::get<TF::AddNOp>(),
TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(),
TypeID::get<TF::AdjustContrastv2Op>(),
TypeID::get<TF::AdjustHueOp>(),
TypeID::get<TF::AdjustSaturationOp>(),
TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
TypeID::get<TF::AsinhOp>(),
TypeID::get<TF::AsinOp>(),
TypeID::get<TF::Atan2Op>(),
TypeID::get<TF::AtanhOp>(),
TypeID::get<TF::BatchMatMulV2Op>(),
TypeID::get<TF::BatchMatMulV3Op>(),
TypeID::get<TF::BatchToSpaceOp>(),
TypeID::get<TF::BesselI0eOp>(),
TypeID::get<TF::BesselI1eOp>(),
TypeID::get<TF::BetaincOp>(),
TypeID::get<TF::BiasAddOp>(),
TypeID::get<TF::BitwiseAndOp>(),
TypeID::get<TF::BitwiseOrOp>(),
TypeID::get<TF::BitwiseXorOp>(),
TypeID::get<TF::BucketizeOp>(),
TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::CholeskyOp>(),
TypeID::get<TF::ComplexAbsOp>(),
TypeID::get<TF::ConjugateTransposeOp>(),
TypeID::get<TF::CoshOp>(),
TypeID::get<TF::CrossOp>(),
TypeID::get<TF::DataFormatDimMapOp>(),
TypeID::get<TF::DataFormatVecPermuteOp>(),
TypeID::get<TF::DepthToSpaceOp>(),
TypeID::get<TF::DepthwiseConv2dNativeBackpropFilterOp>(),
TypeID::get<TF::DepthwiseConv2dNativeBackpropInputOp>(),
TypeID::get<TF::DiagOp>(),
TypeID::get<TF::DigammaOp>(),
TypeID::get<TF::DivNoNanOp>(),
TypeID::get<TF::EluGradOp>(),
TypeID::get<TF::EluOp>(),
TypeID::get<TF::EnsureShapeOp>(),
TypeID::get<TF::EqualOp>(),
TypeID::get<TF::ErfcOp>(),
TypeID::get<TF::ErfinvOp>(),
TypeID::get<TF::ErfOp>(),
TypeID::get<TF::ExtractImagePatchesOp>(),
TypeID::get<TF::FFT2DOp>(),
TypeID::get<TF::FFT3DOp>(),
TypeID::get<TF::FFTOp>(),
TypeID::get<TF::FakeParamOp>(),
TypeID::get<TF::FakeQuantWithMinMaxArgsGradientOp>(),
TypeID::get<TF::FakeQuantWithMinMaxVarsGradientOp>(),
TypeID::get<TF::FloorDivOp>(),
TypeID::get<TF::FloorModOp>(),
TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::HSVToRGBOp>(),
TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IRFFT2DOp>(),
TypeID::get<TF::IRFFT3DOp>(),
TypeID::get<TF::IgammaOp>(),
TypeID::get<TF::IgammacOp>(),
TypeID::get<TF::IgammaGradAOp>(),
TypeID::get<TF::InplaceAddOp>(),
TypeID::get<TF::InTopKV2Op>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::KthOrderStatisticOp>(),
TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(),
TypeID::get<TF::LeakyReluOp>(),
TypeID::get<TF::LeftShiftOp>(),
TypeID::get<TF::LessOp>(),
TypeID::get<TF::ListDiffOp>(),
TypeID::get<TF::LogicalAndOp>(),
TypeID::get<TF::LogicalNotOp>(),
TypeID::get<TF::LogOp>(),
TypeID::get<TF::LowerBoundOp>(),
TypeID::get<TF::MakeUniqueOp>(),
TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MatrixDiagV3Op>(),
TypeID::get<TF::MatrixInverseOp>(),
TypeID::get<TF::MatrixSetDiagV3Op>(),
TypeID::get<TF::MatrixSolveOp>(),
TypeID::get<TF::MatrixTriangularSolveOp>(),
TypeID::get<TF::MaxPool3DGradGradOp>(),
TypeID::get<TF::MaxPoolGradGradOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MirrorPadGradOp>(),
TypeID::get<TF::MulOp>(),
TypeID::get<TF::MultinomialOp>(),
TypeID::get<TF::NdtriOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NextAfterOp>(),
TypeID::get<TF::NonMaxSuppressionV4Op>(),
TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::PadOp>(),
TypeID::get<TF::ParameterizedTruncatedNormalOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(),
TypeID::get<TF::PolygammaOp>(),
TypeID::get<TF::PopulationCountOp>(),
TypeID::get<TF::PowOp>(),
// TODO(hinsu): Canonicalize QuantizeAndDequantize and
// QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting
// attributes to operands.
TypeID::get<TF::QuantizeAndDequantizeOp>(),
TypeID::get<TF::QuantizeAndDequantizeV2Op>(),
TypeID::get<TF::QuantizeAndDequantizeV3Op>(),
TypeID::get<TF::QuantizeAndDequantizeV4Op>(),
TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RGBToHSVOp>(),
TypeID::get<TF::RandomUniformIntOp>(),
TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ResizeBilinearOp>(),
TypeID::get<TF::ResizeBilinearGradOp>(),
TypeID::get<TF::ResizeNearestNeighborOp>(),
TypeID::get<TF::ResizeNearestNeighborGradOp>(),
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(),
TypeID::get<TF::RollOp>(),
TypeID::get<TF::RoundOp>(),
TypeID::get<TF::SelectV2Op>(),
TypeID::get<TF::SelfAdjointEigV2Op>(),
TypeID::get<TF::SeluGradOp>(),
TypeID::get<TF::SeluOp>(),
TypeID::get<TF::SigmoidGradOp>(),
TypeID::get<TF::SinhOp>(),
TypeID::get<TF::SinOp>(),
TypeID::get<TF::SoftplusGradOp>(),
TypeID::get<TF::SoftsignGradOp>(),
TypeID::get<TF::SoftsignOp>(),
TypeID::get<TF::SpaceToBatchNDOp>(),
TypeID::get<TF::SpaceToBatchOp>(),
TypeID::get<TF::SpaceToDepthOp>(),
TypeID::get<TF::SparseToDenseOp>(),
TypeID::get<TF::SquareOp>(),
TypeID::get<TF::StatelessMultinomialOp>(),
TypeID::get<TF::StatelessRandomGetAlgOp>(),
TypeID::get<TF::StatelessRandomGetKeyCounterOp>(),
TypeID::get<TF::StatelessRandomGetKeyCounterAlgOp>(),
TypeID::get<TF::StatelessRandomNormalOp>(),
TypeID::get<TF::StatelessRandomNormalV2Op>(),
TypeID::get<TF::StatelessRandomUniformOp>(),
TypeID::get<TF::StatelessRandomUniformFullIntOp>(),
TypeID::get<TF::StatelessRandomUniformFullIntV2Op>(),
TypeID::get<TF::StatelessRandomUniformV2Op>(),
TypeID::get<TF::StatelessRandomUniformIntOp>(),
TypeID::get<TF::StatelessRandomUniformIntV2Op>(),
TypeID::get<TF::StatelessTruncatedNormalOp>(),
TypeID::get<TF::StatelessTruncatedNormalV2Op>(),
TypeID::get<TF::SubOp>(),
TypeID::get<TF::SvdOp>(),
TypeID::get<TF::TanOp>(),
TypeID::get<TF::TensorScatterAddOp>(),
TypeID::get<TF::TensorScatterSubOp>(),
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
TypeID::get<TF::TopKUniqueOp>(),
TypeID::get<TF::TopKWithUniqueOp>(),
TypeID::get<TF::TransposeOp>(),
TypeID::get<TF::TridiagonalSolveOp>(),
TypeID::get<TF::TruncateDivOp>(),
TypeID::get<TF::TruncatedNormalOp>(),
TypeID::get<TF::TruncateModOp>(),
TypeID::get<TF::UnpackOp>(),
TypeID::get<TF::UpperBoundOp>(),
TypeID::get<TF::XlaBroadcastHelperOp>(),
TypeID::get<TF::XlaConvOp>(),
TypeID::get<TF::XlaConvV2Op>(),
TypeID::get<TF::XlaDotOp>(),
TypeID::get<TF::XlaDotV2Op>(),
TypeID::get<TF::XlaDynamicSliceOp>(),
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
TypeID::get<TF::XlaEinsumOp>(),
TypeID::get<TF::XlaKeyValueSortOp>(),
TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::XlaSetDynamicDimensionSizeOp>(),
TypeID::get<TF::XlaSortOp>(),
TypeID::get<TF::XlaSvdOp>(),
};
// clang-format on
auto* abstractOp = op->getAbstractOperation();
if (!abstractOp) return false;
return ops.count(abstractOp->typeID);
}
// List ops that should use MLIR legalization in the case of prefer_tf2xla. All
// other ops not in this list should use the old bridge's XlaOpKernel
// legalization.
const llvm::DenseSet<mlir::TypeID>& MlirLegalizedUnderPreferTf2XlaSet() {
// The static variable is a pointer in order to avoid destruction upon thread
// termination.
// clang-format off
static const llvm::DenseSet<mlir::TypeID>* ops =
new llvm::DenseSet<mlir::TypeID>{
// Ops that are legalized in the old bridge using MlirXlaOpKernel
TypeID::get<TF::AbsOp>(),
};
// clang-format on
return *ops;
}
bool IsOpMlirLegalizedUnderPreferTf2Xla(Operation* op) {
auto mlir_ops = MlirLegalizedUnderPreferTf2XlaSet();
auto* abstractOp = op->getAbstractOperation();
if (!abstractOp) return false;
return mlir_ops.count(abstractOp->typeID);
}
namespace {
template <typename T, size_t N>
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
const std::string& device_type) {
// Register compilation kernels for all registered XLA backends.
tensorflow::XlaOpRegistry::RegisterCompilationKernels();
auto device = absl::make_unique<tensorflow::XlaCompilationDevice>(
tensorflow::SessionOptions(), tensorflow::DeviceType(device_type));
return absl::make_unique<tensorflow::StaticDeviceMgr>(std::move(device));
}
class Tf2XlaRewriter {
public:
static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter,
const std::string& device_type) {
Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type);
return tf2xla_rewriter.LegalizeOp();
}
private:
Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter,
const std::string& device_type)
: op_(op),
device_type_(device_type),
rewriter_(rewriter),
hlo_builder_(op->getName().getStringRef().str(), rewriter_,
op->getLoc()),
context_(nullptr) {}
~Tf2XlaRewriter() {
if (context_) context_->Unref();
}
// Prepares OpKernelContext params common to all the ops.
// Emits an error on failure.
LogicalResult PrepareParams();
// Tries to legalize the specified TensorFlow op, if supported.
//
// Emits an error and returns failure if an error is encountered during
// conversion. Note that success return value doesn't mean successful
// legalization.
LogicalResult LegalizeOp();
// Converts the given operand to expression of kind kConstant or kXlaOp.
// Emits a remark and returns expression of kind kInvalid on failure.
tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op);
Operation* op_;
std::string device_type_;
PatternRewriter& rewriter_;
::xla::MlirHloBuilder hlo_builder_;
tensorflow::OpOrArgLocNameMapper name_mapper_;
tensorflow::XlaContext* context_; // Ref-counted.
std::unique_ptr<tensorflow::StaticDeviceMgr> device_mgr_;
tensorflow::Device* device_; // Owned by device_mgr_;
std::unique_ptr<tensorflow::ScopedStepContainer> step_container_;
std::unique_ptr<tensorflow::FunctionLibraryDefinition> flib_def_;
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr_;
tensorflow::OpKernelContext::Params params_;
};
LogicalResult Tf2XlaRewriter::PrepareParams() {
// XlaCompiler within the context is only used by the functional ops to
// compile functions. We are not handling those at the moment so XlaCompiler
// is not required.
context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_,
/*graph=*/nullptr);
context_->Ref();
device_mgr_ = CreateDeviceMgr(device_type_);
if (!device_mgr_) return failure();
// Type of params_.device is DeviceBase* so store it as Device* to access
// derived class method.
device_ = device_mgr_->ListDevices().front();
params_.device = device_;
params_.resource_manager = device_->resource_manager();
// Resources are cleared at the time of device manager destruction so pass
// no-op cleanup function.
auto cleanup = [](const std::string& name) {};
// Use step_id zero as we only have a single context concurrently and
// concurrently running each of the MLIR functions create a new device.
step_container_ = absl::make_unique<tensorflow::ScopedStepContainer>(
/*step_id=*/0, cleanup);
tensorflow::Status status = step_container_->Create(
device_->resource_manager(),
tensorflow::XlaContext::kXlaContextResourceName, context_);
if (!status.ok()) {
return emitError(op_->getLoc())
<< "failed to create XlaContext resource: " << status.ToString();
}
params_.step_container = step_container_.get();
tensorflow::StatusOr<int64_t> version_or =
tensorflow::GetTfGraphProducerVersion(
op_->getParentOfType<mlir::ModuleOp>());
if (!version_or.ok()) {
return emitError(op_->getLoc()) << version_or.status().ToString();
}
flib_def_ = absl::make_unique<tensorflow::FunctionLibraryDefinition>(
tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary());
pflr_ = absl::make_unique<tensorflow::ProcessFunctionLibraryRuntime>(
device_mgr_.get(), tensorflow::Env::Default(), /*config=*/nullptr,
version_or.ValueOrDie(), flib_def_.get(), tensorflow::OptimizerOptions());
params_.function_library = pflr_->GetFLR(device_->name());
return success();
}
LogicalResult Tf2XlaRewriter::LegalizeOp() {
// Only static shaped operands are supported in XLA builders for now.
for (Type ty : op_->getOperandTypes()) {
auto ranked_ty = ty.dyn_cast<ShapedType>();
if (!ranked_ty || !ranked_ty.hasStaticShape()) {
return op_->emitRemark()
<< "lowering requires static shaped tensor operands";
}
}
for (const auto& attr : op_->getAttrs()) {
if (attr.second.isa<SymbolRefAttr>()) {
return op_->emitRemark()
<< "ops with symbol references are not supported";
}
}
auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef(
op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true);
if (!nodedef_or.ok()) {
return op_->emitRemark() << "failed to convert op to NodeDef: "
<< nodedef_or.status().ToString();
}
if (failed(PrepareParams())) return failure();
std::shared_ptr<const tensorflow::NodeProperties> props;
tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef(
*nodedef_or.ValueOrDie(),
params_.function_library->GetFunctionLibraryDefinition(), &props);
if (!status.ok()) {
return op_->emitRemark()
<< "failed to create NodeProperties: " << status.ToString();
}
tensorflow::OpKernel* op_kernel_raw;
status = params_.function_library->CreateKernel(props, &op_kernel_raw);
if (!status.ok()) {
return op_->emitRemark()
<< "failed to create tf2xla kernel: " << status.ToString();
}
// Transfer ownership of the kernel to a local smart pointer.
auto op_kernel = absl::WrapUnique(op_kernel_raw);
std::vector<int> required_constants;
status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs(
*op_kernel, &required_constants);
if (!status.ok()) {
return op_->emitRemark()
<< "failed to compute required constants: " << status.ToString();
}
llvm::SmallDenseSet<int, 4> required_consts;
required_consts.insert(required_constants.begin(), required_constants.end());
// TensorValue in inputs are backed by tensors which in turn depend on
// expressions. So, pre-allocate them to the required size.
InlinedVector<tensorflow::XlaExpression, 4> expressions;
InlinedVector<tensorflow::Tensor, 4> tensors;
InlinedVector<tensorflow::TensorValue, 4> inputs;
expressions.reserve(op_->getNumOperands());
tensors.reserve(op_->getNumOperands());
inputs.reserve(op_->getNumOperands());
// Prepare the list of Tensor inputs for the kernel.
for (auto it : llvm::enumerate(op_->getOperands())) {
Value operand = it.value();
size_t idx = it.index();
tensorflow::XlaExpression expr = GetExprForOperand(operand, op_);
tensorflow::XlaExpression::Kind kind = expr.kind();
if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
if (required_consts.count(idx) &&
kind != tensorflow::XlaExpression::Kind::kConstant) {
return op_->emitRemark()
<< "lowering requires operand #" << idx << " to be a constant";
}
expressions.push_back(expr);
if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
return op_->emitRemark()
<< "skipping legalization due to unsupported type "
<< operand.getType();
}
auto shape_or = expr.GetShape();
if (!shape_or.ok()) {
return op_->emitRemark()
<< "failed to get shape for expression. " << expr.HumanString();
}
tensors.emplace_back(
device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
shape_or.ValueOrDie());
tensorflow::Tensor& tensor = tensors.back();
tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
inputs.emplace_back(&tensor);
}
params_.inputs = &inputs;
params_.op_kernel = op_kernel.get();
llvm::SmallVector<tensorflow::AllocatorAttributes, 4> output_attr(
op_->getNumResults());
params_.output_attr_array = output_attr.data();
hlo_builder_.setInsertionPoint(op_);
hlo_builder_.SetLocation(op_->getLoc());
// Execute the kernel.
tensorflow::OpKernelContext op_context(&params_, op_->getNumResults());
device_->Compute(params_.op_kernel, &op_context);
if (!op_context.status().ok()) {
return op_->emitRemark()
<< "compilation to HLO failed: " << op_context.status().ToString();
}
// Replace uses of old results using the corresponding value after the
// lowering.
llvm::SmallVector<Value, 2> values;
values.reserve(op_->getNumResults());
for (int i = 0, e = op_->getNumResults(); i < e; i++) {
tensorflow::Tensor* output = op_context.mutable_output(i);
const tensorflow::XlaExpression* expr =
tensorflow::XlaExpression::CastExpressionFromTensor(*output);
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
return op_->emitError(
"expects XlaExpression of kind kXlaOp in compiled output");
auto value = hlo_builder_.GetValue(expr->handle());
mlir::OpResult old_result = op_->getResult(i);
if (value.getType() != old_result.getType()) {
value = hlo_builder_.create<mlir::tensor::CastOp>(old_result.getType(),
value);
}
values.push_back(value);
}
rewriter_.replaceOp(op_, values);
return success();
}
tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand,
Operation* op) {
ElementsAttr const_attr;
auto defining_op = operand.getDefiningOp();
if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) {
tensorflow::Tensor tensor;
auto status = tensorflow::ConvertToTensor(const_attr, &tensor);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to failed const conversion"
<< status.ToString();
return tensorflow::XlaExpression::Invalid();
}
return tensorflow::XlaExpression::Constant(tensor);
}
// Skip this op if XLA doesn't support this operand type.
auto xla_op_or = hlo_builder_.MakeXlaOp(operand);
if (!xla_op_or.ok()) {
op->emitRemark() << "skipping legalization due to "
<< xla_op_or.status().ToString();
return tensorflow::XlaExpression::Invalid();
}
::xla::XlaOp xla_op = xla_op_or.ValueOrDie();
tensorflow::DataType dtype;
auto status = tensorflow::ConvertToDataType(operand.getType(), &dtype);
if (!status.ok()) {
op->emitRemark() << "skipping legalization due to " << status.ToString();
return tensorflow::XlaExpression::Invalid();
}
return tensorflow::XlaExpression::XlaOp(xla_op, dtype);
}
class Tf2XlaRewritePattern : public RewritePattern {
public:
explicit Tf2XlaRewritePattern(MLIRContext* ctx,
const std::string& device_type,
bool prefer_tf2xla)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx),
device_type_(device_type),
prefer_tf2xla_(prefer_tf2xla) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override {
if (prefer_tf2xla_) {
if (IsOpMlirLegalizedUnderPreferTf2Xla(op)) return failure();
} else {
if (!IsOpAllowedTf2XlaFallback(op)) return failure();
}
return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_);
}
private:
std::string device_type_;
bool prefer_tf2xla_;
};
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
public:
LegalizeTF() = default;
explicit LegalizeTF(llvm::StringRef device_type, bool prefer_tf2xla) {
device_type_ = device_type.str();
prefer_tf2xla_ = prefer_tf2xla;
}
LegalizeTF(const LegalizeTF&) {}
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());
patterns.insert<Tf2XlaRewritePattern>(&getContext(), device_type_,
prefer_tf2xla_);
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
signalPassFailure();
}
private:
// TODO(hinsu): Support finer grained device type assignment instead of a
// global device type for all TensorFlow ops.
Option<std::string> device_type_{
*this, "device-type",
llvm::cl::desc("XLA device type for execution of TensorFlow ops.")};
Option<bool> prefer_tf2xla_{
*this,
"prefer-tf2xla",
llvm::cl::desc("Enable legalization when it is not in the list of "
"MLIR-legalized ops."),
};
};
static PassRegistration<LegalizeTF> pass(
"xla-legalize-tf-with-tf2xla",
"Legalize from TensorFlow to the HLO dialect using tf2xla kernels");
} // end namespace
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns,
MLIRContext* ctx,
bool prefer_tf2xla) {
patterns.insert<Tf2XlaRewritePattern>(ctx, device_type.str(), prefer_tf2xla);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type, bool prefer_tf2xla) {
return std::make_unique<LegalizeTF>(device_type, prefer_tf2xla);
}
} // end namespace mhlo
} // end namespace mlir