enable leaky relu in remapper
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 46c7afb..62cba11 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -361,7 +361,7 @@
 }
 
 bool IsSupportedActivation(const NodeDef& node) {
-  return IsRelu(node) || IsRelu6(node) || IsElu(node);
+  return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
 }
 
 inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
@@ -450,6 +450,14 @@
       IsInPreserveSet(ctx, bias_add_node_def))
     return false;
 
+  // Get the contraction node
+  const auto* contraction_node_view =
+      bias_add_node_view->GetRegularFanin(0).node_view();
+  const auto* contraction_node_def = contraction_node_view->node();
+
+  // Currently, only conv + bias + leakyrelu is enabled
+  if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
+
   // Check that data type and data format are supported on assigned device.
   const ContractionWithBiasAddAndActivation pattern{base.contraction,
                                                     base.bias_add, node_index};
@@ -719,6 +727,16 @@
     return false;
   }
 
+  // Get the contraction node
+  const auto* bias_add_node_view =
+      add_node_view->GetRegularFanin(base.port_id).node_view();
+  const auto* contraction_node_view =
+      bias_add_node_view->GetRegularFanin(0).node_view();
+  const auto* contraction_node_def = contraction_node_view->node();
+
+  // Currently, only conv + bias + add + leakyrelu is enabled
+  if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
+
   // We successfully found a Conv2D+BiasAdd+AddN+activation pattern.
   const ContractionWithBiasAndAddActivation pattern{
       base.contraction, base.bias_add, base.add, base.port_id, node_index};
@@ -919,7 +937,8 @@
   return false;
 }
 
-void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
+void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
+                          const NodeDef* activation = nullptr) {
   DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
 
   auto* attr = fused_conv2d->mutable_attr();
@@ -932,10 +951,15 @@
   (*attr)["dilations"] = src_attr.at("dilations");
   (*attr)["data_format"] = src_attr.at("data_format");
   (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
+  if (activation != nullptr && IsLeakyRelu(*activation)) {
+    auto& activation_attr = activation->attr();
+    (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
+  }
 }
 
 void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
-                                         NodeDef* fused_dw_conv2d) {
+                                         NodeDef* fused_dw_conv2d,
+                                         const NodeDef* activation = nullptr) {
   DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
       << "Input node must be a DepthwiseConv2dNative";
 
@@ -947,6 +971,10 @@
   (*attr)["padding"] = src_attr.at("padding");
   (*attr)["dilations"] = src_attr.at("dilations");
   (*attr)["data_format"] = src_attr.at("data_format");
+  if (activation != nullptr && IsLeakyRelu(*activation)) {
+    auto& activation_attr = activation->attr();
+    (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
+  }
 }
 
 void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
@@ -1049,6 +1077,7 @@
   const NodeDef& contraction = graph->node(matched.contraction);
   const NodeDef& bias_add = graph->node(matched.bias_add);
   const NodeDef& activation = graph->node(matched.activation);
+
   VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
           << activation.op() << ":"
           << " activation=" << activation.name()
@@ -1064,7 +1093,8 @@
 
   if (IsConv2D(contraction)) {
     fused_op.set_op(kFusedConv2D);
-    CopyConv2DAttributes(contraction, &fused_op);
+    // leaky relu has a special attribute alpha
+    CopyConv2DAttributes(contraction, &fused_op, &activation);
   } else if (IsDepthwiseConv2dNative(contraction)) {
     fused_op.set_op(kFusedDepthwiseConv2dNative);
     CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
@@ -1284,7 +1314,7 @@
   fused_conv2d.add_input(add.input(1 - matched.port_id));
 
   CopyConv2DAttributes(contraction, &fused_conv2d);
-  SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", "Relu"}, 2);
+  SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2);
 
   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
   Status status;