blob: 60174064cf405744a76d330cc077eff27236dc02 [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/transforms/remapper/pass.h"
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include "mlir/Dialect/PDL/IR/PDL.h" // from @llvm-project
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Parser/Parser.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/ir/dialect.h"
#include "tensorflow/core/ir/tf_op_wrapper.h"
#include "tensorflow/core/transforms/pass_detail.h"
#include "tensorflow/core/transforms/remapper/remapping_helper.h"
#include "tensorflow/core/transforms/utils/pdll/utils.h"
#include "tensorflow/core/transforms/utils/utils.h"
#include "tensorflow/core/util/util.h"
namespace mlir {
namespace tfg {
namespace mkl {
#include "tensorflow/core/transforms/remapper/pdll/MklPDLLPatterns.h.inc"
} // namespace mkl
// Convert Sigmoid+Mul to Swish
// Mul(x, Sigmoid(x)) --> _MklSwish(x)
class MatchMulSigmoid : public RewritePattern {
public:
explicit MatchMulSigmoid(MLIRContext *context)
: RewritePattern("tfg.Mul", PatternBenefit(/*benefit=*/1), context),
sigmoid_name_("tfg.Sigmoid", context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
TypeAttr dtype_attr = op->getAttrOfType<TypeAttr>("T");
if (!dtype_attr.getValue().isa<Float32Type>() &&
!dtype_attr.getValue().isa<BFloat16Type>()) {
return failure();
}
if (!util::OpHasDevice(op, tensorflow::DEVICE_CPU)) return failure();
TFOp mul_wrapper(op);
Value sigmoid = op->getOperand(0);
Value x = op->getOperand(1);
auto sigmoidOperandEqToX = [&](Value sigmoid, Value x) {
Operation *op = sigmoid.getDefiningOp();
return op && op->getName() == sigmoid_name_ && op->getOperand(0) == x;
};
if (!sigmoidOperandEqToX(sigmoid, x)) {
// The operands are commutative and it may have both sigmoid operands.
// Swap them then check it again.
std::swap(sigmoid, x);
if (!sigmoidOperandEqToX(sigmoid, x)) return failure();
}
SmallVector<Value> operands;
// Set up non-control operand.
operands.push_back(x);
// Control operands come after regular operands.
llvm::append_range(operands, mul_wrapper.getControlOperands());
Operation *new_op =
rewriter.create(op->getLoc(), rewriter.getStringAttr("tfg._MklSwish"),
operands, op->getResultTypes(), op->getAttrs());
rewriter.replaceOp(op, new_op->getResults());
return success();
}
private:
// This is used to eliminate the string comparison by caching the handle of an
// operation name.
OperationName sigmoid_name_;
};
// This enum class is used as a template parameter and meant for alias to tfg op
// name.
// TODO(intel-tf): Add more items as needed.
enum class OpKind { Relu, Relu6, Elu, LeakyRelu, Tanh, Sigmoid };
inline std::string GetTfgOpName(OpKind op_kind) {
switch (op_kind) {
case OpKind::Relu:
return "tfg.Relu";
case OpKind::Relu6:
return "tfg.Relu6";
case OpKind::Elu:
return "tfg.Elu";
case OpKind::LeakyRelu:
return "tfg.LeakyRelu";
case OpKind::Tanh:
return "tfg.Tanh";
case OpKind::Sigmoid:
return "tfg.Sigmoid";
default:
return "tfg.NoOp";
}
}
class RemapperPatternBase : public RewritePattern {
public:
RemapperPatternBase(StringRef opName, OpPropertyHelper &helper,
PatternBenefit benefit = PatternBenefit(1))
: RewritePattern(opName, benefit, helper.getDialect()->getContext()),
helper_(helper) {}
RemapperPatternBase(MatchAnyOpTypeTag tag, OpPropertyHelper &helper,
PatternBenefit benefit = PatternBenefit(1))
: RewritePattern(tag, benefit, helper.getDialect()->getContext()),
helper_(helper) {}
protected:
OpPropertyHelper helper_;
};
static std::unique_ptr<OperationState> GetContractionBiasAddOpState(
OpBuilder &builder, const OpPropertyHelper &helper,
Operation *contraction_op, Operation *bias_add_op) {
// Fused op name dependes on original contraction op name.
std::string fused_op_name;
if (helper.getDialect()->IsConv2D(contraction_op)) {
fused_op_name = "tfg._FusedConv2D";
} else if (helper.getDialect()->IsMatMul(contraction_op)) {
fused_op_name = "tfg._FusedMatMul";
} else if (helper.getDialect()->IsDepthwiseConv2dNative(contraction_op)) {
fused_op_name = "tfg._FusedDepthwiseConv2dNative";
} else if (helper.getDialect()->IsConv3D(contraction_op)) {
fused_op_name = "tfg._FusedConv3D";
} else {
return nullptr;
}
SmallVector<Location> fused_locs{contraction_op->getLoc(),
bias_add_op->getLoc()};
auto state = std::make_unique<OperationState>(builder.getFusedLoc(fused_locs),
fused_op_name);
SmallVector<Value> operands;
Value input = contraction_op->getOperand(0);
Value filter = contraction_op->getOperand(1);
Value bias = bias_add_op->getOperand(1);
operands.push_back(input);
operands.push_back(filter);
operands.push_back(bias);
state->addOperands(operands);
state->addOperands(TFOp(contraction_op).getControlOperands());
state->addOperands(TFOp(bias_add_op).getControlOperands());
state->addTypes(bias_add_op->getResultTypes());
state->attributes = contraction_op->getAttrs();
state->attributes.set("fused_ops", builder.getStrArrayAttr({"BiasAdd"}));
state->attributes.set("num_args", builder.getI32IntegerAttr(1));
// Default values for epsilon and leakyrelu_alpha
state->attributes.set("epsilon", builder.getF32FloatAttr(0.0001));
state->attributes.set("leakyrelu_alpha", builder.getF32FloatAttr(0.2));
return state;
}
// Contraction + BiasAdd
// TODO(intel-tf): Support Contraction + {Add, AddV2} fusion in the case it has
// similar semantic of contraction + BiasAdd
class ContractionBiasAddRewriter : public RemapperPatternBase {
public:
explicit ContractionBiasAddRewriter(OpPropertyHelper &helper)
: RemapperPatternBase("tfg.BiasAdd", helper, PatternBenefit(1)) {}
// Constructor used by derived pattern rewritter class that may have
// different root operation name. Currently, pattern is
// matched from root op to its inputs.
explicit ContractionBiasAddRewriter(StringRef op_name,
OpPropertyHelper &helper,
PatternBenefit benefit)
: RemapperPatternBase(op_name, helper, benefit) {}
using Pattern = ContractionBiasAdd;
bool matchPattern(Operation *op, Pattern &pattern) const {
// Not allowing control flow on BiasAdd
if (helper_.HasControlOperandsOrResultUsers(op)) return false;
Operation *contraction_op = op->getOperand(0).getDefiningOp();
if (!helper_.IsContraction(contraction_op) ||
helper_.HasControlOperandsOrResultUsers(contraction_op) ||
!helper_.HaveSameDataType(op, contraction_op) ||
!helper_.HasAtMostOneUserOfResult0(contraction_op)) {
return false;
}
pattern.contraction = contraction_op;
pattern.bias_add = op;
return true;
}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Pattern pattern;
if (!matchPattern(op, pattern)) return failure();
if (!helper_.IsDeviceCompatible(pattern)) return failure();
std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
rewriter, helper_, pattern.contraction, pattern.bias_add);
Operation *fused_op = rewriter.create(*state);
TFOp(fused_op).setName(TFOp(op).nameAttr());
rewriter.replaceOp(op, fused_op->getResults());
return success();
}
};
// BasePattern + Activation
template <typename BasePatternRewriter, OpKind activation>
class BasePatternActivationRewriter : public BasePatternRewriter {
public:
explicit BasePatternActivationRewriter(OpPropertyHelper &helper)
: BasePatternRewriter(GetTfgOpName(activation), helper,
PatternBenefit(1)) {}
using BasePattern = typename BasePatternRewriter::Pattern;
using Pattern = std::conditional_t<
std::is_same<BasePatternRewriter, ContractionBiasAddRewriter>::value,
ContractionBiasAddActivation, void>;
bool matchPattern(Operation *op, BasePattern &base_pattern,
Pattern &pattern) const {
// Although template instantiation guarantuees that only valid activation is
// set as the root operation, a sanity check is added here.
if (this->helper_.getDialect()->IsNoOp(op)) return false;
if (this->helper_.HasControlOperandsOrResultUsers(op)) return false;
// TODO(intel-tf): Add support for more patterns.
if constexpr (std::is_same<BasePattern, ContractionBiasAdd>::value &&
std::is_same<Pattern, ContractionBiasAddActivation>::value) {
Operation *bias_add_op = op->getOperand(0).getDefiningOp();
if (!this->helper_.getDialect()->IsBiasAdd(bias_add_op) ||
!this->helper_.HaveSameDataType(op, bias_add_op) ||
!this->helper_.HasAtMostOneUserOfResult0(bias_add_op)) {
return false;
}
if (!BasePatternRewriter::matchPattern(bias_add_op, base_pattern)) {
return false;
}
pattern.contraction = base_pattern.contraction;
pattern.bias_add = base_pattern.bias_add;
pattern.activation = op;
return true;
}
return false;
}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
BasePattern base_pattern;
Pattern pattern;
if (!matchPattern(op, base_pattern, pattern)) return failure();
if constexpr (!std::is_same<BasePatternRewriter,
ContractionBiasAddRewriter>::value) {
return failure();
}
if (!this->helper_.IsDeviceCompatible(pattern)) return failure();
Operation *&contraction_op = pattern.contraction;
Operation *&bias_add_op = pattern.bias_add;
Operation *&activation_op = pattern.activation;
const std::string activation_op_name =
activation_op->getName().stripDialect().str();
// Currently, supported activations are:
// _FusedMatMul: Relu, Relu6, Elu, LeakyRelu, Tanh, and Sigmoid
// _Fused*Conv*: Relu, Relu6, Elu and LeakyRelu
if ((activation_op_name == "Tanh" || activation_op_name == "Sigmoid") &&
!this->helper_.getDialect()->IsMatMul(contraction_op)) {
return failure();
}
std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
rewriter, this->helper_, contraction_op, bias_add_op);
SmallVector<Location> fused_locs{state->location, activation_op->getLoc()};
state->location = rewriter.getFusedLoc(fused_locs);
state->attributes.set(
"fused_ops", rewriter.getStrArrayAttr({"BiasAdd", activation_op_name}));
if (this->helper_.getDialect()->IsLeakyRelu(activation_op)) {
state->attributes.set("leakyrelu_alpha", activation_op->getAttr("alpha"));
}
Operation *fused_op = rewriter.create(*state);
TFOp(fused_op).setName(TFOp(op).nameAttr());
rewriter.replaceOp(op, fused_op->getResults());
return success();
}
};
template <template <OpKind> class PatternT, OpKind... op_kinds,
typename... Args>
static void InsertPatterns(RewritePatternSet &patterns, Args &&...args) {
patterns.insert<PatternT<op_kinds>...>(std::forward<Args>(args)...);
}
template <OpKind activation>
using ContractionBiasAddActivationRewriter =
BasePatternActivationRewriter<ContractionBiasAddRewriter, activation>;
class Remapper : public RemapperBase<Remapper> {
public:
Remapper() = default;
explicit Remapper(bool enable_onednn_patterns, bool xla_auto_clustering) {
enable_onednn_patterns_ = enable_onednn_patterns;
xla_auto_clustering_ = xla_auto_clustering;
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
}
LogicalResult initialize(MLIRContext *context) override {
helper_ = OpPropertyHelper(context->getOrLoadDialect<TFGraphDialect>(),
enable_onednn_patterns_, xla_auto_clustering_);
RewritePatternSet patterns(context);
populateRemapperPatterns(context, patterns);
RegisterPDLLUtils(patterns);
final_patterns_ = std::move(patterns);
return success();
}
void runOnOperation() override;
private:
void populateRemapperPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
if (verify_pdll_patterns_only_) {
populateRemapperPDLLPatterns(patterns);
return;
}
if (enable_onednn_patterns_) {
patterns.insert<MatchMulSigmoid>(context);
// TODO(chiahungduan): Currently, the only pattern implemented in PDLL is
// the same one as `MatchMulSigmoid`. Remove the one of them when there's
// a decision that which one is preferred.
populateRemapperPDLLPatterns(patterns);
}
patterns.insert<ContractionBiasAddRewriter>(helper_);
// Insert multiple pattern rewriters from template instantiations by
// activation ops.
InsertPatterns<ContractionBiasAddActivationRewriter, OpKind::Relu,
OpKind::Relu6, OpKind::Elu, OpKind::LeakyRelu, OpKind::Tanh,
OpKind::Sigmoid>(patterns, helper_);
}
void populateRemapperPDLLPatterns(RewritePatternSet &patterns) {
mkl::populateGeneratedPDLLPatterns(patterns);
}
FrozenRewritePatternSet final_patterns_;
OpPropertyHelper helper_;
};
void Remapper::runOnOperation() {
if (failed(applyPatternsAndFoldGreedily(getOperation(), final_patterns_))) {
signalPassFailure();
}
}
std::unique_ptr<Pass> CreateRemapperPass(bool enable_onednn_patterns,
bool xla_auto_clustering) {
return std::make_unique<Remapper>(enable_onednn_patterns,
xla_auto_clustering);
}
} // namespace tfg
} // namespace mlir