Minor updates for review comments.
diff --git a/tensorflow/core/transforms/remapper/pass.cc b/tensorflow/core/transforms/remapper/pass.cc
index 166b830..4f08c40 100644
--- a/tensorflow/core/transforms/remapper/pass.cc
+++ b/tensorflow/core/transforms/remapper/pass.cc
@@ -247,40 +247,37 @@
BasePattern base_pattern;
Pattern pattern;
if (!matchPattern(op, base_pattern, pattern)) return failure();
- if constexpr (std::is_same<BasePatternRewriter,
- ContractionBiasAddRewriter>::value) {
- if (!this->helper_.IsDeviceCompatible(pattern)) return failure();
- Operation *&contraction_op = pattern.contraction;
- Operation *&bias_add_op = pattern.bias_add;
- Operation *&activation_op = pattern.activation;
- const std::string activation_op_name =
- activation_op->getName().stripDialect().str();
- // Currently, supported activations are:
- // _FusedMatMul: Relu, Relu6, Elu, LeakyRelu, Tanh, and Sigmoid
- // _Fused*Conv*: Relu, Relu6, Elu and LeakyRelu
- if ((activation_op_name == "Tanh" || activation_op_name == "Sigmoid") &&
- !this->helper_.getDialect()->IsMatMul(contraction_op)) {
- return failure();
- }
-
- std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
- rewriter, this->helper_, contraction_op, bias_add_op);
- SmallVector<Location> fused_locs{state->location,
- activation_op->getLoc()};
- state->location = rewriter.getFusedLoc(fused_locs);
- state->attributes.set("fused_ops", rewriter.getStrArrayAttr(
- {"BiasAdd", activation_op_name}));
- if (this->helper_.getDialect()->IsLeakyRelu(activation_op)) {
- state->attributes.set("leakyrelu_alpha",
- activation_op->getAttr("alpha"));
- }
- Operation *fused_op = rewriter.create(*state);
- TFOp(fused_op).setName(TFOp(op).nameAttr());
- rewriter.replaceOp(op, fused_op->getResults());
- return success();
+ if constexpr (!std::is_same<BasePatternRewriter,
+ ContractionBiasAddRewriter>::value) {
+ return failure();
+ }
+ if (!this->helper_.IsDeviceCompatible(pattern)) return failure();
+ Operation *&contraction_op = pattern.contraction;
+ Operation *&bias_add_op = pattern.bias_add;
+ Operation *&activation_op = pattern.activation;
+ const std::string activation_op_name =
+ activation_op->getName().stripDialect().str();
+ // Currently, supported activations are:
+ // _FusedMatMul: Relu, Relu6, Elu, LeakyRelu, Tanh, and Sigmoid
+ // _Fused*Conv*: Relu, Relu6, Elu and LeakyRelu
+ if ((activation_op_name == "Tanh" || activation_op_name == "Sigmoid") &&
+ !this->helper_.getDialect()->IsMatMul(contraction_op)) {
+ return failure();
}
- return failure();
+ std::unique_ptr<OperationState> state = GetContractionBiasAddOpState(
+ rewriter, this->helper_, contraction_op, bias_add_op);
+ SmallVector<Location> fused_locs{state->location, activation_op->getLoc()};
+ state->location = rewriter.getFusedLoc(fused_locs);
+ state->attributes.set(
+ "fused_ops", rewriter.getStrArrayAttr({"BiasAdd", activation_op_name}));
+ if (this->helper_.getDialect()->IsLeakyRelu(activation_op)) {
+ state->attributes.set("leakyrelu_alpha", activation_op->getAttr("alpha"));
+ }
+ Operation *fused_op = rewriter.create(*state);
+ TFOp(fused_op).setName(TFOp(op).nameAttr());
+ rewriter.replaceOp(op, fused_op->getResults());
+ return success();
}
};
diff --git a/tensorflow/core/transforms/remapper/remapping_helper.h b/tensorflow/core/transforms/remapper/remapping_helper.h
index 5313635..ae7642e 100644
--- a/tensorflow/core/transforms/remapper/remapping_helper.h
+++ b/tensorflow/core/transforms/remapper/remapping_helper.h
@@ -120,12 +120,9 @@
// This function is currently used by contraction ops.
bool IsGpuCompatibleDataType(Operation* contraction_op,
const StringRef& attr_name = "T") const {
- Type dtype;
- if (auto attr = contraction_op->getAttrOfType<TypeAttr>(attr_name)) {
- dtype = attr.getValue();
- } else {
- return false;
- }
+ auto attr = contraction_op->getAttrOfType<TypeAttr>(attr_name);
+ if (!attr) return false;
+ Type dtype = attr.getValue();
if (dialect_->IsConv2D(contraction_op)) {
return dtype.isa<Float32Type>();
} else if (dialect_->IsMatMul(contraction_op)) {
@@ -138,13 +135,9 @@
// This function is currently used by contraction ops.
bool IsCpuCompatibleDataType(Operation* contraction_op,
const StringRef& attr_name = "T") const {
- Type dtype;
- if (auto attr = contraction_op->getAttrOfType<TypeAttr>(attr_name)) {
- dtype = attr.getValue();
- } else {
- return false;
- }
-
+ auto attr = contraction_op->getAttrOfType<TypeAttr>(attr_name);
+ if (!attr) return false;
+ Type dtype = attr.getValue();
if (is_onednn_enabled_) {
// Only contraction ops (MatMul, Conv2D, Conv3D, and
// DepthwiseConv2dNative) and BatchMatMul are supported. BatchMatMul