blob: 4f64990f1fee2c5ff08415406007f43f34ad5ad7 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <iostream>
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
namespace mlir {
namespace TF {
namespace {
// Note: This implements the fusions performed in the old Remapper Grappler
// pass. That pass has specific cases for GPU and based on different
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
// pass covers (some of) the general CPU case and at the moment does not account
// for any target-specific configurations.
// This pass is being ported over from the Grappler Remapper pass based on
// need/usage. File a bug to request porting over additional fusions.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
// implementations to decrease the number of operations needed to perform a
// computation.
struct FusedKernelMatcherPass
: public FusedKernelMatcherPassBase<FusedKernelMatcherPass> {
void runOnOperation() override;
};
bool IsActivationFunction(Operation *op) {
return isa<EluOp, ReluOp, Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are
// multiple such activations, one is returned (with no guarantee as to which
// one). If there are no activation functions that use the output, returns
// nullptr.
Operation *GetActivation(Value op) {
for (auto &use : op.getUses()) {
if (IsActivationFunction(use.getOwner())) return use.getOwner();
}
return nullptr;
}
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
// input. If there are multiple such BiasAdds, one is returned (with no
// guarantee as to which one). If there are no BiasAdds that use the output,
// returns a null BiasAddOp.
BiasAddOp GetBiasAdd(Value op) {
for (auto &use : op.getUses()) {
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
// If it's a BiasAdd, check that the conv op is the first input.
if (bias_add && bias_add.value() == op) return bias_add;
}
// No BiasAddOps found among uses.
return BiasAddOp();
}
// Performs a fusion of the following pattern(s), if possible:
// <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
//
// Note that fusion with activation is preferred, but a contraction and BiasAdd
// can also be replaced by a _FusedConv2D if there is no other activation
// function.
// i.e., this class also supports the following fusion:
// <Contraction> + BiasAdd -> <FusedContraction>
//
// TODO(b/158266331): Support fusing activation chains of arbitrary length.
template <typename SrcOpT, typename FusedOpT>
class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
public:
using OpRewritePattern<SrcOpT>::OpRewritePattern;
// Class users should override this method if there are any op-specific
// compatibility requirements between the contraction op and the BiasAdd op.
virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
PatternRewriter &rewriter) const {
return true;
}
LogicalResult matchAndRewrite(SrcOpT contraction,
PatternRewriter &rewriter) const override {
auto context = rewriter.getContext();
// We do support fusion only if the contraction operation is inside one of
// the expected operations with regions. Other operations can have semantics
// that is not compatible with fusion (e.g. region compilation).
if (!isa<func::FuncOp, IfOp, WhileOp>(contraction->getParentOp())) {
return rewriter.notifyMatchFailure(
contraction,
"fused operation must be nested inside a function, If or While");
}
// If the contraction is used in multiple places, fusing it will only create
// more contraction nodes, which is slower.
if (!contraction.getResult().hasOneUse())
return rewriter.notifyMatchFailure(contraction,
"result is used by multiple ops");
BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
if (!bias_add) {
return rewriter.notifyMatchFailure(
contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
}
if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
return rewriter.notifyMatchFailure(
contraction, "cannot fuse with the subsequent BiasAdd op");
}
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
context, bias_add.getOperation()->getName().stripDialect())};
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
// If there is an activation, only fuse it if this is the only op to use the
// result of the BiasAdd.
bool fuse_activation = activation && bias_add.output().hasOneUse();
Type result_type;
// Include info about the activation function if applicable.
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(
StringAttr::get(context, activation->getName().stripDialect()));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();
}
auto fused_loc = rewriter.getFusedLoc(locations);
// The fused contraction has the same operands as the original contraction
// with `bias` from the BiasAddOp appended.
SmallVector<Value, 4> operands(contraction.operand_begin(),
contraction.operand_end());
operands.push_back(bias_add.bias());
// The fused contraction has the same attributes as the original
// contraction, with two additions: the list of ops which have been fused
// together; epsilon (only with FusedBatchNorm).
std::vector<NamedAttribute> attrs = contraction->getAttrs();
ArrayAttr fused_ops_attr = ArrayAttr::get(context, fused_ops);
attrs.push_back(
NamedAttribute(StringAttr::get(context, "fused_ops"), fused_ops_attr));
// Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
// here.
Attribute epsilon = rewriter.getF32FloatAttr(0);
attrs.push_back(
NamedAttribute(StringAttr::get(context, "epsilon"), epsilon));
// Insert fused operation right before the BiasAdd operation to guarantee
// that bias value dominates the fused operation. We already verified that
// original operation has a single use, so this is safe to do.
auto *bias_add_op = bias_add.getOperation();
if (bias_add_op) rewriter.setInsertionPoint(bias_add_op);
Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
ValueRange(operands), attrs);
auto op_to_replace = fuse_activation ? activation : bias_add;
rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
return success();
}
};
// Performs a fusion of the following pattern(s), if possible:
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
class FuseConv2DBiasAdd
: public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
public:
using FuseContractionWithBiasAdd<Conv2DOp,
_FusedConv2DOp>::FuseContractionWithBiasAdd;
// Verify that the Conv2D and BiasAdd data formats match. This is necessary
// for the ops to fuse correctly, the fused Conv2D op has one data format
// attribute which is shared.
bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
PatternRewriter &rewriter) const override {
// Verify that the data formats match and are valid for fusion.
if (conv.data_format() != bias_add.data_format()) {
(void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
diag << "data format does not match Conv2D data format ("
<< bias_add.data_format() << " vs " << conv.data_format() << ")";
});
return false;
}
// Verify the data type is supported.
if (!conv.T().isF32() && !conv.T().isF64()) {
(void)rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
diag << "supported data types for _FusedConv2D are float and double, "
<< " but got " << conv.T();
});
return false;
}
return true;
}
};
// Performs a fusion of the following pattern(s), if possible:
// MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
class FuseMatMulBiasAdd
: public FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp> {
using FuseContractionWithBiasAdd<MatMulOp,
_FusedMatMulOp>::FuseContractionWithBiasAdd;
bool AreFuseCompatible(MatMulOp matmul, BiasAddOp bias_add,
PatternRewriter &rewriter) const override {
// FusedMatMul kernel supports limited set of data types.
if (!matmul.T().isF32() && !matmul.T().isBF16()) {
(void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) {
diag << "supported data types for _FusedMatMul are float and bfloat16, "
<< " but got " << matmul.T();
});
return false;
}
return true;
}
};
void FusedKernelMatcherPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
auto func = getOperation();
patterns.add<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> CreateFusedKernelMatcherPass() {
return std::make_unique<FusedKernelMatcherPass>();
}
} // namespace TF
} // namespace mlir