blob: 04456be1d43849de5692742210113a79158dc65c [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This pass forms `tf_executor.island` per region of
// `tf_device.parallel_execute`.
//
// For example, the following:
//
// %0 = tf_executor.island {
// tf_executor.yield
// }
// %1:2 = tf_executor.island {
// %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
// tf_executor.yield %2 : tensor<i1>
// }
// %3:2 = tf_executor.island(%0) {
// %4 = "tf_device.parallel_execute"() ({
// %5 = "tf.opB"() : () -> tensor<i1>
// tf_device.return %5 : tensor<i1>
// }, {
// %5 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
// tf_device.return
// }) {} : () -> (tensor<i1>)
// tf_executor.yield %4 : tensor<i1>
// }
// tf_executor.fetch %3#0 : tensor<i1>
//
// gets lowered to:
//
// %0 = tf_executor.island {
// tf_executor.yield
// }
// %1:2 = tf_executor.island {
// %2 = "tf.opA"(%arg0) : (tensor<i1>) -> tensor<i1>
// tf_executor.yield %2 : tensor<i1>
// }
//
// // Island for the first region of above parallel_execute.
// %3:2 = tf_executor.island(%0) {
// %4 = "tf.opB"() : () -> tensor<i1>
// tf_executor.yield %4 : tensor<i1>
// }
//
// // Island for the second region of above parallel_execute.
// %5 = tf_executor.island(%0) {
// %6 = "tf.opC"(%1#0) : (tensor<i1>) -> tensor<i32>
// tf_executor.yield
// }
//
// tf_executor.fetch %3#0, %5 : tensor<i1>, !tf_executor.control
//
// When tf_device.parallel_execute op is enclosed after tf_device.replicate,
// then this pass will run following `replicate-to-island` pass and
// `tf-executor-break-up-islands` pass.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
namespace mlir {
namespace TFDevice {
namespace {
struct ParallelExecuteToIslandsPass
: public TF::ParallelExecuteToIslandsPassBase<
ParallelExecuteToIslandsPass> {
void runOnOperation() override;
};
// Convert parallel_execute op to a set of islands where each region of
// parallel_execute op becomes a separate island. This ensures that the regions
// of the parallel_execute op gets executed concurrently.
void ExpandParallelExecuteToIslands(
tf_executor::IslandOp island_op,
tf_device::ParallelExecuteOp parallel_execute_op, OpBuilder* builder,
llvm::SmallVectorImpl<tf_executor::IslandOp>& executes) {
const int num_regions = parallel_execute_op.getOperation()->getNumRegions();
executes.reserve(num_regions);
for (int i : llvm::seq<int>(0, num_regions)) {
Block& execute_block = parallel_execute_op.GetRegionBlockWithIndex(i);
// Replace terminator with tf_executor.YieldOp.
Operation* terminator = execute_block.getTerminator();
builder->setInsertionPoint(terminator);
auto yield = builder->create<tf_executor::YieldOp>(
terminator->getLoc(), terminator->getOperands());
terminator->erase();
// Create new island for each region.
builder->setInsertionPoint(island_op);
auto execute_island = builder->create<tf_executor::IslandOp>(
island_op.getLoc(), yield.getOperandTypes(),
island_op.control().getType(), island_op.controlInputs());
// Move over tf_device.parallel_execute body region into newly the created
// island.
execute_island.body().takeBody(*execute_block.getParent());
executes.push_back(execute_island);
}
}
void CreateIslandsFromParallelExecute(
tf_executor::IslandOp island_op,
tf_device::ParallelExecuteOp parallel_execute_op) {
OpBuilder builder(island_op);
// Create islands for each region of the parallel_execute op.
llvm::SmallVector<tf_executor::IslandOp, 4> executes;
ExpandParallelExecuteToIslands(island_op, parallel_execute_op, &builder,
executes);
// Remap all results of parallel_execute op with outputs from newly created
// islands.
llvm::SmallVector<Value, 8> parallel_execute_outputs;
parallel_execute_outputs.reserve(
parallel_execute_op.getOperation()->getNumResults());
for (auto& execute : executes)
parallel_execute_outputs.append(execute.outputs().begin(),
execute.outputs().end());
for (auto result : llvm::zip(island_op.outputs(), parallel_execute_outputs))
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
// Add sink island to pin all islands as a control dependency if there is a
// control dependency leading from the parallel_execute originally.
if (!island_op.control().use_empty()) {
llvm::SmallVector<Value, 8> island_operands;
for (auto& execute : executes) island_operands.push_back(execute.control());
builder.setInsertionPoint(island_op);
auto island_sink = builder.create<tf_executor::IslandOp>(
island_op.getLoc(), llvm::ArrayRef<Type>{},
island_op.control().getType(), island_operands);
island_sink.body().push_back(new Block);
builder.setInsertionPointToEnd(&island_sink.GetBody());
builder.create<tf_executor::YieldOp>(island_op.getLoc(),
llvm::ArrayRef<Value>{});
island_op.control().replaceAllUsesWith(island_sink.control());
}
// Islands with no uses should be pinned to a graph fetch so they still
// execute.
llvm::SmallVector<Value, 8> unused_execute_controls;
for (auto& execute : executes)
if (execute.use_empty())
unused_execute_controls.push_back(execute.control());
if (!unused_execute_controls.empty()) {
auto graph_op = island_op->getParentOfType<tf_executor::GraphOp>();
tf_executor::FetchOp fetch = graph_op.GetFetch();
auto fetches = llvm::to_vector<8>(fetch.getOperands());
fetches.append(unused_execute_controls.begin(),
unused_execute_controls.end());
builder.setInsertionPoint(fetch);
builder.create<tf_executor::FetchOp>(fetch.getLoc(), fetches);
fetch.erase();
}
island_op.erase();
}
void ParallelExecuteToIslandsPass::runOnOperation() {
// Find islands with a single `tf_device.parallel_execute` and create
// individual islands per execute region of the parallel_execute.
llvm::SmallVector<tf_executor::IslandOp, 4> parallel_execute_op_islands;
getOperation().walk([&](tf_executor::GraphOp graph_op) {
for (auto island_op : graph_op.getOps<tf_executor::IslandOp>()) {
if (!island_op.WrapsSingleOp()) continue;
if (isa<tf_device::ParallelExecuteOp>(&island_op.GetBody().front()))
parallel_execute_op_islands.push_back(island_op);
}
});
for (tf_executor::IslandOp island_op : parallel_execute_op_islands) {
auto parallel_execute_op =
cast<tf_device::ParallelExecuteOp>(island_op.GetBody().front());
CreateIslandsFromParallelExecute(island_op, parallel_execute_op);
}
}
} // anonymous namespace
std::unique_ptr<OperationPass<func::FuncOp>>
CreateParallelExecuteToIslandsPass() {
return std::make_unique<ParallelExecuteToIslandsPass>();
}
} // namespace TFDevice
} // namespace mlir