blob: e943885064a7044f89bc8bc310703a3e73b645be [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is a pass to reduce operands without changing the outcome.
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
struct ReduceWhileOperandsPass
: public PassWrapper<ReduceWhileOperandsPass, FunctionPass> {
public:
StringRef getArgument() const final { return "tfl-reduce-while"; }
StringRef getDescription() const final {
// TODO(b/200919263): Declare Reduce While Operands Pass in Table-Gen
return "Reduce the number of operands and results of a whlieOp.";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TFL::TensorFlowLiteDialect, TF::TensorFlowDialect>();
}
void runOnFunction() override;
};
LogicalResult FindImplicityProducers(
const std::vector<uint64_t> &explicitly_consumed_ids,
std::vector<bool> &is_consumed_id,
const std::vector<std::vector<uint64_t>> &dependency_graph) {
std::vector<uint64_t> queue;
queue.reserve(is_consumed_id.size());
for (auto id : explicitly_consumed_ids) {
is_consumed_id[id] = true;
queue.push_back(id);
}
while (!queue.empty()) {
auto i = queue.back();
queue.pop_back();
// If there is a consumer which cannot be found in dependency graph, return
// false.
if (i >= dependency_graph.size()) {
return failure();
}
for (auto j : dependency_graph.at(i)) {
if (is_consumed_id[j]) continue;
queue.push_back(j);
is_consumed_id[j] = true;
}
}
return success();
}
void FindProducers(Value start_node, std::vector<uint64_t> &neighbors) {
llvm::DenseSet<Value> visited;
std::vector<Value> queue;
queue.push_back(start_node);
visited.insert(start_node);
while (!queue.empty()) {
auto node = queue.back();
queue.pop_back();
if (auto arg = node.dyn_cast_or_null<BlockArgument>()) {
neighbors.push_back(arg.getArgNumber());
continue;
}
if (!node.getDefiningOp()) continue;
for (Value operand : node.getDefiningOp()->getOperands()) {
if (visited.contains(operand)) continue;
queue.push_back(operand);
visited.insert(operand);
}
}
}
void FindConsumedOp(Operation *start_op,
llvm::DenseSet<Operation *> &consumed_ops) {
if (consumed_ops.contains(start_op)) return;
std::vector<Operation *> queue;
queue.push_back(start_op);
consumed_ops.insert(start_op);
while (!queue.empty()) {
auto op = queue.back();
queue.pop_back();
for (Value operand : op->getOperands()) {
if (!operand.getDefiningOp()) continue;
auto def_op = operand.getDefiningOp();
if (consumed_ops.contains(def_op)) continue;
queue.push_back(def_op);
consumed_ops.insert(def_op);
}
}
}
inline bool IsConstant(Operation *op) { return matchPattern(op, m_Constant()); }
bool AllOperationSafe(Block &block) {
auto walk_result = block.walk([&](Operation *op) {
// op has SideEffect.
if (!isa_and_nonnull<TFL::WhileOp>(op) &&
!op->hasTrait<OpTrait::IsTerminator>() &&
!MemoryEffectOpInterface::hasNoEffect(op)) {
return WalkResult::interrupt();
}
// op has implict arguments not listed in operands.
// Fact: if every op's operands are defined in the same block as op,
// then no operation has implicit arugments (constant doesn't count).
for (auto operand : op->getOperands()) {
if (operand.dyn_cast_or_null<BlockArgument>()) continue;
auto operand_op = operand.getDefiningOp();
if (IsConstant(operand_op)) continue;
if (operand_op->getBlock() != op->getBlock()) {
return WalkResult::interrupt();
}
}
return WalkResult::advance();
});
return !walk_result.wasInterrupted();
}
// It reduces the following pattern:
//
// S = (0, 0, 0)
// while S[2] < 3:
// a0 = S[0] * 2
// a1 = a0 + S[1]
// a2 = S[2] + 1
// S = (a0, a1, a2)
// return S[0]
//
// the 2nd operand (i = 1) as well as its related op (a1 = a0 + S[1])
// can be removed since only S[0] is returned.
// It cannot be removed by loop-invariant-code-motion pass since every value
// is used and changed in the while loop.
// Moreover, we require
// 1. no implicit argument: For every operation in whileOp, all dependent values
// (except for constant) are explicitly passed in.
// 2. no side effect: Every operation inside whileOp can be safely
// remove when it is useEmpty().
// 3. no call func inside while.
bool ReduceWhileOperands(TFL::WhileOp while_op) {
std::vector<uint64_t> explicitly_consumed_ids;
Block &cond = while_op.cond().front();
Block &body = while_op.body().front();
auto n = while_op.getNumOperands();
if (!AllOperationSafe(cond) || !AllOperationSafe(body)) return false;
// Find all Consumed indices.
// i is consumed element if result(i) is used outside whileOp or
// arugment(i) is used in whileOp.cond().
for (auto i = 0; i < n; ++i) {
if (!while_op.getResult(i).use_empty() ||
!cond.getArgument(i).use_empty()) {
explicitly_consumed_ids.push_back(i);
}
}
// Empty consumed_element_ids implies none of results is used.
if (explicitly_consumed_ids.empty()) {
while_op.erase();
return true;
}
// If every element is consumed, one can't reduce any operand.
if (explicitly_consumed_ids.size() == n) {
return false;
}
// Build the dependency graph.
// If result(i) is depend on argument(j) in While.body(), then we put
// directed edge (i->j) into the graph.
std::vector<std::vector<uint64_t>> dependency_graph;
dependency_graph.reserve(n);
Operation &yield_op = body.back();
auto results = yield_op.getOperands();
for (auto i = 0; i < n; ++i) {
std::vector<uint64_t> neighbors;
neighbors.reserve(n);
FindProducers(results[i], neighbors);
dependency_graph.push_back(neighbors);
}
std::vector<bool> is_consumed_id(n, false);
if (failed(FindImplicityProducers(explicitly_consumed_ids, is_consumed_id,
dependency_graph))) {
return false;
}
// Find all consumed operations in while body.
llvm::DenseSet<Operation *> consumed_ops;
// We'll pass in the erase_indices to erase several operands simultaneously.
llvm::BitVector erase_indices(n);
consumed_ops.insert(&yield_op);
for (auto i = 0; i < n; ++i) {
if (!is_consumed_id[i]) {
erase_indices.set(i);
} else if (results[i].getDefiningOp()) {
FindConsumedOp(results[i].getDefiningOp(), consumed_ops);
}
}
// Remove elements and operations in while_body that are not indispensable.
yield_op.eraseOperands(erase_indices);
// Remove ops from bottom to top.
for (Operation &op :
llvm::make_early_inc_range(reverse(body.getOperations())))
// Constant will not be removed in case it is implicitly used.
if (!consumed_ops.contains(&op) && !IsConstant(&op)) {
op.erase();
}
body.eraseArguments(erase_indices);
cond.eraseArguments(erase_indices);
llvm::SmallVector<Value> new_operands;
llvm::SmallVector<Type> new_result_types;
new_operands.reserve(n - erase_indices.size());
new_result_types.reserve(n - erase_indices.size());
// After reducing, the number of results is decreased. The i-th result of old
// WhileOp becomes the j-th (j<=i) result of new WhileOp. This information is
// stored in id_map (id_map[i] = j).
std::vector<uint64_t> id_map(n, 0);
uint64_t j = 0;
for (auto i = 0; i < n; ++i) {
if (is_consumed_id[i]) {
id_map[i] = j++;
new_operands.push_back(while_op.getOperand(i));
new_result_types.push_back(while_op.getResultTypes()[i]);
}
}
auto new_while_op = OpBuilder(while_op).create<WhileOp>(
while_op.getLoc(), new_result_types, new_operands, while_op->getAttrs());
new_while_op.cond().takeBody(while_op.cond());
new_while_op.body().takeBody(while_op.body());
for (auto i = 0; i < n; ++i) {
if (!while_op.getResult(i).use_empty()) {
auto j = id_map[i];
while_op.getResult(i).replaceAllUsesWith(new_while_op.getResult(j));
}
}
while_op.erase();
return erase_indices.any();
}
void ReduceWhileOperandsPass::runOnFunction() {
auto fn = getFunction();
fn.walk([&](TFL::WhileOp while_op) { ReduceWhileOperands(while_op); });
}
static PassRegistration<ReduceWhileOperandsPass> pass;
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateReduceWhileOperandsPass() {
return std::make_unique<ReduceWhileOperandsPass>();
}
} // namespace TFL
} // namespace mlir