blob: 3e13e8494c4b4ecb508b79ed18b16750bab25b23 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <type_traits>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
namespace mlir {
namespace TFTPU {
// NOLINTNEXTLINE
static llvm::cl::opt<bool> tpu_compile_metadata_debug(
"tpu_compile_metadata_debug",
llvm::cl::desc("Serialize TPUCompileMetadataProto metadata in "
"'tf._TPUCompileMlir' op as a proto debug string"));
constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kNumCoresPerReplicaAttr[] = "num_cores_per_replica";
constexpr char kStepMarkerLocationAttr[] = "step_marker_location";
constexpr char kPaddingMapAttr[] = "padding_map";
// Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU
// jit-compile runtime ops.
//
// For example:
// %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster", func =
// @tpu_func}
// %2 = "tf.SomeOp"(%1)
//
// Would become following ops (unimportant attributes, types are omitted):
// %1 = "tf.Shape"(%0)
// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = "<Serialized @tpu_func>"}
// "tf.TPUCompileSucceededAssert"(%2#0)
// %3 = "tf.TPUExecute"(%0, %2#1)
// %4 = "tf.SomeOp"(%3)
namespace {
struct TPURewritePass : public ModulePass<TPURewritePass> {
void runOnModule() override;
};
// Recursively visits all attributes of `op` to find any Attribute of type
// `SymbolRefAttr`.
llvm::SmallVector<SymbolRefAttr, 8> GetAllSymbolRefAttrs(Operation* op) {
llvm::SmallVector<SymbolRefAttr, 8> symbol_ref_attrs;
llvm::SmallVector<Attribute, 8> worklist;
for (auto named_attr : op->getAttrs()) {
worklist.push_back(named_attr.second);
}
while (!worklist.empty()) {
Attribute attr = worklist.pop_back_val();
if (SymbolRefAttr symbol_ref_attr = attr.dyn_cast<SymbolRefAttr>()) {
// Found a SymbolRefAttr, add it to result list.
symbol_ref_attrs.push_back(symbol_ref_attr);
} else if (ArrayAttr array_attr = attr.dyn_cast<ArrayAttr>()) {
// Found an ArrayAttr, add its nested Attributes to worklist for further
// inspection.
worklist.append(array_attr.begin(), array_attr.end());
} else if (DictionaryAttr dict_attr = attr.dyn_cast<DictionaryAttr>()) {
// Found a DictionaryAttr, add its nested value Attributes to worklist for
// further inspection.
for (NamedAttribute named_attr : dict_attr.getValue()) {
worklist.push_back(named_attr.second);
}
}
}
return symbol_ref_attrs;
}
// Creates a new self-contained module that contains `entry_func` and all
// referenced functions in `entry_func`. entry_func is renamed to "main".
// Return value is serialized text formate of newly-created module.
std::string EncapsulateFuncAndSerialize(FuncOp entry_func) {
ModuleOp module = entry_func.getParentOfType<ModuleOp>();
llvm::SmallVector<FuncOp, 4> referenced({entry_func});
// Create a new module to hold func and all referenced functions.
OwningModuleRef module_for_func =
ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
ModuleManager module_manager(module_for_func.get());
while (!referenced.empty()) {
auto func = referenced.pop_back_val();
// Skip functions that have already been cloned into new module.
if (module_manager.lookupSymbol<FuncOp>(func.getName())) continue;
// Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
// all found FuncOps to new_module to make sure new_module is
// self-contained.
func.walk([&](Operation* op) {
for (auto symbol_ref_attr : GetAllSymbolRefAttrs(op)) {
FuncOp referenced_func =
module.lookupSymbol<FuncOp>(symbol_ref_attr.getValue());
// Skip Symbols that do not map to a function.
if (!referenced_func) continue;
referenced.emplace_back(referenced_func);
}
});
auto clone = func.clone();
if (clone.getName() == entry_func.getName()) {
// We can simply change name of TPU program's main function because there
// should be no other reference to it.
clone.setName("main");
}
module_manager.insert(clone);
}
// Serialize module and return.
std::string txt_module;
{
llvm::raw_string_ostream os(txt_module);
module_for_func.get().print(os);
}
return txt_module;
}
// Creates a missing attribute error message.
std::string CreateMissingAttributeMsg(llvm::StringRef attribute) {
return llvm::formatv("requires attribute '{0}'", attribute).str();
}
// Populates a TPUCompileMetadataProto from attributes of a
// `tf_device::LaunchFuncOp`. If any necessary attributes are missing from the
// op, a failure will be returned.
// TODO(lyandy): Propagate and support device assignment.
// TODO(lyandy): Support session handle and guaranteed consts.
LogicalResult SetMetadataProtoFromLaunchFuncOp(
tf_device::LaunchFuncOp op,
tensorflow::tpu::TPUCompileMetadataProto* metadata) {
auto num_replicas = op.getAttrOfType<IntegerAttr>(kNumReplicasAttr);
if (!num_replicas)
return op.emitOpError(CreateMissingAttributeMsg(kNumReplicasAttr));
metadata->set_num_replicas(num_replicas.getInt());
auto num_cores_per_replica =
op.getAttrOfType<IntegerAttr>(kNumCoresPerReplicaAttr);
if (!num_cores_per_replica)
return op.emitOpError(CreateMissingAttributeMsg(kNumCoresPerReplicaAttr));
metadata->set_num_cores_per_replica(num_cores_per_replica.getInt());
auto step_marker_location =
op.getAttrOfType<StringAttr>(kStepMarkerLocationAttr);
if (!step_marker_location)
return op.emitOpError(CreateMissingAttributeMsg(kStepMarkerLocationAttr));
// Default to `STEP_MARK_AT_ENTRY` for step marker location if attribute is
// empty.
xla::DebugOptions::StepMarkerLocation location =
xla::DebugOptions::STEP_MARK_AT_ENTRY;
if (!step_marker_location.getValue().empty() &&
!xla::DebugOptions::StepMarkerLocation_Parse(
step_marker_location.getValue(), &location))
return op.emitOpError(llvm::formatv("bad '{0}' attribute with value '{1}'",
kStepMarkerLocationAttr,
step_marker_location.getValue()));
metadata->set_step_marker_location(location);
auto padding_map = op.getAttrOfType<ArrayAttr>(kPaddingMapAttr);
if (!padding_map)
return op.emitOpError(CreateMissingAttributeMsg(kPaddingMapAttr));
for (const auto padding_and_idx : llvm::enumerate(padding_map)) {
auto& padding_attr = padding_and_idx.value();
auto padding_attr_str = padding_attr.dyn_cast<StringAttr>();
if (!padding_attr_str)
return op.emitOpError(
llvm::formatv("bad '{0}' attribute at index {1}, not a string",
kPaddingMapAttr, padding_and_idx.index()));
tensorflow::tpu::PaddingMap* padding =
metadata->mutable_padding_maps()->Add();
if (!padding->ParseFromString(padding_attr_str.getValue()))
return op.emitOpError(llvm::formatv(
"bad '{0}' attribute at index {1} with value '{2}'", kPaddingMapAttr,
padding_and_idx.index(), padding_attr_str.getValue()));
}
// Set args metadata in proto.
for (auto operand_type_and_idx : llvm::enumerate(op.getOperandTypes())) {
Type operand_type = operand_type_and_idx.value();
tensorflow::tpu::TPUCompileMetadataProto::Arg* arg = metadata->add_args();
tensorflow::DataType dtype;
tensorflow::Status status =
tensorflow::ConvertToDataType(operand_type, &dtype);
if (!status.ok())
return op.emitOpError(
llvm::formatv("failed to determine operand type at index {0}: {1}",
operand_type_and_idx.index(), status.error_message()));
arg->set_dtype(dtype);
// TODO(lyandy): Support other arg kinds.
if (dtype == tensorflow::DT_RESOURCE)
arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::VARIABLE);
else
arg->set_kind(tensorflow::tpu::TPUCompileMetadataProto::Arg::PARAMETER);
// Populate argument shapes.
*arg->mutable_shape() = tensorflow::TensorShapeProto();
if (auto ranked_tensor_type = operand_type.dyn_cast<RankedTensorType>()) {
tensorflow::TensorShapeProto shape_proto;
ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto);
*arg->mutable_shape() = std::move(shape_proto);
} else {
arg->mutable_shape()->set_unknown_rank(true);
}
// TODO(lyandy): Determine proper sharding of args once topology and devices
// are propagated to the pass.
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::MAXIMAL);
sharding.add_tile_assignment_dimensions(1);
sharding.add_tile_assignment_devices(0);
*arg->mutable_sharding() = std::move(sharding);
}
// Set retvals metadata in proto.
// TODO(lyandy): Determine proper sharding of retvals once topology and
// devices is propagated to the pass.
for (int i = 0; i < op.getNumResults(); ++i) {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::MAXIMAL);
sharding.add_tile_assignment_dimensions(1);
sharding.add_tile_assignment_devices(0);
*metadata->add_retvals()->mutable_sharding() = std::move(sharding);
}
return success();
}
// Create a `tf._TPUCompileMlir` that contains a MLIR module that is
// functionally equivalent to the function referenced by launch_func.
Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func,
OpBuilder* builder) {
// TODO(b/139377366): Use tf_tpu.compile build method when it is defined.
OperationState compile_op_state(launch_func.getLoc(), "tf._TPUCompileMlir");
// Set metadata from attributes.
tensorflow::tpu::TPUCompileMetadataProto metadata;
if (failed(SetMetadataProtoFromLaunchFuncOp(launch_func, &metadata)))
return nullptr;
std::string txt_metadata;
if (tpu_compile_metadata_debug)
txt_metadata = metadata.DebugString();
else
metadata.SerializeToString(&txt_metadata);
compile_op_state.addAttribute("metadata",
builder->getStringAttr(txt_metadata));
// Build a shape op for each input to launch_func.
// TODO(b/139377366): When shape inference is ready, we can use compile time
// shape inference to get inputs that have static shapes and only use shape
// ops for the rest.
llvm::SmallVector<Value*, 4> compile_op_operands;
compile_op_operands.reserve(launch_func.getNumOperands());
for (auto operand_and_idx : llvm::enumerate(launch_func.getOperands())) {
// Skip adding shape op for operands that have static shapes.
tensorflow::PartialTensorShape shape(
metadata.args(operand_and_idx.index()).shape());
if (shape.IsFullyDefined()) continue;
auto shape_op = builder->create<TF::ShapeOp>(
launch_func.getLoc(),
builder->getTensorType({-1}, builder->getIntegerType(64)),
operand_and_idx.value());
compile_op_operands.emplace_back(shape_op.getResult());
}
compile_op_state.addOperands(compile_op_operands);
compile_op_state.addAttribute(
"NumDynamicShapes",
builder->getI64IntegerAttr(compile_op_operands.size()));
SymbolRefAttr func_attr = launch_func.getAttrOfType<SymbolRefAttr>("func");
if (!func_attr) {
launch_func.emitOpError("does not have `func` attribute");
return nullptr;
}
FuncOp func = launch_func.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
func_attr.getValue());
std::string txt_module = EncapsulateFuncAndSerialize(func);
compile_op_state.addAttribute("mlir_module",
builder->getStringAttr(txt_module));
// Result #0 is a string indicating whether compilation is successful or not.
compile_op_state.addTypes(
builder->getTensorType({}, builder->getType<TF::StringType>()));
// Result #1 is key to look up executable binary in compilation cache.
compile_op_state.addTypes(
builder->getTensorType({}, builder->getType<TF::StringType>()));
return builder->createOperation(compile_op_state);
}
// Creates a `tf.TPUExecute` op that executes TPU program generated by
// `compile_op`.
Operation* BuildExecuteOp(Operation* compile_op,
tf_device::LaunchFuncOp launch_func,
OpBuilder* builder) {
// TODO(b/139377366): Use tf.TPUExecute build method when it is defined.
OperationState execute_op_state(launch_func.getLoc(), "tf.TPUExecute");
// TPUExecute inherits all launch_func inputs.
llvm::SmallVector<Value*, 4> tensor_inputs(launch_func.getOperands());
execute_op_state.addOperands(tensor_inputs);
// TODO(b/139377366): Need to snapshot all resource variable inputs in
// follow-up CLs.
// Set Targs of TPUExecute according to launch_func input types.
llvm::SmallVector<Attribute, 4> tensor_input_types_attrs;
tensor_input_types_attrs.reserve(tensor_inputs.size());
for (Value* v : tensor_inputs) {
tensor_input_types_attrs.emplace_back(builder->getTypeAttr(v->getType()));
}
execute_op_state.addAttribute(
"Targs", builder->getArrayAttr(tensor_input_types_attrs));
// TPUExecute takes an additional input for compilation cache key.
execute_op_state.addOperands(compile_op->getResult(1));
// Set Tresults of TPUExecute according to launch_func results types.
llvm::SmallVector<Attribute, 4> output_types_attrs;
output_types_attrs.reserve(launch_func.getNumResults());
for (Value* v : launch_func.getResults()) {
output_types_attrs.emplace_back(builder->getTypeAttr(v->getType()));
}
execute_op_state.addAttribute("Tresults",
builder->getArrayAttr(output_types_attrs));
// TPUExecute has same output types as launch_func.
llvm::SmallVector<Type, 4> output_types(launch_func.getResultTypes());
execute_op_state.addTypes(output_types);
return builder->createOperation(execute_op_state);
}
// Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation
// status of `compile_op` to check whether compilation is successful.
void BuildTPUCompileSucceededAssertOp(Operation* compile_op,
OpBuilder* builder) {
OperationState assert_op_state(compile_op->getLoc(),
"tf.TPUCompileSucceededAssert");
assert_op_state.addOperands(compile_op->getResult(0));
builder->createOperation(assert_op_state);
}
// Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime
// Operations that jit-compiles and executes function in `tf_device.launch_func`
// on TPU. If it is not possible to rewrite the operation, a failure will be
// returned.
LogicalResult Rewrite(tf_device::LaunchFuncOp launch_func, OpBuilder* builder) {
// Skip non-tpu device launch_func.
auto replicate_attr = launch_func.getAttrOfType<StringAttr>("_tpu_replicate");
if (!replicate_attr) return success();
builder->setInsertionPoint(launch_func);
Operation* compile_op = BuildCompileOp(launch_func, builder);
if (!compile_op) return failure();
// After rewrite, find if there is a TPUCompilationResultOp in the block with
// the same _tpu_replicate attribute and replace it with the result of the
// compile op. This op is used as a placeholder to hook during graph creation
// the other ops that are intended to consume the compile result.
Block* block = launch_func.getOperation()->getBlock();
for (auto compile_result_op : block->getOps<TF::TPUCompilationResultOp>())
compile_result_op.output()->replaceAllUsesWith(compile_op->getResult(0));
BuildTPUCompileSucceededAssertOp(compile_op, builder);
// TODO(ycao): Right now we only support single-core case. The right thing to
// do is to read from launch_func attributes to determine how many execute
// ops to build.
Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder);
launch_func.replaceAllUsesWith(execute_op);
launch_func.erase();
return success();
}
void TPURewritePass::runOnModule() {
OpBuilder builder(&getContext());
auto result = getModule().walk([&](tf_device::LaunchFuncOp op) {
if (failed(Rewrite(op, &builder))) return WalkResult::interrupt();
return WalkResult::advance();
});
if (result.wasInterrupted()) {
signalPassFailure();
return;
}
// Eliminate TPUReplicatedInput and TPUReplicatedOutput now that the rewrite
// is complete.
getModule().walk([&](Operation* op) {
auto op_name = op->getName().getStringRef();
if (op_name != "tf.TPUReplicatedInput" &&
op_name != "tf.TPUReplicatedOutput")
return;
op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
op->erase();
});
// TODO(b/139377366): Remove functions that are no longer needed.
}
} // namespace
std::unique_ptr<OpPassBase<ModuleOp>> CreateTPURewritePass() {
return std::make_unique<TPURewritePass>();
}
static PassRegistration<TPURewritePass> pass(
"tf-tpu-rewrite",
"Rewriting `tf_device.launch_func` on TPUs into TPU runtime ops");
} // namespace TFTPU
} // namespace mlir