blob: da2a4179c4ed00070343ccfafa5cb0ed1fa5a532 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
#include <numeric>
#include <string>
#include <tuple>
#include <type_traits>
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/InliningUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
namespace mlir {
namespace TF {
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
//===----------------------------------------------------------------------===//
// _XlaHostComputeOp
//===----------------------------------------------------------------------===//
// This verifies that `_XlaHostComputeMlirOp` has a well-formed
// `host_mlir_module` attribute.
// For other attributes, there is no additional verification beyond the default.
LogicalResult _XlaHostComputeMlirOp::verify() {
_XlaHostComputeMlirOp op = *this;
// Extract the module and function.
StringRef host_module = op.host_mlir_module();
if (host_module.empty()) return success();
mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
tensorflow::Status status = tensorflow::DeserializeMlirModule(
host_module.str(), op->getContext(), &module_for_func);
if (!status.ok()) {
return op.emitError()
<< "attribute 'host_mlir_module' can not be deserialized. "
<< status.error_message();
}
func::FuncOp func = module_for_func->lookupSymbol<func::FuncOp>("host_func");
if (!func)
return op.emitError()
<< "serialized module in attribute 'host_mlir_module' does not "
"contain 'host_func' function.";
if (op->getNumOperands() != func.getFunctionType().getNumInputs())
return op.emitError()
<< "'host_func' has " << func.getFunctionType().getNumInputs()
<< " inputs and '_XlaHostComputeMlir' has " << op->getNumOperands()
<< " operands. Number of operands/inputs should be the same.";
if (op->getNumResults() != func.getFunctionType().getNumResults())
return op.emitError() << "'host_func' has "
<< func.getFunctionType().getNumResults()
<< " results and '_XlaHostComputeMlir' has "
<< op->getNumResults()
<< " results. Number of results should be the same.";
return success();
}
func::FuncOp _XlaHostComputeMlirOp::GetHostFunc(
mlir::OwningOpRef<mlir::ModuleOp>* mlir_module) {
if (!tensorflow::DeserializeMlirModule(host_mlir_module().str(),
this->getContext(), mlir_module)
.ok())
return nullptr;
return (*mlir_module)->lookupSymbol<func::FuncOp>("host_func");
}
//===----------------------------------------------------------------------===//
// XLA Send/Recv ops
//===----------------------------------------------------------------------===//
// For XLA Send/Recv ops the key corresponds to the resource instance.
std::string _XlaRecvAtHostOp::GetResourceInstanceStr() { return key().str(); }
std::string _XlaRecvAtHostV2Op::GetResourceInstanceStr() { return key().str(); }
std::string _XlaSendFromHostOp::GetResourceInstanceStr() { return key().str(); }
std::string _XlaSendFromHostV2Op::GetResourceInstanceStr() {
return key().str();
}
} // namespace TF
} // namespace mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc"