blob: 04cb6038c468adea655bee184319cd326b1bfc01 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kDeviceAttr[] = "device";
constexpr char kHostFunctionAttr[] = "host_func";
constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
struct TPUExtractOutsideCompilation
: public TF::TPUExtractOutsideCompilationPassBase<
TPUExtractOutsideCompilation> {
void runOnOperation() override;
};
// Build a function containing `ops` with `inputs` and `outputs` using
// `builder`. The `ops` are cloned and modified to use the function arguments
// as inputs.
func::FuncOp BuildFunction(llvm::ArrayRef<Operation*> ops,
llvm::ArrayRef<Value> inputs,
llvm::ArrayRef<Value> outputs, OpBuilder* builder) {
llvm::SmallVector<Type, 4> operand_types;
operand_types.reserve(inputs.size());
for (Value v : inputs) operand_types.emplace_back(v.getType());
llvm::SmallVector<Type, 4> output_types;
output_types.reserve(outputs.size());
for (Value v : outputs) output_types.emplace_back(v.getType());
auto func_type = builder->getFunctionType(operand_types, output_types);
func::FuncOp outlined_func =
func::FuncOp::create(ops.front()->getLoc(), kHostFunctionAttr, func_type);
// Create function body.
Block* outlined_func_block = outlined_func.addEntryBlock();
// Clone the operations and remap the inputs to use the function arguments.
BlockAndValueMapping mapping;
mapping.map(inputs, outlined_func.getArguments());
builder->setInsertionPoint(outlined_func_block, outlined_func_block->begin());
for (Operation* op : ops) {
builder->clone(*op, mapping);
}
// Set the returned values to use cloned ops results using mapping.
llvm::SmallVector<Value, 4> results_after_mapping;
for (Value result : outputs) {
results_after_mapping.push_back(mapping.lookupOrDefault(result));
}
builder->create<func::ReturnOp>(ops.front()->getLoc(), results_after_mapping);
return outlined_func;
}
// Encapsulates `func` in a module and serializes that module.
// `serialized_func_module` is set to the serialized module.
void EncapsulateFuncAndSerialize(func::FuncOp func,
std::string* serialized_func_module) {
// Create a new module to hold func and all referenced functions.
OwningOpRef<mlir::ModuleOp> module_for_func =
ModuleOp::create(mlir::UnknownLoc::get(func.getContext()));
SymbolTable symbol_table(module_for_func.get());
symbol_table.insert(func);
*serialized_func_module =
tensorflow::SerializeMlirModule(module_for_func.get());
}
// Returns whether `op` or ops nested in `op` are outside compiled.
bool HasOutsideCompilationNested(Operation* op) {
return op
->walk([&](Operation* walked_op) {
if (op == walked_op) return WalkResult::advance();
if (walked_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
})
.wasInterrupted();
}
// Returns whether `op` or any ancestors of `op` are outside compiled.
bool HasOutsideCompilationAncestor(Operation* op) {
while (op) {
if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
return true;
}
op = op->getParentOp();
}
return false;
}
// Returns whether any ancestors of `op` are outside compiled.
bool HasOutsideCompilationAncestorExclusive(Operation* op) {
Operation* parent_op = op->getParentOp();
if (!parent_op) return false;
return HasOutsideCompilationAncestor(parent_op);
}
Operation* ApplyXlaHostTransferAttr(Operation* op, OpBuilder& builder) {
op->setAttr("_xla_has_host_transfer", builder.getBoolAttr(true));
return op;
}
// Creates a tf._XlaSendFromHost or tf._XlaSendFromHostV2 op. If device ordinal
// is present, a tf._XlaSendFromHostV2 op is created instead.
Operation* CreateSendFromHostOp(OpBuilder& builder, Location loc,
ValueRange inputs, Value compilation_key,
Value device_ordinal,
llvm::StringRef communication_key) {
if (device_ordinal)
return ApplyXlaHostTransferAttr(
builder.create<TF::_XlaSendFromHostV2Op>(
loc, inputs,
/*dynamic_key=*/compilation_key, device_ordinal,
builder.getStringAttr(communication_key)),
builder);
return ApplyXlaHostTransferAttr(
builder.create<TF::_XlaSendFromHostOp>(
loc, inputs,
/*dynamic_key=*/compilation_key,
builder.getStringAttr(communication_key),
/*device_ordinal=*/builder.getI64IntegerAttr(0)),
builder);
}
// Creates a tf._XlaRecvAtHost or tf._XlaRecvAtHostV2 op. If device ordinal is
// present, a tf._XlaRecvAtHostV2 op is created instead.
Operation* CreateRecvAtHostOp(OpBuilder& builder, Location loc,
TypeRange output_types, Value compilation_key,
Value device_ordinal,
llvm::StringRef communication_key) {
if (device_ordinal)
return ApplyXlaHostTransferAttr(
builder.create<TF::_XlaRecvAtHostV2Op>(
loc, output_types, /*dynamic_key=*/compilation_key, device_ordinal,
builder.getStringAttr(communication_key)),
builder);
return ApplyXlaHostTransferAttr(
builder.create<TF::_XlaRecvAtHostOp>(
loc, output_types, /*dynamic_key=*/compilation_key,
builder.getStringAttr(communication_key),
/*device_ordinal=*/builder.getI64IntegerAttr(0)),
builder);
}
// Clones an IfRegionOp 'if_region' and attributes and creates then/else regions
// with yield op and an empty block.
TF::IfRegionOp CloneEmptyIfWithPredicate(TF::IfRegionOp if_region,
OpBuilder& builder) {
auto host_side_if = builder.create<TF::IfRegionOp>(
if_region.getLoc(), llvm::SmallVector<Type, 4>{}, if_region.cond(),
if_region.is_stateless(), if_region._then_func_nameAttr(),
if_region._else_func_nameAttr());
// Create empty then branch region.
auto& then_branch = host_side_if.then_branch();
then_branch.push_back(new Block);
builder.setInsertionPointToEnd(&then_branch.front());
builder.create<TF::YieldOp>(if_region.getLoc(),
/*operands=*/ArrayRef<Value>{});
// Create empty else branch region.
auto& else_branch = host_side_if.else_branch();
else_branch.push_back(new Block);
builder.setInsertionPointToEnd(&else_branch.front());
builder.create<TF::YieldOp>(if_region.getLoc(),
/*operands=*/ArrayRef<Value>{});
return host_side_if;
}
// Creates a WhileRegionOp cond and body regions with yield op and
// an empty body.
TF::WhileRegionOp CloneEmptyWhile(bool is_stateless,
uint64_t parallel_iterations, Location loc,
OpBuilder& builder) {
auto host_side_while = builder.create<TF::WhileRegionOp>(
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
parallel_iterations, is_stateless, /*shape_invariant=*/false);
// Create empty else branch region.
auto& body = host_side_while.body();
body.push_back(new Block);
builder.setInsertionPointToEnd(&body.front());
builder.create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
return host_side_while;
}
// TODO(b/157054714): Use a better abstraction instead of
// _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
// Creates a compilation key as placeholder. A placeholder compilation cache key
// is created because it is a required input to _XlaRecvAtHost and
// _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU
// cluster that contains the outside compiled ops. This placeholder should be
// replaced by the TPU cluster _TPUCompileMlir in a subsequent pass.
TF::_TPUCompileMlirPlaceholderProgramKeyOp CreateCompilationKeyPlaceholder(
Location loc, OpBuilder& builder) {
auto result_type =
RankedTensorType::get({3}, builder.getType<TF::StringType>());
return builder.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
loc, /*program=*/result_type, llvm::ArrayRef<Value>{});
}
// Creates a `tf_device.launch` to wrap cluster ops.
tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
OpBuilder& builder, Operation* loc_op, llvm::StringRef host_device) {
// An empty string placeholder is used for the device as that will be later
// populated with the device of the associated TPUReplicateMetadata op.
auto launch_op = builder.create<tf_device::LaunchOp>(
loc_op->getLoc(), builder.getStringAttr(host_device),
/*result_types=*/ArrayRef<Type>{});
launch_op.body().push_back(new Block);
builder.setInsertionPointToEnd(&launch_op.GetBody());
builder.create<tf_device::ReturnOp>(loc_op->getLoc(),
llvm::ArrayRef<Value>{});
return launch_op;
}
// Returns true if `op` has non-static shaped outputs.
bool HasDynamicOutputs(Operation* op) {
for (Value v : op->getResults()) {
if (TF::CanBeRefined(v.getType())) return true;
}
return false;
}
// Returns true if any op in `cluster_ops` has outputs consumed by ops not
// `cluster_ops` with a non-static shape.
bool HasDynamicOutputs(const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
for (Operation* op : cluster_ops) {
for (const OpOperand& use : op->getUses()) {
if (cluster_ops.count(use.getOwner())) {
continue;
}
if (TF::CanBeRefined(use.get().getType())) return true;
}
}
return false;
}
bool HasDynamicExternalValues(Operation* op) {
return op
->walk([](Operation* walked_op) {
for (Value v : walked_op->getOperands()) {
if (TF::CanBeRefined(v.getType())) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
})
.wasInterrupted();
}
// Returns operands of `cluster_ops` that need to be
// communicated from device->host. This is for the case when all operands have a
// static shape.
llvm::SmallSetVector<Value, 4> GetStaticExternalOperands(
tf_device::ClusterOp tpu_cluster,
const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
llvm::SmallSetVector<Value, 4> external_values;
for (Operation* op : cluster_ops) {
op->walk([&](Operation* walked_op) {
if (llvm::isa<TF::_XlaRecvAtHostV2Op, TF::_XlaSendFromHostV2Op>(
walked_op))
return WalkResult::advance();
for (Value v : walked_op->getOperands()) {
if (auto* defining_op = v.getDefiningOp()) {
if (!op->isAncestor(defining_op) &&
tpu_cluster->isAncestor(defining_op) &&
!HasOutsideCompilationAncestor(defining_op) &&
!llvm::isa<TF::_XlaRecvAtHostV2Op>(defining_op)) {
external_values.insert(v);
}
continue;
}
auto block_arg = v.cast<BlockArgument>();
if (block_arg.getParentRegion() == op->getParentRegion())
external_values.insert(v);
}
return WalkResult::advance();
});
}
return external_values;
}
// Returns every operand of `cluster_ops` that does not come from an op in
// `cluster_ops`.
llvm::SmallSetVector<Value, 4> GetAllExternalOperands(
const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
llvm::SmallSetVector<Value, 4> external_values;
for (Operation* op : cluster_ops) {
op->walk([&](Operation* walked_op) {
for (Value v : walked_op->getOperands()) {
Operation* defining_op = v.getDefiningOp();
if (!defining_op || !cluster_ops.count(defining_op)) {
external_values.insert(v);
}
}
});
}
return external_values;
}
// Returns a SmallSetVector containing all of the operands that need to be
// communicated from device->host.
llvm::SmallSetVector<Value, 4> GetExternalOperands(
tf_device::ClusterOp tpu_cluster,
const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
// If there are any dynamic outputs, get all of the operands which are defined
// external to `cluster_ops`.
bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
if (has_dynamic_outputs) {
return GetAllExternalOperands(cluster_ops);
} else {
return GetStaticExternalOperands(tpu_cluster, cluster_ops);
}
}
// Gets all outputs that need to be communicated from host->device.
llvm::SmallSetVector<Value, 4> GetExternalOutputs(
const llvm::SmallSetVector<Operation*, 4>& cluster_ops) {
llvm::SmallSetVector<Value, 4> external_outputs;
bool has_dynamic_outputs = HasDynamicOutputs(cluster_ops);
for (Operation* op : cluster_ops) {
for (Operation* user : op->getUsers()) {
// We skip any operations that are in the same outside compilation
// cluster that will be moved to the host at the same time since both
// defining op and user op will be moved to host.
if (cluster_ops.count(user)) {
continue;
}
// This is pessimistic and in some cases will add extra communication.
if (!HasOutsideCompilationAncestor(user) || has_dynamic_outputs ||
HasDynamicOutputs(user)) {
for (Value v : user->getOperands()) {
if (v.getDefiningOp() == op) external_outputs.insert(v);
}
}
}
}
return external_outputs;
}
// Creates the HostCompute with `inputs` and `outputs`
// using `communication_key`.
TF::_XlaHostComputeMlirOp CreateHostCompute(
OpBuilder& builder, Location loc,
const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
llvm::StringRef args_communication_key,
llvm::StringRef retvals_communication_key,
llvm::StringRef serialized_func_module) {
llvm::SmallVector<Type, 4> device_output_types;
for (const auto& output : outputs)
device_output_types.push_back(output.getType());
auto host_compute = builder.create<TF::_XlaHostComputeMlirOp>(
loc, device_output_types, inputs.getArrayRef(),
builder.getStringAttr(args_communication_key),
builder.getStringAttr(retvals_communication_key),
/*host_mlir_module=*/builder.getStringAttr(serialized_func_module));
return host_compute;
}
void MarkOutsideCompiled(Operation* op) {
op->setAttr(kXlaOutsideCompilationAttr,
StringAttr::get(op->getContext(), "temp"));
}
// Returns whether an outside compilation cluster should be closed. True when:
// 1. There is a dynamically shaped output consumed by a non-outside compiled
// op.
// 2. There is no dynamically shaped output.
bool ShouldCloseCluster(llvm::ArrayRef<Value> outputs) {
bool has_dynamic_output = false;
for (Value v : outputs) {
if (TF::CanBeRefined(v.getType())) {
has_dynamic_output = true;
for (Operation* user : v.getUsers()) {
if (!HasOutsideCompilationAncestor(user)) return true;
}
}
}
return !has_dynamic_output;
}
// Replaces `external_operands` with the results from `recv_at_host`.
// For non-static shapes, only replace operand usage if op is in the same
// region as insertion.
// For static-shapes, Replace operand usages if op is in the same region as
// insertion or if the op is outside compiled and will be moved to host later.
void ReplaceExternalOperandUsage(
const llvm::SmallSetVector<Value, 4>& external_operands,
Operation* recv_at_host, Operation* insertion_point,
Block* original_op_block) {
auto replace_operand_usage = [&](OpOperand& operand) {
if (TF::CanBeRefined(operand.get().getType()) ||
HasDynamicOutputs(operand.getOwner())) {
return insertion_point->getParentRegion()->isAncestor(
operand.getOwner()->getParentRegion());
}
return insertion_point->getParentRegion()->isAncestor(
operand.getOwner()->getParentRegion()) ||
(HasOutsideCompilationAncestor(operand.getOwner()) &&
original_op_block == operand.getOwner()->getBlock());
};
for (auto result : llvm::zip(external_operands, recv_at_host->getResults())) {
Value external_operand = std::get<0>(result);
external_operand.replaceUsesWithIf(std::get<1>(result),
replace_operand_usage);
}
}
bool HasDynamicOutputs(llvm::ArrayRef<Value> outputs) {
for (Value v : outputs) {
if (TF::CanBeRefined(v.getType())) {
return true;
}
}
return false;
}
// Replaces usages of `external_outputs` which are values returned by outside
// compilation with the corresponding outputs from `host_compute`.
void ReplaceExternalOutputUsage(
const llvm::SmallSetVector<Value, 4>& external_outputs,
TF::_XlaHostComputeMlirOp host_compute) {
bool has_dynamic_outputs = HasDynamicOutputs(external_outputs.getArrayRef());
auto replace_output_usage = [&](OpOperand& operand) {
// Don't replace output usages if in host computation (defining op and user
// in same region).
bool in_same_region =
operand.get().getDefiningOp()->getParentRegion()->isAncestor(
operand.getOwner()->getParentRegion());
if (has_dynamic_outputs || HasDynamicOutputs(operand.getOwner())) {
return !in_same_region;
} else {
// Don't replace output usages in host computation or for outside
// compiled ops.
return !in_same_region &&
!HasOutsideCompilationAncestor(operand.getOwner());
}
};
for (auto result : llvm::zip(external_outputs, host_compute.getResults())) {
Value external_output = std::get<0>(result);
external_output.replaceUsesWithIf(std::get<1>(result),
replace_output_usage);
}
}
// Move `clustered_ops` to run on host and adds communication ops to transfer
// `external_operands` and `external_outputs` to/from device/host. Inserts
// ops at `insertion_point` and uses `compilation_key` and `device_ordinal` when
// creating comm ops.
void MoveOpsToHost(const llvm::SmallSetVector<Operation*, 4>& clustered_ops,
const llvm::SmallSetVector<Value, 4>& external_operands,
const llvm::SmallSetVector<Value, 4>& external_outputs,
Operation* insertion_point, Value compilation_key,
Value device_ordinal, int& communication_key_index) {
OpBuilder builder(insertion_point);
Operation& op = *clustered_ops.back();
std::string args_communication_key =
llvm::formatv("host_compute_channel_{0}_args", (communication_key_index))
.str();
std::string retvals_communication_key =
llvm::formatv("host_compute_channel_{0}_retvals",
(communication_key_index))
.str();
// Use a unique name when sending just the IfRegion predicate. This is
// for readable and to match the key in the TF2XLA bridge.
if (clustered_ops.size() == 1 && llvm::isa<TF::IfRegionOp>(op) &&
external_operands.size() == 1) {
args_communication_key =
llvm::formatv("if_predicate_channel_{0}", (communication_key_index))
.str();
}
std::string serialized_func_module;
if (HasDynamicOutputs(external_outputs.getArrayRef())) {
func::FuncOp shape_op = BuildFunction(
clustered_ops.getArrayRef(), external_operands.getArrayRef(),
external_outputs.getArrayRef(), &builder);
EncapsulateFuncAndSerialize(shape_op, &serialized_func_module);
}
builder.setInsertionPoint(&op);
auto host_compute =
CreateHostCompute(builder, op.getLoc(), external_operands,
external_outputs.getArrayRef(), args_communication_key,
retvals_communication_key, serialized_func_module);
// Insert ops on the host side computation to receive data from device.
builder.setInsertionPoint(insertion_point);
llvm::SmallVector<Type, 4> host_operand_types;
for (const auto& operand : external_operands)
host_operand_types.push_back(operand.getType());
Operation* recv_at_host = CreateRecvAtHostOp(
builder, op.getLoc(), host_operand_types, compilation_key, device_ordinal,
args_communication_key);
Block* original_op_block = op.getBlock();
Operation* after_op = recv_at_host;
for (Operation* cluster_op : clustered_ops) {
cluster_op->moveAfter(after_op);
cluster_op->removeAttr(StringAttr::get(op.getContext(), kDeviceAttr));
after_op = cluster_op;
}
if (!external_outputs.empty()) {
CreateSendFromHostOp(builder, op.getLoc(), external_outputs.getArrayRef(),
compilation_key, device_ordinal,
retvals_communication_key);
}
if (external_operands.empty()) {
recv_at_host->erase();
} else {
ReplaceExternalOperandUsage(external_operands,
/*recv_at_host=*/recv_at_host,
/*insertion_point=*/insertion_point,
/*original_op_block=*/original_op_block);
}
ReplaceExternalOutputUsage(external_outputs, host_compute);
if (external_operands.empty() && external_outputs.empty()) {
host_compute.erase();
} else {
++communication_key_index;
}
}
// Move outside compiled ops in `src` to `insertion_point` in host
// computation (may be temporarily with `tpu_cluster` but moved in subsequent
// call to this method). Communication ops are added in both `src` and at
// `insertion_point` using `compilation_key`, `device_ordinal` and
// `communication_key_index` which is incremented when used. Communication ops
// are added only when needed and at the location need. There are checks to
// ensure that duplicate communication between device and host is not added.
LogicalResult MoveOpsToHost(tf_device::ClusterOp tpu_cluster, Block* src,
Operation* insertion_point, Value compilation_key,
Value device_ordinal,
int& communication_key_index) {
// Contains all of the outside compiled operations that should be moved to the
// host using a single `_XlaHostComputeMlir` op. This should only contain a
// single op except in the case where some of the input/output shapes are
// non-static.
llvm::SmallSetVector<Operation*, 4> clustered_ops;
for (Operation& op : llvm::make_early_inc_range(*src)) {
if (HasOutsideCompilationAncestorExclusive(&op) ||
!op.hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
continue;
// We want to move the clustered_ops if the op to be added has all
// statically shaped operands since we can't ensure that the static shapes
// has been sent back to host in all cases. See
// @static_shapes_sandwiched_outside_compilation MLIR test for an example.
if (!HasDynamicExternalValues(&op) && !clustered_ops.empty()) {
llvm::SmallSetVector<Value, 4> external_operands =
GetExternalOperands(tpu_cluster, clustered_ops);
llvm::SmallSetVector<Value, 4> external_outputs =
GetExternalOutputs(clustered_ops);
MoveOpsToHost(clustered_ops, external_operands, external_outputs,
insertion_point, compilation_key, device_ordinal,
communication_key_index);
clustered_ops.clear();
}
clustered_ops.insert(&op);
// Get the outputs that need to be communicated from host -> device.
llvm::SmallSetVector<Value, 4> external_outputs =
GetExternalOutputs(clustered_ops);
if (ShouldCloseCluster(external_outputs.getArrayRef())) {
// Get the operands that need to be communicated from device -> host.
llvm::SmallSetVector<Value, 4> external_operands =
GetExternalOperands(tpu_cluster, clustered_ops);
MoveOpsToHost(clustered_ops, external_operands, external_outputs,
insertion_point, compilation_key, device_ordinal,
communication_key_index);
clustered_ops.clear();
}
}
return success();
}
// Decompose control flow in `tpu_cluster` into device computation and host
// (outside compiled) computation into two separate control flow ops with
// communication between the device/host for data dependencies. Both device and
// host control flow initially remain within `tpu_cluster` and a subsequency
// call to MoveOpsToHost moves the host side control flow to the host launch in
// tf_device.parallel_execute. Uses `compilation_key, `device_ordinal` and
// `communication_key_index` when creating communication ops.
LogicalResult DecomposeControlFlow(tf_device::ClusterOp tpu_cluster,
Value compilation_key, Value device_ordinal,
int& communication_key_index) {
auto result = tpu_cluster.GetBody().walk([&](Operation* op) {
if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(op)) {
if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
OpBuilder builder(if_op);
auto host_if = CloneEmptyIfWithPredicate(if_op, builder);
if (failed(MoveOpsToHost(tpu_cluster, &if_op.then_branch().front(),
host_if.then_branch().front().getTerminator(),
compilation_key, device_ordinal,
communication_key_index)))
return WalkResult::interrupt();
if (failed(MoveOpsToHost(tpu_cluster, &if_op.else_branch().front(),
host_if.else_branch().front().getTerminator(),
compilation_key, device_ordinal,
communication_key_index)))
return WalkResult::interrupt();
MarkOutsideCompiled(host_if.getOperation());
}
if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) {
if (!HasOutsideCompilationNested(op)) return WalkResult::advance();
OpBuilder builder(while_op);
auto host_while = CloneEmptyWhile(while_op.is_stateless(),
while_op.parallel_iterations(),
while_op.getLoc(), builder);
const auto condition_send_recv_key =
llvm::formatv("while_condition_channel_{0}",
communication_key_index++)
.str();
auto& cond = host_while.cond();
cond.push_back(new Block);
auto condition = while_op.cond().front().getTerminator()->getOperand(0);
builder.setInsertionPoint(while_op.cond().front().getTerminator());
builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition,
condition_send_recv_key);
builder.setInsertionPointToEnd(&cond.front());
auto recv_condition_at_host = CreateRecvAtHostOp(
builder, while_op.getLoc(), TypeRange{condition.getType()},
compilation_key, device_ordinal, condition_send_recv_key);
builder.create<TF::YieldOp>(while_op.getLoc(),
recv_condition_at_host->getResults());
if (failed(MoveOpsToHost(tpu_cluster, &while_op.cond().front(),
recv_condition_at_host, compilation_key,
device_ordinal, communication_key_index)))
return WalkResult::interrupt();
if (failed(MoveOpsToHost(tpu_cluster, &while_op.body().front(),
host_while.body().front().getTerminator(),
compilation_key, device_ordinal,
communication_key_index)))
return WalkResult::interrupt();
MarkOutsideCompiled(host_while.getOperation());
}
return WalkResult::advance();
});
if (result.wasInterrupted()) return failure();
return success();
}
// Removes outside compilation from all ops inside `host_launch_op`. Should
// only be run after all outside compiled ops have been moved to
// `host_launch_op`.
void RemoveOutsideCompilation(tf_device::LaunchOp host_launch_op) {
host_launch_op.GetBody().walk([&](Operation* op) {
if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
op->removeAttr(
StringAttr::get(op->getContext(), kXlaOutsideCompilationAttr));
}
});
}
// Creates a `parallel_execute` op with a region for host computation and
// a region for `tpu_cluster` computation by extracting outside compiled ops to
// host computation.
LogicalResult CreateParallelExecuteForOutsideCompilation(
ModuleOp module, tf_device::ClusterOp tpu_cluster,
llvm::StringRef host_device) {
OpBuilder builder(tpu_cluster);
// Create parallel_execute regions, one for the host computation for outside
// compilation and the second for the original TPU cluster computation.
const int num_regions = 2;
auto parallel_execute_op = builder.create<tf_device::ParallelExecuteOp>(
tpu_cluster.getLoc(), num_regions, tpu_cluster.results().getTypes());
Block& host_computation_block =
parallel_execute_op.GetRegionBlockWithIndex(0);
builder.setInsertionPointToEnd(&host_computation_block);
// Create a single launch op for all outside compiled ops.
tf_device::LaunchOp host_launch_op =
CreateLaunchOpForOutsideCluster(builder, tpu_cluster, host_device);
builder.setInsertionPoint(host_launch_op.GetBody().getTerminator());
auto compilation_key_op =
CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), builder);
Value compilation_key = compilation_key_op.program();
auto device_ordinal_op = builder.create<TF::_TPUDeviceOrdinalPlaceholderOp>(
tpu_cluster.getLoc(), RankedTensorType::get({}, builder.getI64Type()));
Value device_ordinal = nullptr;
if (tpu_cluster->getParentOfType<tf_device::ReplicateOp>()) {
device_ordinal = device_ordinal_op.device_ordinal();
}
int communication_key_index = 0;
// Decompose control flow into device and host control flow when outside
// compilation is included.
if (failed(DecomposeControlFlow(tpu_cluster, compilation_key, device_ordinal,
communication_key_index)))
return failure();
// Move all outside compiled ops including control flow to host launch.
if (failed(MoveOpsToHost(tpu_cluster, &tpu_cluster.GetBody(),
host_launch_op.GetBody().getTerminator(),
compilation_key, device_ordinal,
communication_key_index)))
return failure();
if (communication_key_index == 0) compilation_key_op.erase();
if (communication_key_index == 0 || device_ordinal == nullptr)
device_ordinal_op.erase();
RemoveOutsideCompilation(host_launch_op);
builder.setInsertionPointToEnd(&host_computation_block);
builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(), ArrayRef<Value>{});
// Move the launch body to last parallel_execute block.
Block& parallel_execute_tpu_block =
parallel_execute_op.GetRegionBlockWithIndex(1);
builder.setInsertionPointToEnd(&parallel_execute_tpu_block);
builder.create<tf_device::ReturnOp>(tpu_cluster.getLoc(),
tpu_cluster.getResults());
tpu_cluster.getOperation()->moveBefore(
parallel_execute_tpu_block.getTerminator());
// Remap cluster results with parallel_execute results if user is outside of
// parallel_execute.
for (auto result :
llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) {
Value tpu_cluster_result = std::get<0>(result);
Value parallel_execute_result = std::get<1>(result);
for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses()))
if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner()))
use.set(parallel_execute_result);
}
return success();
}
void TPUExtractOutsideCompilation::runOnOperation() {
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
return signalPassFailure();
module.walk([&](tf_device::ClusterOp tpu_cluster) {
if (HasOutsideCompilationNested(tpu_cluster.getOperation())) {
std::string host_device;
if (failed(tensorflow::GetHostDeviceOutsideComputation(
devices, tpu_cluster, &host_device)))
return signalPassFailure();
if (failed(CreateParallelExecuteForOutsideCompilation(module, tpu_cluster,
host_device)))
return signalPassFailure();
}
});
// Remove `_xla_outside_compilation` attribute from all ops. These ops will
// be outside of the device cluster. The `_xla_outside_compilation` attribute
// on ops outside of tf_device.cluster don't have any meaning and can lead to
// errors later on. These ops were likely lifted out of the
// tf_device.cluster in an earlier pass.
module.walk(
[](Operation* op) { op->removeAttr("_xla_outside_compilation"); });
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUExtractOutsideCompilationPass() {
return std::make_unique<TPUExtractOutsideCompilation>();
}
} // namespace TFTPU
} // namespace mlir