Enable Eigen MatMul + Bias + LeakyRelu fusion
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 2ea2a8a..a21366d 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3551,6 +3551,7 @@
":ops_util",
":quantized_ops",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:client_session",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc
index 9ba9ed6..b24797d 100644
--- a/tensorflow/core/kernels/matmul_op_fused.cc
+++ b/tensorflow/core/kernels/matmul_op_fused.cc
@@ -86,7 +86,12 @@
BiasAddArgs<T> bias_add_args;
if (BiasAddArgs<T>::IsSupported(fusion)) {
- OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
+ if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
+ OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
+ &fusion_args.leakyrelu_alpha));
+ } else {
+ OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
+ }
}
switch (fusion) {
@@ -102,6 +107,10 @@
case FusedComputationType::kBiasAddWithElu:
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
break;
+ case FusedComputationType::kBiasAddWithLeakyRelu:
+ out.device(d) = lhs.contract(rhs, dim_pair,
+ WithBiasAddAndLeakyRelu<T>(bias_add_args));
+ break;
case FusedComputationType::kUndefined:
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
break;
@@ -148,10 +157,13 @@
using FCT = FusedComputationType;
if (std::is_same<Device, CPUDevice>::value) {
- patterns = {{FCT::kBiasAdd, {"BiasAdd"}},
- {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
- {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
- {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}};
+ patterns = {
+ {FCT::kBiasAdd, {"BiasAdd"}},
+ {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
+ {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
+ {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
+ {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
+ };
}
OP_REQUIRES_OK(context, InitializeFusedComputation(
diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc
index 4f986e3..a18ec39 100644
--- a/tensorflow/core/kernels/matmul_op_test.cc
+++ b/tensorflow/core/kernels/matmul_op_test.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "absl/algorithm/container.h"
+#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h"
@@ -137,6 +138,8 @@
ops::Relu6(root.WithOpName("with_activation"), with_bias);
} else if (activation_type == "Elu") {
ops::Elu(root.WithOpName("with_activation"), with_bias);
+ } else if (activation_type == "LeakyRelu") {
+ ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias);
} else {
ops::Identity(root.WithOpName("with_activation"), with_bias);
}
@@ -291,7 +294,7 @@
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) {
- for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+ for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(256, 256, 256, false, false,
activation);
this->VerifyConv2DWithBiasAndActivation(256, 256, 256, true, false,
@@ -304,21 +307,21 @@
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) {
- for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+ for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(1, 256, 256, false, false,
activation);
}
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) {
- for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+ for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(256, 256, 1, false, false,
activation);
}
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) {
- for (const string& activation : {"Relu", "Relu6", "Elu"}) {
+ for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(1, 256, 1, false, false,
activation);
}