blob: bef5a0f94d947e93d88d1b668381cdc7e6ec8aee [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/memory/memory.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Traits.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
namespace mlir {
namespace {
class ConvertResultsBroadcastableShapeOp : public RewritePattern {
public:
ConvertResultsBroadcastableShapeOp(MLIRContext* context)
: RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
LogicalResult matchAndRewrite(Operation* op,
PatternRewriter& rewriter) const override;
private:
template <typename Op>
LogicalResult RewriteEqOp(Operation* op, PatternRewriter& rewriter) const;
LogicalResult RewriteOp(
Operation* op, PatternRewriter& rewriter,
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
SmallVectorImpl<int64_t>&)>&
get_broadcasted_shape) const;
LogicalResult RewriteBatchMatMulV2Op(Operation* op,
PatternRewriter& rewriter) const;
};
class BroadcastFoldPass : public TF::BroadcastFoldPassBase<BroadcastFoldPass> {
public:
void runOnOperation() override;
};
LogicalResult ConvertResultsBroadcastableShapeOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
if (op->hasTrait<OpTrait::ResultsBroadcastableShape>())
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
// tf.Equal and tf.NotEqual ops only satisfy ResultsBroadcastableShape when
// incompatible_shape_error is `true` (what is also checked by the verifier).
if (succeeded(RewriteEqOp<TF::EqualOp>(op, rewriter))) return success();
if (succeeded(RewriteEqOp<TF::NotEqualOp>(op, rewriter))) return success();
if (succeeded(RewriteBatchMatMulV2Op(op, rewriter))) return success();
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteBatchMatMulV2Op(
Operation* op, PatternRewriter& rewriter) const {
auto matmul_op = llvm::dyn_cast<TF::BatchMatMulV2Op>(op);
if (!matmul_op) return failure();
// Gets the broadcasted output shape for tf.BatchMatMulV2Op. `shape_x` is the
// shape of op's first/left-hand-side operand and `shape_y` is the shape of
// op's second/right-hand-side operand.
const auto get_broadcasted_shape =
[&](ArrayRef<int64_t> shape_x, ArrayRef<int64_t> shape_y,
SmallVectorImpl<int64_t>& result_shape) {
if (shape_x.size() < 2 || shape_y.size() < 2) {
return false;
}
// Checks outer dimensions (i.e., the dimensions higher than 2D) are
// broadcastable. If true, then get the broadcasted shape for outer
// dimension.
if (!OpTrait::util::getBroadcastedShape(
shape_x.drop_back(2), shape_y.drop_back(2), result_shape)) {
return false;
}
const int x_row =
matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
const int x_col =
!matmul_op.adj_x() ? shape_x.back() : *(shape_x.rbegin() + 1);
const int y_row =
matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
const int y_col =
!matmul_op.adj_y() ? shape_y.back() : *(shape_y.rbegin() + 1);
// Checks that matrix multiply can perform a valid contraction.
if (x_col != y_row) {
result_shape.clear();
return false;
}
result_shape.push_back(x_row);
result_shape.push_back(y_col);
return true;
};
return RewriteOp(op, rewriter, get_broadcasted_shape);
}
template <typename Op>
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteEqOp(
Operation* op, PatternRewriter& rewriter) const {
auto eq_op = llvm::dyn_cast_or_null<Op>(op);
if (eq_op && eq_op.incompatible_shape_error())
return RewriteOp(op, rewriter, OpTrait::util::getBroadcastedShape);
return failure();
}
LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp(
Operation* op, PatternRewriter& rewriter,
const std::function<bool(ArrayRef<int64_t>, ArrayRef<int64_t>,
SmallVectorImpl<int64_t>&)>& get_broadcasted_shape)
const {
if (op->getNumOperands() != 2 || op->getResultTypes().size() != 1)
return failure();
// Check that the result shape is fully defined.
auto result_type =
op->getResultTypes().front().dyn_cast_or_null<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) return failure();
bool changed = false;
for (uint64_t i = 0, e = op->getNumOperands(); i < e; ++i) {
// Check that the i'th operand is a broadcast.
auto broadcast = llvm::dyn_cast_or_null<TF::BroadcastToOp>(
op->getOpOperand(i).get().getDefiningOp());
if (!broadcast) continue;
// Check that the operand of the broadcast has fully defined shape.
auto broadcast_arg_type =
broadcast.input().getType().dyn_cast_or_null<RankedTensorType>();
if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue;
// Check that the other argument has fully defined shape.
auto argument_type = op->getOpOperand(1 - i)
.get()
.getType()
.dyn_cast_or_null<RankedTensorType>();
if (!argument_type || !argument_type.hasStaticShape()) continue;
// Get the unbroadcasted shapes in the operand order.
std::array<llvm::ArrayRef<int64_t>, 2> operand_shapes;
operand_shapes[i] = broadcast_arg_type.getShape();
operand_shapes[1 - i] = argument_type.getShape();
// Check that the input of the broadcast and the other operand is broadcast
// compatible.
llvm::SmallVector<int64_t, 4> broadcasted_shape;
if (!get_broadcasted_shape(operand_shapes[0], operand_shapes[1],
broadcasted_shape))
continue;
// Check that an implicit broadcast between the operand of the broadcast and
// the other argument would result in the same type as the result type.
if (broadcasted_shape != result_type.getShape()) continue;
// Update the operand of the op to be the operand of the broadcast.
rewriter.updateRootInPlace(
op, [&]() { op->getOpOperand(i).set(broadcast.input()); });
changed = true;
}
return success(changed);
}
void BroadcastFoldPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
auto func = getOperation();
patterns.add<ConvertResultsBroadcastableShapeOp>(func.getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
} // namespace
namespace TF {
std::unique_ptr<OperationPass<func::FuncOp>> CreateBroadcastFoldPass() {
return absl::make_unique<BroadcastFoldPass>();
}
} // namespace TF
} // namespace mlir