enable leaky relu in remapper
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 46c7afb..62cba11 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -361,7 +361,7 @@
}
bool IsSupportedActivation(const NodeDef& node) {
- return IsRelu(node) || IsRelu6(node) || IsElu(node);
+ return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
}
inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
@@ -450,6 +450,14 @@
IsInPreserveSet(ctx, bias_add_node_def))
return false;
+ // Get the contraction node
+ const auto* contraction_node_view =
+ bias_add_node_view->GetRegularFanin(0).node_view();
+ const auto* contraction_node_def = contraction_node_view->node();
+
+ // Currently, only conv + bias + leakyrelu is enabled
+ if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
+
// Check that data type and data format are supported on assigned device.
const ContractionWithBiasAddAndActivation pattern{base.contraction,
base.bias_add, node_index};
@@ -719,6 +727,16 @@
return false;
}
+ // Get the contraction node
+ const auto* bias_add_node_view =
+ add_node_view->GetRegularFanin(base.port_id).node_view();
+ const auto* contraction_node_view =
+ bias_add_node_view->GetRegularFanin(0).node_view();
+ const auto* contraction_node_def = contraction_node_view->node();
+
+ // Currently, only conv + bias + add + leakyrelu is enabled
+ if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
+
// We successfully found a Conv2D+BiasAdd+AddN+activation pattern.
const ContractionWithBiasAndAddActivation pattern{
base.contraction, base.bias_add, base.add, base.port_id, node_index};
@@ -919,7 +937,8 @@
return false;
}
-void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
+void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
+ const NodeDef* activation = nullptr) {
DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
auto* attr = fused_conv2d->mutable_attr();
@@ -932,10 +951,15 @@
(*attr)["dilations"] = src_attr.at("dilations");
(*attr)["data_format"] = src_attr.at("data_format");
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
+ if (activation != nullptr && IsLeakyRelu(*activation)) {
+ auto& activation_attr = activation->attr();
+ (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
+ }
}
void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
- NodeDef* fused_dw_conv2d) {
+ NodeDef* fused_dw_conv2d,
+ const NodeDef* activation = nullptr) {
DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
<< "Input node must be a DepthwiseConv2dNative";
@@ -947,6 +971,10 @@
(*attr)["padding"] = src_attr.at("padding");
(*attr)["dilations"] = src_attr.at("dilations");
(*attr)["data_format"] = src_attr.at("data_format");
+ if (activation != nullptr && IsLeakyRelu(*activation)) {
+ auto& activation_attr = activation->attr();
+ (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
+ }
}
void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
@@ -1049,6 +1077,7 @@
const NodeDef& contraction = graph->node(matched.contraction);
const NodeDef& bias_add = graph->node(matched.bias_add);
const NodeDef& activation = graph->node(matched.activation);
+
VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
<< activation.op() << ":"
<< " activation=" << activation.name()
@@ -1064,7 +1093,8 @@
if (IsConv2D(contraction)) {
fused_op.set_op(kFusedConv2D);
- CopyConv2DAttributes(contraction, &fused_op);
+ // leaky relu has a special attribute alpha
+ CopyConv2DAttributes(contraction, &fused_op, &activation);
} else if (IsDepthwiseConv2dNative(contraction)) {
fused_op.set_op(kFusedDepthwiseConv2dNative);
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
@@ -1284,7 +1314,7 @@
fused_conv2d.add_input(add.input(1 - matched.port_id));
CopyConv2DAttributes(contraction, &fused_conv2d);
- SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", "Relu"}, 2);
+ SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2);
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
Status status;