Fusing Mul(x , Sigmoid(x)) into Swish(x)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ba41e35..7c374fc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -691,6 +691,7 @@
"//tensorflow/core/kernels/mkl:mkl_reshape_op",
"//tensorflow/core/kernels/mkl:mkl_slice_op",
"//tensorflow/core/kernels/mkl:mkl_softmax_op",
+ "//tensorflow/core/kernels/mkl:mkl_swish_op",
"//tensorflow/core/kernels/mkl:mkl_transpose_op",
"//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
"//tensorflow/core/kernels/mkl:mkl_einsum_op",
diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
index 4a90a50..a63df9c 100644
--- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
+++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
@@ -737,6 +737,95 @@
// changed by other optimizers before the remapper optimizer.
TEST_F(FusedMatMulBiasAddAndGeluTest, Float32GeluExact) { RunTest<DT_FLOAT>(); }
+class MklRemapperSwishTest : public GrapplerTest {
+ protected:
+ template <DataType DTYPE>
+ void RunTest() {
+ using ::tensorflow::ops::Placeholder;
+
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto mul_shape = ops::Placeholder::Shape({64, 64});
+
+ // We will test four sitations:
+ // 1. y = x * sigmoid(x)
+ // 2. y = sigmoid(x) * x
+ // 3. y = sigmoid(x) * sigmoid(sigmoid(x))
+ // 4. y = sigmoid(sigmoid(x)) * sigmoid(x)
+ auto input = Placeholder(s.WithOpName("input"), DTYPE, mul_shape);
+ auto sigmoid1 = ops::Sigmoid(s.WithOpName("sigmoid1"), input);
+ auto sigmoid2 = ops::Sigmoid(s.WithOpName("sigmoid2"), input);
+ auto sigmoid3_1 = ops::Sigmoid(s.WithOpName("sigmoid3_1"), input);
+ auto sigmoid3_2 = ops::Sigmoid(s.WithOpName("sigmoid3_2"), sigmoid3_1);
+ auto sigmoid4_1 = ops::Sigmoid(s.WithOpName("sigmoid4_1"), input);
+ auto sigmoid4_2 = ops::Sigmoid(s.WithOpName("sigmoid4_2"), sigmoid4_1);
+ auto mul1 = ops::Mul(s.WithOpName("mul1"), input, sigmoid1);
+ auto mul2 = ops::Mul(s.WithOpName("mul2"), sigmoid2, input);
+ auto mul3 = ops::Mul(s.WithOpName("mul3"), sigmoid3_1, sigmoid3_2);
+ auto mul4 = ops::Mul(s.WithOpName("mul4"), sigmoid4_2, sigmoid4_1);
+ auto fetch1 = ops::Identity(s.WithOpName("fetch1"), mul1);
+ auto fetch2 = ops::Identity(s.WithOpName("fetch2"), mul2);
+ auto fetch3 = ops::Identity(s.WithOpName("fetch3"), mul3);
+ auto fetch4 = ops::Identity(s.WithOpName("fetch4"), mul4);
+ auto mul_t = GenerateTensorWithSetRandom<DTYPE>({64, 64});
+
+ GrapplerItem item;
+ item.fetch = {"fetch1", "fetch2", "fetch3", "fetch4"};
+ item.feed = {{"input", mul_t}};
+ TF_ASSERT_OK(s.ToGraphDef(&item.graph));
+
+ // Place all nodes on CPU.
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ item.graph.mutable_node(i)->set_device("/device:CPU:0");
+ }
+
+ Remapper optimizer(RewriterConfig::ON);
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "mul1") {
+ EXPECT_EQ(node.op(), "_MklSwish");
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), "input");
+ ++found;
+ }
+ if (node.name() == "mul2") {
+ EXPECT_EQ(node.op(), "_MklSwish");
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), "input");
+ ++found;
+ }
+ // mul3 won't be replaced by swish
+ // Coz of the limitation of patternMatcher with commutative op
+ if (node.name() == "mul4") {
+ EXPECT_EQ(node.op(), "_MklSwish");
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), "sigmoid4_1");
+ ++found;
+ }
+ }
+ EXPECT_EQ(found, 3);
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+ ASSERT_EQ(tensors_expected.size(), 4);
+ auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+ ASSERT_EQ(tensors.size(), 4);
+ float atol = 1e-6, rtol = 1e-6;
+ if (DTYPE == DT_BFLOAT16) {
+ atol = 1e-2;
+ rtol = 1e-2;
+ }
+ test::ExpectClose(tensors[0], tensors_expected[0], atol, rtol);
+ test::ExpectClose(tensors[1], tensors_expected[1], atol, rtol);
+ test::ExpectClose(tensors[2], tensors_expected[2], atol, rtol);
+ test::ExpectClose(tensors[3], tensors_expected[3], atol, rtol);
+ }
+};
+
+TEST_F(MklRemapperSwishTest, F32) { RunTest<DT_FLOAT>(); }
+TEST_F(MklRemapperSwishTest, BF16) { RunTest<DT_BFLOAT16>(); }
+
} // namespace grappler
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 789c4e5..605f59f 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -57,6 +57,9 @@
// (1) FusedBatchNorm + <Activation>
// (2) FusedBatchNorm + SideInput + <Activation>
//
+// Sigmoid + Mul -> _MklSwish // This fusion only works on Intel CPU.
+//
+//
// In all cases, the supported activation functions are Relu, Relu6, and Elu.
//
// Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
@@ -1090,6 +1093,65 @@
return (found_gelu_exact || found_gelu_approximate);
}
+bool FindSigmoidAndMul(RemapperContext* ctx, int node_index,
+ std::map<string, int>* matched_nodes_map,
+ std::set<int>* remove_node_indices) {
+ using utils::MatchingDirection;
+ using utils::NodeStatus;
+ // clang-format off
+ // Convert Sigmoid+Mul to Swish
+ // From Graph To Graph
+ // ----------- ---------
+ // Conv2D <- Filter(const) Conv2D <- Filter(const)
+ // ! !
+ // V V
+ // BiasAdd <- bias(const) BiasAdd <- bias(const)
+ // ! !
+ // V !
+ // ---- ---- !
+ // ! ! !
+ // ! V !
+ // ! Sigmoid !
+ // ! ! !
+ // --- --- !
+ // ! ! !
+ // ! ! !
+ // V V V
+ // Mul _MklSwish
+ // ! !
+ // V V
+
+ utils::OpTypePattern sigmoidmul_pattern{ "Mul", "mul_to_swish", NodeStatus::kReplace,
+ {
+ { "Sigmoid", "sigmoid", NodeStatus::kRemove,
+ {
+ { "*", "input", NodeStatus::kRemain}
+ }
+ },
+ { "*", "input", NodeStatus::kRemain}
+ }
+ };
+ // clang-format on
+ // check for data types
+ auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node();
+ if (!HasDataType(mul_node_def, DT_FLOAT) &&
+ !HasDataType(mul_node_def, DT_BFLOAT16))
+ return false;
+
+ if (!NodeIsOnCpu(mul_node_def)) return false;
+
+ bool found_op_type_match = false;
+ utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
+ &(ctx->graph_view));
+ matched_nodes_map->clear();
+ remove_node_indices->clear();
+ found_op_type_match = graph_matcher.GetMatchedNodes(
+ sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index),
+ matched_nodes_map, remove_node_indices);
+
+ return found_op_type_match;
+}
+
bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
FusedBatchNorm* matched) {
const auto* node_view = ctx.graph_view.GetNode(node_index);
@@ -1985,6 +2047,39 @@
return Status::OK();
}
+Status ReplaceSigmoidMulWithSwish(
+ RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
+ const std::set<int>& remove_node_indices,
+ std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
+ const GraphDef* graph = ctx->graph_view.graph();
+ const NodeDef* mul =
+ ctx->graph_view.GetNode(matched_nodes_map.at("mul_to_swish"))->node();
+ const NodeDef* sigmoid =
+ ctx->graph_view.GetNode(matched_nodes_map.at("sigmoid"))->node();
+
+ NodeDef fused_op;
+ fused_op.set_name(mul->name());
+ fused_op.set_op("_MklSwish");
+ fused_op.set_device(mul->device());
+ fused_op.add_input(sigmoid->input(0));
+
+ auto* attr = fused_op.mutable_attr();
+ (*attr)["T"] = mul->attr().at("T");
+
+ utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
+ Status status;
+ mutation->AddNode(std::move(fused_op), &status);
+ TF_RETURN_IF_ERROR(status);
+ TF_RETURN_IF_ERROR(mutation->Apply());
+
+ (*invalidated_nodes)[matched_nodes_map.at("mul_to_swish")] = true;
+
+ for (const auto& node_index : remove_node_indices) {
+ (*nodes_to_delete)[node_index] = true;
+ }
+ return Status::OK();
+}
+
Status AddFusedBatchNormExNode(RemapperContext* ctx,
const FusedBatchNormEx& matched,
std::vector<bool>* invalidated_nodes,
@@ -2580,6 +2675,17 @@
&nodes_to_delete, is_gelu_approximate));
continue;
}
+
+ // Remap Mul(x, Sigmoid(x)) pattern, fuse them into the Swish(x).
+ std::map<string, int> sigmoidmul_matched_nodes_map;
+ std::set<int> sigmoidmul_remove_node_indices;
+ if (FindSigmoidAndMul(&ctx, i, &sigmoidmul_matched_nodes_map,
+ &sigmoidmul_remove_node_indices)) {
+ TF_RETURN_IF_ERROR(ReplaceSigmoidMulWithSwish(
+ &ctx, sigmoidmul_matched_nodes_map, sigmoidmul_remove_node_indices,
+ &invalidated_nodes, &nodes_to_delete));
+ continue;
+ }
}
// Infer properties lazily in case they are not needed.
diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD
index 6e1e095..447d584 100644
--- a/tensorflow/core/kernels/mkl/BUILD
+++ b/tensorflow/core/kernels/mkl/BUILD
@@ -237,6 +237,20 @@
deps = ["@com_google_absl//absl/strings"] + MKL_TEST_DEPS,
)
+tf_cc_test_mkl(
+ name = "mkl_swish_op_test",
+ size = "small",
+ srcs = ["mkl_swish_op_test.cc"],
+ linkstatic = 1, # Fixes dyld error on MacOS.
+ deps = [
+ ":mkl_eltwise_activation_base_op",
+ ":mkl_swish_op",
+ "//tensorflow/cc:math_ops",
+ "//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core:direct_session",
+ ] + MKL_TEST_DEPS,
+)
+
tf_mkl_kernel_library(
name = "mkl_tfconv_op",
prefix = "mkl_tfconv",
@@ -308,6 +322,19 @@
)
tf_mkl_kernel_library(
+ name = "mkl_eltwise_activation_base_op",
+ prefix = "mkl_eltwise_activation_base",
+ deps = MKL_DEPS,
+)
+
+tf_mkl_kernel_library(
+ name = "mkl_swish_op",
+ hdrs = ["mkl_eltwise_activation_base_op.h"],
+ prefix = "mkl_swish",
+ deps = MKL_DEPS,
+)
+
+tf_mkl_kernel_library(
name = "mkl_softmax_op",
prefix = "mkl_softmax",
deps = MKL_SHORT_DEPS + ["//third_party/eigen3"],
diff --git a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h
new file mode 100644
index 0000000..8337153
--- /dev/null
+++ b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h
@@ -0,0 +1,288 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+#ifdef INTEL_MKL
+
+#include <unordered_map>
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/mkl_util.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+using mkldnn::algorithm;
+using mkldnn::eltwise_forward;
+using mkldnn::memory;
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc;
+
+namespace tensorflow {
+
+template <typename T>
+class MklEltwiseFwdParams {
+ public:
+ memory::dims src_dims;
+ memory::desc src_md;
+ algorithm alg_kind;
+ float alpha;
+ float beta;
+
+ MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md,
+ algorithm alg_kind, float alpha, float beta)
+ : src_dims(src_dims),
+ src_md(src_md),
+ alg_kind(alg_kind),
+ alpha(alpha),
+ beta(beta) {}
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
+ : MklPrimitive(engine(engine::kind::cpu, 0)) {
+ // create eltwise primitive
+ if (context_.eltwise_fwd == nullptr) {
+ Setup(fwdParams);
+ }
+ }
+
+ ~MklEltwiseFwdPrimitive() {}
+
+ // Eltwise forward execute
+ // src_data: input data buffer of src
+ // dst_data: output data buffer of dst
+ void Execute(const T* src_data, T* dst_data, OpKernelContext* op_context) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+ DCHECK_EQ(context_.fwd_primitives.size(),
+ context_.fwd_primitives_args.size());
+
+ std::vector<primitive> net;
+ net.push_back(eltwise_forward(*context_.fwd_pd));
+ std::vector<MemoryArgsMap> net_args;
+ net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
+ {MKLDNN_ARG_DST, *context_.dst_mem}});
+ // execute eltwise_fwd primitve
+ ExecutePrimitive(net, &net_args, GetEngine(), op_context);
+
+ // After execution, set data handle back.
+ context_.src_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
+ }
+
+ std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
+
+ private:
+ // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
+ struct EltwiseFwdContext {
+ // MKLDNN memory
+ std::shared_ptr<memory> src_mem;
+ std::shared_ptr<memory> dst_mem;
+
+ // desc & primitive desc
+ std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
+ std::shared_ptr<EltwiseFwdPd> fwd_pd;
+
+ // memory desc
+ std::shared_ptr<memory::desc> src_md;
+ std::shared_ptr<memory::desc> dst_md;
+
+ // memory primitive desc
+ std::shared_ptr<memory::desc> src_mpd;
+
+ // Eltwise primitive
+ std::shared_ptr<mkldnn::primitive> eltwise_fwd;
+
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
+
+ EltwiseFwdContext()
+ : src_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ fwd_pd(nullptr),
+ src_md(nullptr),
+ dst_md(nullptr),
+ src_mpd(nullptr),
+ eltwise_fwd(nullptr) {}
+ };
+
+ // Eltwise forward primitive setup
+ void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
+ // create memory descriptors for eltwise data with specified format
+ context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
+ context_.src_mpd.reset(new memory::desc(*context_.src_md));
+
+ // Create an eltwise forward descriptor and primitive descriptor
+ context_.fwd_desc.reset(new eltwise_forward::desc(
+ prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
+ fwdParams.alpha, fwdParams.beta));
+ context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
+ auto fwd_pd = context_.fwd_pd.get();
+
+ // Create memory primitive based on dummy data
+ context_.src_mem.reset(
+ new memory(fwd_pd->src_desc(), cpu_engine_, DummyData));
+ context_.dst_mem.reset(
+ new memory(fwd_pd->dst_desc(), cpu_engine_, DummyData));
+ // Create eltwise primitive and add it to net
+ context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
+ context_.fwd_primitives_args.push_back(
+ {{MKLDNN_ARG_SRC, *context_.src_mem},
+ {MKLDNN_ARG_DST, *context_.dst_mem}});
+ context_.fwd_primitives.push_back(*context_.eltwise_fwd);
+ }
+
+ struct EltwiseFwdContext context_;
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklEltwiseFwdPrimitive<T>* Get(
+ const MklEltwiseFwdParams<T>& fwdParams) {
+ MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
+
+ // Get a eltwise fwd primitive from the cached pool
+ eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
+ MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
+ fwdParams));
+ if (eltwise_forward == nullptr) {
+ eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
+ MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
+ fwdParams, eltwise_forward);
+ }
+
+ return eltwise_forward;
+ }
+
+ static MklEltwiseFwdPrimitiveFactory& GetInstance() {
+ static MklEltwiseFwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklEltwiseFwdPrimitiveFactory() {}
+ ~MklEltwiseFwdPrimitiveFactory() {}
+
+ static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
+ string prefix = "eltwise_fwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(fwdParams.src_dims);
+ key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
+ key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
+ key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
+ string key = CreateKey(fwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
+ MklPrimitive* op) {
+ string key = CreateKey(fwdParams);
+ this->SetOp(key, op);
+ }
+};
+
+template <typename Device, typename T, algorithm alg_kind>
+class MklEltwiseFwdActivationOpBase : public OpKernel {
+ public:
+ ~MklEltwiseFwdActivationOpBase() {}
+
+ explicit MklEltwiseFwdActivationOpBase(OpKernelConstruction* context,
+ float alpha, float beta)
+ : OpKernel(context), alpha_(alpha), beta_(beta) {}
+ virtual void Compute_Scalar(OpKernelContext* context) = 0;
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ const Tensor& src_tensor = context->input(0);
+ TensorShape src_shape = src_tensor.shape();
+ if (src_tensor.dims() == 0) {
+ Compute_Scalar(context);
+ return;
+ }
+ // Allocate output (dst) tensor
+ TensorShape dst_shape = src_shape;
+ Tensor* dst_tensor = nullptr;
+ // Nothing to compute, return.
+ if (src_shape.num_elements() == 0) {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ GetTensorDataIndex(0, context->num_outputs()),
+ dst_shape, &dst_tensor));
+ return;
+ }
+ // Set DNN primitive - src
+ MklDnnData<T> src(&cpu_engine);
+ memory::dims src_dims;
+ memory::desc src_md({}, memory::data_type::undef,
+ memory::format_tag::undef);
+
+ src_dims = TFShapeToMklDnnDims(src_tensor.shape());
+ auto src_strides = CalculateTFStrides(src_dims);
+
+ // Create blocked memory descriptor
+ src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
+
+ // Try to get an eltwise forward primitive from caching pool
+ MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
+ beta_);
+ MklEltwiseFwdPrimitive<T>* eltwise_fwd =
+ MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
+
+ const T* src_data = src_tensor.flat<T>().data();
+
+ OP_REQUIRES_OK(context, context->allocate_output(
+ GetTensorDataIndex(0, context->num_outputs()),
+ dst_shape, &dst_tensor));
+
+ T* dst_data = dst_tensor->flat<T>().data();
+ // execute eltwise
+ eltwise_fwd->Execute(src_data, dst_data, context);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ engine cpu_engine = engine(engine::kind::cpu, 0);
+
+ protected:
+ float alpha_;
+ float beta_;
+};
+
+// TODO : Implement Eltwise bwd / eltwiseGrad class
+
+} // namespace tensorflow
+
+#endif
\ No newline at end of file
diff --git a/tensorflow/core/kernels/mkl/mkl_swish_op.cc b/tensorflow/core/kernels/mkl/mkl_swish_op.cc
new file mode 100644
index 0000000..874d159
--- /dev/null
+++ b/tensorflow/core/kernels/mkl/mkl_swish_op.cc
@@ -0,0 +1,67 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+#ifdef INTEL_MKL
+
+#include "tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h"
+
+namespace tensorflow {
+
+template <typename Device, typename T>
+class MklSwishOp
+ : public MklEltwiseFwdActivationOpBase<Device, T,
+ mkldnn::algorithm::eltwise_swish> {
+ public:
+ ~MklSwishOp() {}
+
+ explicit MklSwishOp(OpKernelConstruction* context)
+ : MklEltwiseFwdActivationOpBase<Device, T,
+ mkldnn::algorithm::eltwise_swish>(
+ context, 1.0f, 0.0f) {}
+
+ virtual void Compute_Scalar(OpKernelContext* context) {
+ const Tensor& src_tensor = context->input(0);
+
+ // Get shapes of input tensors
+ TensorShape src_shape = src_tensor.shape();
+
+ Tensor* dst_tensor = nullptr;
+ void* user_i =
+ static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
+
+ TensorShape dst_shape = src_shape;
+
+ OP_REQUIRES_OK(context, context->allocate_output(
+ GetTensorDataIndex(0, context->num_outputs()),
+ dst_shape, &dst_tensor));
+
+ // swish(x) = x * sigmoid(x).
+ void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
+ T feature = (static_cast<T*>(user_i))[0];
+ T e1 = Eigen::numext::exp(-feature);
+ (static_cast<T*>(out_o))[0] = feature / (static_cast<T>(1) + e1);
+ return;
+ }
+};
+
+// register dnn kernels for supported operations and supported types
+#define REGISTER_SWISH_MKL_SUPPORTED_KERNELS_TYPES(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_MklSwish").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ MklSwishOp<CPUDevice, type>);
+TF_CALL_float(REGISTER_SWISH_MKL_SUPPORTED_KERNELS_TYPES);
+TF_CALL_bfloat16(REGISTER_SWISH_MKL_SUPPORTED_KERNELS_TYPES);
+
+} // namespace tensorflow
+
+#endif
\ No newline at end of file
diff --git a/tensorflow/core/kernels/mkl/mkl_swish_op_test.cc b/tensorflow/core/kernels/mkl/mkl_swish_op_test.cc
new file mode 100644
index 0000000..132f7f7
--- /dev/null
+++ b/tensorflow/core/kernels/mkl/mkl_swish_op_test.cc
@@ -0,0 +1,114 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include "absl/strings/match.h"
+#include "mkldnn.hpp"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stacktrace_handler.h"
+#include "tensorflow/core/platform/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+// This is a special case, because EIGEN kernels does not have Swish Kerenls.
+// Compare the performance of default tensorflow kernels (Eigen) with
+// MKL kernels on CPU.
+//
+// Then you could use below command to test mkl and eigen performance:
+// $ bazel run --action_env=TF_ENABLE_ONEDNN_OPTS=1 -c opt \
+// //tensorflow/core/kernels/mkl:mkl_swish_op_test -- --benchmark_filter=all
+//
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------//
+// Test Swish Kernels accuracy and performance //
+// --------------------------------------------------------------------------//
+template <typename T>
+static Graph* SwishGraph(const string& kind, const TensorShape& shape) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ DataType dtype = DataTypeToEnum<T>::v();
+ Tensor input_t(dtype, shape);
+ input_t.flat<T>().setRandom();
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ const bool isDefault = (kind == "Default");
+
+ Node* sigmoid;
+ Node* mul;
+ Node* swish;
+ if (isDefault) {
+ TF_CHECK_OK(NodeBuilder(graph->NewName("Default_sigmoid"), "Sigmoid")
+ .Input(input)
+ .Attr("T", dtype)
+ .Finalize(graph, &sigmoid));
+
+ TF_CHECK_OK(NodeBuilder(graph->NewName("Default_mul"), "Mul")
+ .Input(input)
+ .Input(sigmoid)
+ .Attr("T", dtype)
+ .Finalize(graph, &mul));
+ return graph;
+ }
+ // Mkl Swish op.
+ TF_CHECK_OK(NodeBuilder(graph->NewName("Mkl_swish"), "_MklSwish")
+ .Input(input)
+ .Attr("T", dtype)
+ .Finalize(graph, &swish));
+ return graph;
+}
+
+#define BM_SWISH(kind, A, B, C, D, type, T) \
+ static void BM_SWISH_##kind##_##type##_##A##_##B##_##C##_##D##_##T( \
+ ::testing::benchmark::State& state) { \
+ int64 num_computed_elements = (A) * (B) * (C) * (D); \
+ int64 flops_per_iter = num_computed_elements; \
+ \
+ test::Benchmark(#type, SwishGraph<T>(#kind, {A, B, C, D})).Run(state); \
+ state.SetItemsProcessed(state.iterations() * flops_per_iter); \
+ } \
+ BENCHMARK(BM_SWISH_##kind##_##type##_##A##_##B##_##C##_##D##_##T)
+
+#define BENCHMARK_SWISH(A, B, C, D, type, T) \
+ BM_SWISH(Default, A, B, C, D, type, T); \
+ BM_SWISH(Mkl, A, B, C, D, type, T);
+
+#define BENCHMARK_DTYPE(T) \
+ BENCHMARK_SWISH(1, 16, 16, 3, cpu, T); \
+ BENCHMARK_SWISH(16, 32, 32, 1, cpu, T); \
+ BENCHMARK_SWISH(16, 64, 64, 128, cpu, T); \
+ BENCHMARK_SWISH(32, 64, 64, 128, cpu, T); \
+ BENCHMARK_SWISH(32, 256, 256, 128, cpu, T); \
+ BENCHMARK_SWISH(32, 512, 512, 128, cpu, T);
+
+BENCHMARK_DTYPE(float)
+BENCHMARK_DTYPE(bfloat16)
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL
\ No newline at end of file
diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc
index 1397643..fc9d951 100644
--- a/tensorflow/core/ops/mkl_nn_ops.cc
+++ b/tensorflow/core/ops/mkl_nn_ops.cc
@@ -1779,6 +1779,17 @@
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklSwish")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("T: {float, bfloat16} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+MKL version of Swish operator. Uses MKL DNN APIs to implement Swish operator.
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
} // namespace tensorflow
#endif // INTEL_MKL