blob: bdbf73859ef1ad9bd9f3c0eaa195729276a5de39 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering TensorFlow dialect's communication
// ops (TF/XLA) to the HLO dialect.
#include <memory>
#include <string>
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/tf_xla_passes_detail.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/side_effect_util.h"
namespace mlir {
using func::FuncOp;
namespace mhlo {
namespace {
constexpr char kShardingAttr[] = "mhlo.sharding";
constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
// A pass that legalizes TF/XLA communication ops, propagate their respective
// tokens (for ordering), and rewrite their respective functions and control
// flow ops when necessary.
// Note, this currently does not handle nested modules/functions or region based
// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
class LegalizeTFCommunication
: public LegalizeTFCommunicationPassBase<LegalizeTFCommunication> {
void runOnOperation() override;
};
// Checks if an op is a TF/XLA communication op.
bool IsCommunicationOp(Operation* op) {
return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
TF::XlaRecvFromHostOp>(op);
}
// Checks if an op is a supported HLO control flow op.
bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
// Collects control flow op ancestors of a given op, up until FuncOp. If any
// ancestor is not a control flow op or a FuncOp, or of a single block region,
// an error will be returned.
LogicalResult GetControlFlowAncestors(
Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
Block* block = op->getBlock();
Operation* parent = block->getParentOp();
while (block && parent && !isa<func::FuncOp>(parent)) {
if (!IsControlFlowOp(parent))
return op->emitOpError()
<< "expects ancestor(s) to be of ['" << IfOp::getOperationName()
<< "', '" << func::FuncOp::getOperationName() << "']";
if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
return op->emitOpError() << "expects single block region ancestor(s)";
control_flow_ops.insert(parent);
control_flow_blocks.insert(block);
parent = block->getParentOp();
block = parent->getBlock();
}
return success();
}
// Finds communication ops in a function. `control_flow_ops` and
// `control_flow_blocks` will be populated with control flow op ancestors for
// every communication op.
LogicalResult FindCommunicationOps(
func::FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
bool& has_communication_ops) {
auto result = func.walk([&](Operation* op) {
if (!IsCommunicationOp(op)) return WalkResult::advance();
has_communication_ops = true;
if (failed(
GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return failure(result.wasInterrupted());
}
// Helper struct holding a function to be rewritten, it's control flow ops that
// lead to a communication op or function call with a communication op
// (transitively), and an optional clone of itself. If `clone` is set, function
// calls to `original` will be replaced with `clone`.
struct FuncToRewrite {
func::FuncOp original;
llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
func::FuncOp clone;
};
// Finds all functions that need to be rewritten with communication ops and
// and associated tokens.
LogicalResult GetFunctionsToRewrite(
ModuleOp module,
llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
// Find functions containing communication ops.
SmallVector<func::FuncOp, 4> funcs_to_visit;
for (func::FuncOp func : module.getOps<func::FuncOp>()) {
FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
/*control_flow_blocks=*/{},
/*clone=*/nullptr};
bool has_communication_ops = false;
if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
func_to_rewrite.control_flow_blocks,
has_communication_ops)))
return failure();
if (!has_communication_ops) continue;
funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
funcs_to_visit.push_back(func);
}
// Find functions that call functions with communication ops, transitively.
while (!funcs_to_visit.empty()) {
SmallVector<func::FuncOp, 4> new_funcs_to_visit;
for (func::FuncOp& func : funcs_to_visit) {
auto uses = func.getSymbolUses(module);
if (!uses) continue;
for (auto& use : *uses) {
// Only `mlir::func::CallOp` is supported as this requires knowing how
// to rewrite arguments and results to a function.
if (!isa<mlir::func::CallOp>(use.getUser())) continue;
auto caller_parent_func =
use.getUser()->getParentOfType<func::FuncOp>();
if (!caller_parent_func) continue;
FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
/*control_flow_ops=*/{},
/*control_flow_blocks=*/{},
/*clone=*/nullptr};
if (failed(GetControlFlowAncestors(
use.getUser(), func_to_rewrite.control_flow_ops,
func_to_rewrite.control_flow_blocks)))
return failure();
auto it = funcs_to_rewrite.insert(
{caller_parent_func.getName(), func_to_rewrite});
if (it.second) {
new_funcs_to_visit.push_back(caller_parent_func);
} else {
it.first->getSecond().control_flow_ops.insert(
func_to_rewrite.control_flow_ops.begin(),
func_to_rewrite.control_flow_ops.end());
it.first->getSecond().control_flow_blocks.insert(
func_to_rewrite.control_flow_blocks.begin(),
func_to_rewrite.control_flow_blocks.end());
}
}
}
funcs_to_visit.swap(new_funcs_to_visit);
}
// Clone public functions that need to be rewritten. Function calls to this
// function will be replaced with the cloned function.
SymbolTable symbol_table(module);
for (auto& func : funcs_to_rewrite) {
if (func.getSecond().original.isPublic() &&
!func.getSecond().original.symbolKnownUseEmpty(module)) {
auto clone = func.getSecond().original.clone();
clone.setPrivate();
symbol_table.insert(clone);
func.getSecond().clone = clone;
}
}
return success();
}
// Assigns op sharding to an op for a given device core.
void SetOpSharding(Operation* op, int64_t tpu_core) {
std::string sharding_serialized =
::xla::sharding_builder::AssignDevice(tpu_core).SerializeAsString();
op->setAttr(kShardingAttr,
StringAttr::get(op->getContext(), sharding_serialized));
}
// Assigns frontend attributes holding information about data type and
// TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
// handled differently as individual names are used per data send and receive.
void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
Type type, bool device_to_host,
StringRef host_handler_name) {
MLIRContext* context = op->getContext();
std::string formatted_key =
device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
: llvm::formatv("{0}_htod_{1}", key, index).str();
auto rendezvous_name = StringAttr::get(context, formatted_key);
auto rendezvous_name_attr = NamedAttribute(
StringAttr::get(context, xla::kXlaHostTransferRendezvousNameAttr),
rendezvous_name);
auto element_type = getElementTypeOrSelf(type);
auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
const std::string& xla_element_type_str =
::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
auto original_type = StringAttr::get(context, xla_element_type_str);
auto original_type_attr = NamedAttribute(
StringAttr::get(context, xla::kXlaHostTransferOriginalTypeAttr),
original_type);
auto host_handler_name_value =
StringAttr::get(context, host_handler_name.str());
auto host_handler_name_attr = NamedAttribute(
StringAttr::get(context, xla::kXlaHostTransferHandlerNameAttr),
host_handler_name_value);
auto frontend_attributes = DictionaryAttr::get(
context,
ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr,
host_handler_name_attr});
op->setAttr(kFrontendAttributesAttr, frontend_attributes);
}
// Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
// op sharding for the respective device will be set.
Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
Value operand, StringRef key, size_t index,
const Optional<int64_t>& tpu_core, Value token,
StringRef host_handler_name) {
// type 2 == DEVICE_TO_HOST
auto channel_handle = ChannelHandle::get(
/*handle=*/builder.getI64IntegerAttr(channel_id++),
/*type=*/builder.getI64IntegerAttr(2), builder.getContext());
auto send = builder.create<SendOp>(
loc, token.getType(), operand, token, channel_handle,
/*is_host_transfer=*/builder.getBoolAttr(true));
SetFrontendAttributes(send, index, key, operand.getType(),
/*device_to_host=*/true, host_handler_name);
if (tpu_core) SetOpSharding(send, *tpu_core);
return send.getResult();
}
// Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
// sharding for the respective device will be set.
Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
Value result, StringRef key, size_t index,
const Optional<int64_t>& tpu_core, Value token,
StringRef host_handler_name) {
// type 3 == HOST_TO_DEVICE
auto channel_handle = ChannelHandle::get(
/*handle=*/builder.getI64IntegerAttr(channel_id++),
/*type=*/builder.getI64IntegerAttr(3), builder.getContext());
auto result_type = result.getType();
SmallVector<Type, 2> recv_result_type = {result_type, token.getType()};
auto recv =
builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
/*is_host_transfer=*/builder.getBoolAttr(true));
SetFrontendAttributes(recv, index, key, result_type,
/*device_to_host=*/false, host_handler_name);
if (tpu_core) SetOpSharding(recv, *tpu_core);
result.replaceAllUsesWith(recv.getResult(0));
return recv.getResult(1);
}
// Creates a new token if necessary, acting as a sink to previous tokens. If
// there is only one token in `tokens`, the only token is returned. If `tokens`
// is empty, `original_token` is returned instead.
Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
Value original_token) {
if (tokens.empty()) {
return original_token;
} else if (llvm::hasSingleElement(tokens)) {
return tokens[0];
} else {
return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
.getResult();
}
}
// Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
// ops per operand and result. Unique Channel IDs are assigned per transfer.
// Sink tokens are created across all `mhlo.send` ops first and then by
// all `mhlo.recv` ops.
Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
TF::_XlaHostComputeMlirOp host_compute,
Value token) {
builder.setInsertionPoint(host_compute);
Location loc = host_compute.getLoc();
int64_t tpu_core = host_compute.tpu_coreAttr().getInt();
SmallVector<Value, 4> send_tokens;
for (auto operand : llvm::enumerate(host_compute.inputs())) {
auto send_token =
CreateSendOp(builder, channel_id, loc, operand.value(),
host_compute.send_key(), operand.index(), tpu_core, token,
xla::kXlaHostTransferTfRendezvousHandlerName);
send_tokens.push_back(send_token);
}
token = CreateSinkToken(builder, loc, send_tokens, token);
SmallVector<Value, 4> recv_tokens;
for (auto result : llvm::enumerate(host_compute.outputs())) {
auto recv_token =
CreateRecvOp(builder, channel_id, loc, result.value(),
host_compute.recv_key(), result.index(), tpu_core, token,
xla::kXlaHostTransferTfRendezvousHandlerName);
recv_tokens.push_back(recv_token);
}
token = CreateSinkToken(builder, loc, recv_tokens, token);
host_compute.erase();
return token;
}
// Replaces `tf.XlaSendToHost` with a `mhlo.send`.
Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
TF::XlaSendToHostOp send_to_host, Value token) {
builder.setInsertionPoint(send_to_host);
token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
send_to_host.input(), send_to_host.key(),
/*index=*/0, /*tpu_core=*/llvm::None, token,
xla::kXlaHostTransferTfRendezvousHandlerName);
send_to_host.erase();
return token;
}
// Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
TF::XlaRecvFromHostOp recv_from_host, Value token) {
builder.setInsertionPoint(recv_from_host);
token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
recv_from_host.output(), recv_from_host.key(),
/*index=*/0, /*tpu_core=*/llvm::None, token,
xla::kXlaHostTransferTfRendezvousHandlerName);
recv_from_host.erase();
return token;
}
// Replaces a `mlir::func::CallOp` with one that has an extra `!mhlo.token`
// operand and `!mhlo.token` result. If `new_symbol` is set, the new call will
// be updated to call the `new_symbol` instead.
Value RewriteCallOp(OpBuilder& builder, func::CallOp call,
const Optional<StringRef>& new_symbol, Value token) {
builder.setInsertionPoint(call);
auto new_operands = llvm::to_vector(call.getArgOperands());
new_operands.push_back(token);
auto new_result_types = llvm::to_vector(call.getResultTypes());
new_result_types.push_back(token.getType());
auto new_call = builder.create<func::CallOp>(
call.getLoc(), new_result_types,
new_symbol ? *new_symbol : call.getCallee(), new_operands);
for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
call.erase();
return new_call.getResults().back();
}
// Helper struct holding state of which op to visit to next. If `op` is in a
// control flow op region, `region_idx` will be set with the respective region
// index. `token` will be current token from the last communication op/control
// flow op transitive communication ops.
struct OpVisitorState {
Optional<unsigned> region_idx;
Value token;
Operation* op;
};
// Creates a tuple from a sequence of values.
Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
return builder.create<TupleOp>(loc, operands).getResult();
}
// Extends `values` with the value `token` attached. If `flatten_tuple` is
// false, `values` will have a single element, say `value`. If `value` is not a
// tuple, a new tuple is formed with `token`. If `values` is a tuple, it is
// extended instead. New tuple values created are cached.
SmallVector<Value> GetValueWithToken(
OpBuilder& builder, ArrayRef<Value> values, Value token,
llvm::SmallDenseMap<Value, Value>& rewritten_values, bool flatten_tuple) {
if (flatten_tuple) {
auto operands = llvm::to_vector(values);
operands.push_back(token);
return operands;
}
auto value = values[0];
// If value with token already exists, reuse it.
auto it = rewritten_values.find(value);
if (it != rewritten_values.end()) return {it->getSecond()};
auto create_tuple = [&](ArrayRef<Value> operands) {
auto new_result = CreateTuple(builder, value.getLoc(), operands);
rewritten_values.insert({value, new_result});
return new_result;
};
auto tuple_type = value.getType().dyn_cast<TupleType>();
// `value` is not a tuple, create a new tuple.
if (!tuple_type) return {create_tuple({value, token})};
// Extend tuple if `value` is a tuple.
// If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
// the tuple.
if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
auto tuple_operands = llvm::to_vector(tuple_op.getOperands());
tuple_operands.push_back(token);
return {create_tuple(tuple_operands)};
}
// `value` is not created via a `mhlo.tuple` directly, unpack individual
// elements directly with `mhlo.get_tuple_element`.
SmallVector<Value, 4> tuple_operands;
for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
tuple_operands.push_back(
builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
.getResult());
tuple_operands.push_back(token);
return {create_tuple(tuple_operands)};
}
// Extends the 'types' to include a `mhlo.token` type. If `flatten_tuple` is
// false, `types` will have a single element, say `type`. If `type` is not a
// tuple type, a new tuple type with `type` and `mhlo.token` type is created
// instead.
SmallVector<Type> GetTypeWithToken(OpBuilder& builder, ArrayRef<Type> types,
bool flatten_tuple) {
SmallVector<Type> new_result_types;
auto token_type = TokenType::get(builder.getContext());
if (flatten_tuple) {
auto result_types = llvm::to_vector(types);
result_types.push_back(token_type);
return result_types;
}
auto type = types[0];
if (auto tuple_type = type.dyn_cast<TupleType>()) {
auto result_types = llvm::to_vector(tuple_type.getTypes());
result_types.push_back(token_type);
return {builder.getTupleType(result_types)};
}
return {builder.getTupleType({type, token_type})};
}
// Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
// to `end`, exclusive.
Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
SmallVector<Value, 4> tuple_operands;
for (auto idx : llvm::seq<int32_t>(0, end))
tuple_operands.push_back(
builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
.getResult());
return CreateTuple(builder, value.getLoc(), tuple_operands);
}
// Replaces uses of `values` with `replacements`. If `flatten_tuple` is false,
// `values` will have a single element, say `value`. If `value` is not a tuple
// type, an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
// return the first element. Otherwise, `mhlo.get_tuple_element` users are
// simply updated with `replacement`, and all other users are updated with a
// slice of `replacement`.
void ReplaceWithTupleResult(OpBuilder& builder, ArrayRef<Value> values,
ArrayRef<Value> replacements, bool flatten_tuple) {
if (flatten_tuple) {
for (size_t result_index = 0; result_index < values.size(); result_index++)
values[result_index].replaceAllUsesWith(replacements[result_index]);
return;
}
auto value = values[0];
auto replacement = replacements[0];
auto tuple_type = value.getType().dyn_cast<TupleType>();
if (!tuple_type) {
if (!value.use_empty()) {
auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
replacement, 0);
value.replaceAllUsesWith(new_element.getResult());
}
return;
}
Value sub_tuple;
for (auto& use : llvm::make_early_inc_range(value.getUses())) {
if (isa<GetTupleElementOp>(use.getOwner())) {
use.set(replacement);
continue;
}
if (!sub_tuple)
sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
use.set(sub_tuple);
}
}
// Replaces control flow op block arguments with new block arguments
// of types `types`. The last element of the new block argument (token) is
// returned.
Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
ArrayRef<Type> types) {
builder.setInsertionPointToStart(&block);
auto old_args_size = block.getNumArguments();
block.addArguments(
types, SmallVector<Location>(types.size(), block.getParent()->getLoc()));
auto old_args = ArrayRef<Value>(block.getArguments().begin(),
block.getArguments().begin() + old_args_size);
auto new_args = ArrayRef<Value>(block.getArguments().begin() + old_args_size,
block.getArguments().end());
assert(!new_args.empty());
ReplaceWithTupleResult(builder, old_args, new_args, /*flatten_tuple=*/true);
auto new_arg = new_args[new_args.size() - 1];
block.eraseArguments(
llvm::to_vector(llvm::seq((unsigned)0, (unsigned)old_args_size)));
return new_arg;
}
// Updates control flow op terminator with an extra element `token`.
void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
Value token, bool flatten_tuple) {
assert(flatten_tuple || terminator->getNumOperands() == 1);
assert(flatten_tuple || terminator->getBlock()->getNumArguments() == 1);
// `mhlo.while` cond terminator does not need to be rewritten as it always
// returns a tensor<i1> predicate value.
if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
if (terminator->getParentRegion() == &while_parent.cond()) return;
builder.setInsertionPoint(terminator);
llvm::SmallDenseMap<Value, Value> rewritten_operands;
auto new_results =
GetValueWithToken(builder, llvm::to_vector(terminator->getOperands()),
token, rewritten_operands, flatten_tuple);
terminator->setOperands(new_results);
}
// Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. As If op does
// not have any operands other than the predicate, hence we implicitly capture
// the parent token. Also we use the same implicit token for use in the If op's
// regions.
void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
SmallVectorImpl<OpVisitorState>& ops_to_visit,
Value token) {
llvm::SmallDenseMap<Value, Value> rewritten_operands;
auto new_result_types =
GetTypeWithToken(builder, llvm::to_vector(region_if.getResultTypes()),
/*flatten_tuple=*/true);
// Create new `mhlo.if` op with extra token operands and result.
auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_types,
region_if.pred());
// Move all regions from the old `mhlo.if` op to its replacement.
new_if.true_branch().takeBody(region_if.true_branch());
new_if.false_branch().takeBody(region_if.false_branch());
// Forward result from old `mhlo.if` with replacement.
SmallVector<Value> old_if_results = region_if.getResults();
SmallVector<Value> new_if_results = new_if.getResults();
ReplaceWithTupleResult(builder, old_if_results, new_if_results,
/*flatten_tuple=*/true);
// auto new_token = new_if_results[new_if_results.size() - 1];
region_if.erase();
// Next op to visit. The replacement is visited but at its first region.
// The new region use the same implicit token used by the If op.
ops_to_visit.push_back({/*region_idx=*/0, token, new_if});
}
// Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
// `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
// element. If the region block is to be rewritten, the next op to visit is set
// to the first op in the block. Otherwise the terminator is updated to forward
// `token`.
void RewriteControlFlowOpRegion(
OpBuilder& builder, Operation* region_op, unsigned region_idx,
ArrayRef<Type> block_arg_types,
SmallVectorImpl<OpVisitorState>& ops_to_visit,
const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
ops_to_visit.push_back({region_idx + 1, token, region_op});
Region& region = region_op->getRegion(region_idx);
assert(llvm::hasSingleElement(region));
auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
block_arg_types);
if (control_flow_blocks.contains(&region.front())) {
ops_to_visit.push_back(
{/*region_idx=*/llvm::None, block_token, &region.front().front()});
return;
}
RewriteControlFlowTerminator(builder, region.front().getTerminator(),
block_token, /*flatten_tuple=*/true);
}
// For mlir::IfOp or mlir::CaseOp, replace the use of their region's block
// argument (of type token) with 'implicit_operand'.
void ReplaceBlockArgumentsWithImplicitOperands(mlir::Operation* op,
unsigned region_idx,
Value implicit_operand) {
assert((mlir::dyn_cast<mlir::mhlo::IfOp>(*op) ||
mlir::dyn_cast<mlir::mhlo::CaseOp>(*op)) &&
"Unexpected mlir op in "
"HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands!");
auto& region = op->getRegion(region_idx);
region.getArgument(0).replaceAllUsesWith(implicit_operand);
region.front().eraseArguments(
llvm::to_vector(llvm::seq<unsigned>(0, region.getNumArguments())));
}
// Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
// operands and results are rewritten. If `region_idx` is set, region
// `region_idx` is rewritten to take in and return an additional token. Returns
// true if the op or its region was rewritten.
bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
Optional<unsigned> region_idx,
SmallVectorImpl<OpVisitorState>& ops_to_visit,
const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
Value token) {
builder.setInsertionPoint(region_if);
if (!region_idx) {
RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
return true;
}
if (*region_idx < region_if.getNumRegions()) {
// For the region-blocks of If op, we create a dummy token argument. Later
// we replace that block-argument's uses with the same (implicitly captured)
// token 'token', used for If op, and erase the argument.
// Note that 'RewriteControlFlowOpRegion' sets the token, used for the first
// operation of region_idx'th region, to the dummy block-argument. As we
// erase that argument, we also need to make sure that the token used for
// the next operation is set to 'token'.
RewriteControlFlowOpRegion(builder, region_if, *region_idx,
{token.getType()}, ops_to_visit,
control_flow_blocks, token);
ReplaceBlockArgumentsWithImplicitOperands(region_if.getOperation(),
*region_idx, token);
auto next_visitor_state = ops_to_visit.back();
next_visitor_state.token = token;
ops_to_visit.pop_back();
ops_to_visit.push_back(next_visitor_state);
return true;
}
return false;
}
// Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
// the op for all of its regions are extended to have an extra operand `token`.
void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
SmallVectorImpl<OpVisitorState>& ops_to_visit,
Value token) {
llvm::SmallDenseMap<Value, Value> rewritten_operands;
// Rewrite region operand to have an extra operand `token`.
auto new_val_operands =
GetValueWithToken(builder, llvm::to_vector(region_while.getOperands()),
token, rewritten_operands,
/*flatten_tuple=*/true);
auto new_result_types =
GetTypeWithToken(builder, llvm::to_vector(region_while.getResultTypes()),
/*flatten_tuple*/ true);
// Create new `mhlo.while` op with extra token operand and result.
auto new_while = builder.create<WhileOp>(region_while.getLoc(),
new_result_types, new_val_operands);
// Move all regions from the old `mhlo.while` op to its replacement.
new_while.cond().takeBody(region_while.cond());
new_while.body().takeBody(region_while.body());
// Forward result from old `mhlo.while` with replacement.
SmallVector<Value> old_while_results = region_while.getResults();
SmallVector<Value> new_while_results = new_while.getResults();
ReplaceWithTupleResult(builder, old_while_results, new_while_results,
/*flatten_tuple*/ true);
auto new_token = new_while_results[new_while_results.size() - 1];
region_while.erase();
// Next op to visit. The replacement is visited but at its first region. The
// token result of the new region if is propagated.
ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
}
// Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
// operands and results are rewritten. If `region_idx` is set, region
// `region_idx` is rewritten to take in and return an additional token. Returns
// true if the op or its region was rewritten.
bool ProcessRegionWhileOp(
OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
SmallVectorImpl<OpVisitorState>& ops_to_visit,
const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
builder.setInsertionPoint(region_while);
if (!region_idx) {
RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
return true;
}
if (*region_idx < region_while.getNumRegions()) {
SmallVector<Type> arg_types;
for (auto arg : region_while.arg()) arg_types.push_back(arg.getType());
RewriteControlFlowOpRegion(builder, region_while, *region_idx, arg_types,
ops_to_visit, control_flow_blocks, token);
return true;
}
return false;
}
// Updates function type based on current function body block arguments and
// terminator operand types.
void UpdateFunctionType(OpBuilder& builder, func::FuncOp func,
Block& func_body) {
auto new_argument_types = llvm::to_vector(func_body.getArgumentTypes());
auto new_result_types =
llvm::to_vector(func_body.getTerminator()->getOperandTypes());
func.setType(FunctionType::get(builder.getContext(), new_argument_types,
new_result_types));
}
// Replaces a function terminator `return` with another `return` that has an
// extra `mhlo.token` operand.
void RewriteFunctionTerminator(OpBuilder& builder,
mlir::func::ReturnOp terminator, Value token) {
auto new_results = llvm::to_vector(terminator.getOperands());
new_results.push_back(token);
builder.setInsertionPoint(terminator);
builder.create<mlir::func::ReturnOp>(terminator.getLoc(), new_results);
terminator.erase();
}
// Rewrites a function body and communication ops inside. Region control flow
// are updated when necessary, to propagate tokens. The function may either be
// rewritten to create a token or take in and return a token, depending on its
// visibility and if there are any callers.
LogicalResult RewriteFunction(
OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
MLIRContext* context = module.getContext();
if (!llvm::hasSingleElement(func.getBody()))
return func.emitError()
<< "'" << FuncOp::getOperationName()
<< "' ops with more than one block are not supported";
bool rewrite_block =
is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
Block& func_body = func.front();
builder.setInsertionPointToStart(&func_body);
auto token_type = TokenType::get(context);
// If a function is public, it's signature should not be modified, and instead
// a token will be created. Otherwise a token block argument is inserted.
Value init_token =
rewrite_block ? func_body.addArgument(token_type, func.getLoc())
: builder.create<CreateTokenOp>(func.getLoc(), token_type)
.getResult();
// Stack to keep track of region based control flow op nesting and current
// op to visit.
SmallVector<OpVisitorState, 4> ops_to_visit{
{/*region_idx=*/llvm::None, init_token, &func_body.front()}};
while (!ops_to_visit.empty()) {
OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
Operation* curr_op = op_to_visit.op;
Value token = op_to_visit.token;
// Ops may be removed, so the next op is kept track of beforehand.
Operation* next_op = curr_op->getNextNode();
if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
} else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
} else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
} else if (auto call = dyn_cast<mlir::func::CallOp>(curr_op)) {
// Only `mlir::func::CallOp` is supported as this requires knowing how to
// rewrite arguments and results to a function.
auto it = funcs.find(call.getCallee());
if (it != funcs.end()) {
func::FuncOp clone = it->getSecond().clone;
Optional<StringRef> symbol_name =
clone ? Optional<StringRef>(clone.getName()) : llvm::None;
// If the function being called is to be cloned, update the call to also
// point to the cloned function.
token = RewriteCallOp(builder, call, symbol_name, token);
}
} else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
if (op_to_visit.region_idx || control_flow_ops.contains(region_if)) {
auto exist_unprocessed_region =
ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
ops_to_visit, control_flow_blocks, token);
// Once all the IfOp regions are processed (i.e.
// 'exist_unprocessed_region' == false), select returned token-value
// from IfOp as the token to be used for the following op.
if (!exist_unprocessed_region) {
token = curr_op->getResult(curr_op->getNumResults() - 1);
} else {
continue;
}
}
} else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
ops_to_visit, control_flow_blocks, token))
continue;
} else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
bool flatten_tuple = isa<mhlo::WhileOp, mhlo::IfOp, mhlo::CaseOp>(
region_terminator->getParentOp());
RewriteControlFlowTerminator(builder, region_terminator, token,
flatten_tuple);
// There is no next op after the control flow op terminator, simply let
// stack have one less element.
continue;
} else if (auto func_terminator = dyn_cast<mlir::func::ReturnOp>(curr_op)) {
if (rewrite_block)
RewriteFunctionTerminator(builder, func_terminator, token);
// There is no next op after the function terminator, simply let stack
// have one less element/be empty.
continue;
}
// Visit next op.
ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
}
if (rewrite_block) UpdateFunctionType(builder, func, func_body);
return success();
}
// Checks if a function call is pointing to a function with communication ops.
bool IsFunctionCallWithCommunication(
Operation* op,
const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
if (auto call = dyn_cast<mlir::func::CallOp>(op))
return funcs_to_rewrite.count(call.getCallee());
return false;
}
// Collects all control flow op ancestors of communication ops or function calls
// with communication ops (transitively).
void GetCommunicationControlFlowOps(
func::FuncOp func,
const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
func.walk([&](Operation* op) {
if (IsCommunicationOp(op) ||
IsFunctionCallWithCommunication(op, funcs_to_rewrite))
if (failed(GetControlFlowAncestors(op, control_flow_ops,
control_flow_blocks)))
llvm_unreachable(
"checking original function for control flow ancestors should have "
"errored first");
});
}
void LegalizeTFCommunication::runOnOperation() {
auto module = getOperation();
llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
return signalPassFailure();
// Module level counter to make sure Channel IDs are unique.
int64_t channel_id = 1;
OpBuilder builder(&getContext());
for (const auto& func_and_name : funcs_to_rewrite) {
const auto& func_to_rewrite = func_and_name.getSecond();
func::FuncOp func = func_to_rewrite.original;
if (failed(RewriteFunction(builder, channel_id, module, func,
funcs_to_rewrite,
func_to_rewrite.control_flow_ops,
func_to_rewrite.control_flow_blocks,
/*is_clone=*/false)))
return signalPassFailure();
func::FuncOp clone = func_and_name.getSecond().clone;
if (!clone) continue;
llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
clone_control_flow_ops,
clone_control_flow_blocks);
if (failed(RewriteFunction(builder, channel_id, module, clone,
funcs_to_rewrite, clone_control_flow_ops,
clone_control_flow_blocks,
/*is_clone=*/true)))
llvm_unreachable(
"rewriting of original function should have errored first");
}
}
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
return std::make_unique<LegalizeTFCommunication>();
}
} // namespace mhlo
} // namespace mlir