Fusing Mul(x , Sigmoid(x)) into Swish(x)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ba41e35..7c374fc 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -691,6 +691,7 @@
         "//tensorflow/core/kernels/mkl:mkl_reshape_op",
         "//tensorflow/core/kernels/mkl:mkl_slice_op",
         "//tensorflow/core/kernels/mkl:mkl_softmax_op",
+        "//tensorflow/core/kernels/mkl:mkl_swish_op",        
         "//tensorflow/core/kernels/mkl:mkl_transpose_op",
         "//tensorflow/core/kernels/mkl:mkl_batch_matmul_op",
         "//tensorflow/core/kernels/mkl:mkl_einsum_op",
diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
index 4a90a50..a63df9c 100644
--- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
+++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc
@@ -737,6 +737,95 @@
 // changed by other optimizers before the remapper optimizer.
 TEST_F(FusedMatMulBiasAddAndGeluTest, Float32GeluExact) { RunTest<DT_FLOAT>(); }
 
+class MklRemapperSwishTest : public GrapplerTest {
+ protected:
+  template <DataType DTYPE>
+  void RunTest() {
+    using ::tensorflow::ops::Placeholder;
+
+    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+    auto mul_shape = ops::Placeholder::Shape({64, 64});
+
+    // We will test four sitations:
+    //  1. y = x * sigmoid(x)
+    //  2. y = sigmoid(x) * x
+    //  3. y = sigmoid(x) * sigmoid(sigmoid(x))
+    //  4. y = sigmoid(sigmoid(x)) * sigmoid(x)
+    auto input = Placeholder(s.WithOpName("input"), DTYPE, mul_shape);
+    auto sigmoid1 = ops::Sigmoid(s.WithOpName("sigmoid1"), input);
+    auto sigmoid2 = ops::Sigmoid(s.WithOpName("sigmoid2"), input);
+    auto sigmoid3_1 = ops::Sigmoid(s.WithOpName("sigmoid3_1"), input);
+    auto sigmoid3_2 = ops::Sigmoid(s.WithOpName("sigmoid3_2"), sigmoid3_1);
+    auto sigmoid4_1 = ops::Sigmoid(s.WithOpName("sigmoid4_1"), input);
+    auto sigmoid4_2 = ops::Sigmoid(s.WithOpName("sigmoid4_2"), sigmoid4_1);
+    auto mul1 = ops::Mul(s.WithOpName("mul1"), input, sigmoid1);
+    auto mul2 = ops::Mul(s.WithOpName("mul2"), sigmoid2, input);
+    auto mul3 = ops::Mul(s.WithOpName("mul3"), sigmoid3_1, sigmoid3_2);
+    auto mul4 = ops::Mul(s.WithOpName("mul4"), sigmoid4_2, sigmoid4_1);
+    auto fetch1 = ops::Identity(s.WithOpName("fetch1"), mul1);
+    auto fetch2 = ops::Identity(s.WithOpName("fetch2"), mul2);
+    auto fetch3 = ops::Identity(s.WithOpName("fetch3"), mul3);
+    auto fetch4 = ops::Identity(s.WithOpName("fetch4"), mul4);
+    auto mul_t = GenerateTensorWithSetRandom<DTYPE>({64, 64});
+
+    GrapplerItem item;
+    item.fetch = {"fetch1", "fetch2", "fetch3", "fetch4"};
+    item.feed = {{"input", mul_t}};
+    TF_ASSERT_OK(s.ToGraphDef(&item.graph));
+
+    // Place all nodes on CPU.
+    for (int i = 0; i < item.graph.node_size(); ++i) {
+      item.graph.mutable_node(i)->set_device("/device:CPU:0");
+    }
+
+    Remapper optimizer(RewriterConfig::ON);
+    GraphDef output;
+    TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+    int found = 0;
+    for (const NodeDef& node : output.node()) {
+      if (node.name() == "mul1") {
+        EXPECT_EQ(node.op(), "_MklSwish");
+        ASSERT_EQ(node.input_size(), 1);
+        EXPECT_EQ(node.input(0), "input");
+        ++found;
+      }
+      if (node.name() == "mul2") {
+        EXPECT_EQ(node.op(), "_MklSwish");
+        ASSERT_EQ(node.input_size(), 1);
+        EXPECT_EQ(node.input(0), "input");
+        ++found;
+      }
+      // mul3 won't be replaced by swish
+      // Coz of the limitation of patternMatcher with commutative op
+      if (node.name() == "mul4") {
+        EXPECT_EQ(node.op(), "_MklSwish");
+        ASSERT_EQ(node.input_size(), 1);
+        EXPECT_EQ(node.input(0), "sigmoid4_1");
+        ++found;
+      }
+    }
+    EXPECT_EQ(found, 3);
+
+    auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
+    ASSERT_EQ(tensors_expected.size(), 4);
+    auto tensors = EvaluateNodes(output, item.fetch, item.feed);
+    ASSERT_EQ(tensors.size(), 4);
+    float atol = 1e-6, rtol = 1e-6;
+    if (DTYPE == DT_BFLOAT16) {
+      atol = 1e-2;
+      rtol = 1e-2;
+    }
+    test::ExpectClose(tensors[0], tensors_expected[0], atol, rtol);
+    test::ExpectClose(tensors[1], tensors_expected[1], atol, rtol);
+    test::ExpectClose(tensors[2], tensors_expected[2], atol, rtol);
+    test::ExpectClose(tensors[3], tensors_expected[3], atol, rtol);
+  }
+};
+
+TEST_F(MklRemapperSwishTest, F32) { RunTest<DT_FLOAT>(); }
+TEST_F(MklRemapperSwishTest, BF16) { RunTest<DT_BFLOAT16>(); }
+
 }  // namespace grappler
 }  // namespace tensorflow
 #endif  // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 789c4e5..605f59f 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -57,6 +57,9 @@
 //   (1) FusedBatchNorm + <Activation>
 //   (2) FusedBatchNorm + SideInput + <Activation>
 //
+// Sigmoid + Mul -> _MklSwish  // This fusion only works on Intel CPU.
+//
+//
 // In all cases, the supported activation functions are Relu, Relu6, and Elu.
 //
 // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the
@@ -1090,6 +1093,65 @@
   return (found_gelu_exact || found_gelu_approximate);
 }
 
+bool FindSigmoidAndMul(RemapperContext* ctx, int node_index,
+                       std::map<string, int>* matched_nodes_map,
+                       std::set<int>* remove_node_indices) {
+  using utils::MatchingDirection;
+  using utils::NodeStatus;
+  // clang-format off
+  //                Convert Sigmoid+Mul to Swish
+  //          From Graph                          To Graph
+  //          -----------                         ---------
+  //    Conv2D  <-  Filter(const)           Conv2D  <-  Filter(const)
+  //      !                                   !
+  //      V                                   V
+  //    BiasAdd <-  bias(const)             BiasAdd <-  bias(const)
+  //      !                                   !
+  //      V                                   !
+  //  ---- ----                               !
+  //  !       !                               !
+  //  !       V                               !
+  //  !    Sigmoid                            !
+  //  !       !                               !
+  //  ---   ---                               !
+  //     !  !                                 !
+  //     !  !                                 !
+  //     V  V                                 V
+  //      Mul                             _MklSwish
+  //      !                                   !
+  //      V                                   V
+
+  utils::OpTypePattern sigmoidmul_pattern{ "Mul", "mul_to_swish", NodeStatus::kReplace,
+    {
+      { "Sigmoid", "sigmoid", NodeStatus::kRemove,
+        {
+          { "*", "input", NodeStatus::kRemain}
+        }
+      },
+      { "*", "input", NodeStatus::kRemain}
+    }
+  };
+  // clang-format on
+  // check for data types
+  auto* mul_node_def = ctx->graph_view.GetNode(node_index)->node();
+  if (!HasDataType(mul_node_def, DT_FLOAT) &&
+      !HasDataType(mul_node_def, DT_BFLOAT16))
+    return false;
+
+  if (!NodeIsOnCpu(mul_node_def)) return false;
+
+  bool found_op_type_match = false;
+  utils::SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(
+      &(ctx->graph_view));
+  matched_nodes_map->clear();
+  remove_node_indices->clear();
+  found_op_type_match = graph_matcher.GetMatchedNodes(
+      sigmoidmul_pattern, {}, ctx->graph_view.GetNode(node_index),
+      matched_nodes_map, remove_node_indices);
+
+  return found_op_type_match;
+}
+
 bool FindFusedBatchNorm(const RemapperContext& ctx, int node_index,
                         FusedBatchNorm* matched) {
   const auto* node_view = ctx.graph_view.GetNode(node_index);
@@ -1985,6 +2047,39 @@
   return Status::OK();
 }
 
+Status ReplaceSigmoidMulWithSwish(
+    RemapperContext* ctx, const std::map<string, int>& matched_nodes_map,
+    const std::set<int>& remove_node_indices,
+    std::vector<bool>* invalidated_nodes, std::vector<bool>* nodes_to_delete) {
+  const GraphDef* graph = ctx->graph_view.graph();
+  const NodeDef* mul =
+      ctx->graph_view.GetNode(matched_nodes_map.at("mul_to_swish"))->node();
+  const NodeDef* sigmoid =
+      ctx->graph_view.GetNode(matched_nodes_map.at("sigmoid"))->node();
+
+  NodeDef fused_op;
+  fused_op.set_name(mul->name());
+  fused_op.set_op("_MklSwish");
+  fused_op.set_device(mul->device());
+  fused_op.add_input(sigmoid->input(0));
+
+  auto* attr = fused_op.mutable_attr();
+  (*attr)["T"] = mul->attr().at("T");
+
+  utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
+  Status status;
+  mutation->AddNode(std::move(fused_op), &status);
+  TF_RETURN_IF_ERROR(status);
+  TF_RETURN_IF_ERROR(mutation->Apply());
+
+  (*invalidated_nodes)[matched_nodes_map.at("mul_to_swish")] = true;
+
+  for (const auto& node_index : remove_node_indices) {
+    (*nodes_to_delete)[node_index] = true;
+  }
+  return Status::OK();
+}
+
 Status AddFusedBatchNormExNode(RemapperContext* ctx,
                                const FusedBatchNormEx& matched,
                                std::vector<bool>* invalidated_nodes,
@@ -2580,6 +2675,17 @@
             &nodes_to_delete, is_gelu_approximate));
         continue;
       }
+
+      // Remap Mul(x, Sigmoid(x)) pattern, fuse them into the Swish(x).
+      std::map<string, int> sigmoidmul_matched_nodes_map;
+      std::set<int> sigmoidmul_remove_node_indices;
+      if (FindSigmoidAndMul(&ctx, i, &sigmoidmul_matched_nodes_map,
+                            &sigmoidmul_remove_node_indices)) {
+        TF_RETURN_IF_ERROR(ReplaceSigmoidMulWithSwish(
+            &ctx, sigmoidmul_matched_nodes_map, sigmoidmul_remove_node_indices,
+            &invalidated_nodes, &nodes_to_delete));
+        continue;
+      }
     }
 
     // Infer properties lazily in case they are not needed.
diff --git a/tensorflow/core/kernels/mkl/BUILD b/tensorflow/core/kernels/mkl/BUILD
index 6e1e095..447d584 100644
--- a/tensorflow/core/kernels/mkl/BUILD
+++ b/tensorflow/core/kernels/mkl/BUILD
@@ -237,6 +237,20 @@
     deps = ["@com_google_absl//absl/strings"] + MKL_TEST_DEPS,
 )
 
+tf_cc_test_mkl(
+    name = "mkl_swish_op_test",
+    size = "small",
+    srcs = ["mkl_swish_op_test.cc"],
+    linkstatic = 1,  # Fixes dyld error on MacOS.
+    deps = [
+           ":mkl_eltwise_activation_base_op",
+           ":mkl_swish_op",
+           "//tensorflow/cc:math_ops",
+           "//tensorflow/core/kernels:cwise_op",
+           "//tensorflow/core:direct_session",
+    ] + MKL_TEST_DEPS,
+)
+
 tf_mkl_kernel_library(
     name = "mkl_tfconv_op",
     prefix = "mkl_tfconv",
@@ -308,6 +322,19 @@
 )
 
 tf_mkl_kernel_library(
+    name = "mkl_eltwise_activation_base_op",
+    prefix = "mkl_eltwise_activation_base",
+    deps = MKL_DEPS,
+)
+
+tf_mkl_kernel_library(
+    name = "mkl_swish_op",
+    hdrs = ["mkl_eltwise_activation_base_op.h"],
+    prefix = "mkl_swish",
+    deps = MKL_DEPS,
+)
+
+tf_mkl_kernel_library(
     name = "mkl_softmax_op",
     prefix = "mkl_softmax",
     deps = MKL_SHORT_DEPS + ["//third_party/eigen3"],
diff --git a/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h
new file mode 100644
index 0000000..8337153
--- /dev/null
+++ b/tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h
@@ -0,0 +1,288 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+#ifdef INTEL_MKL
+
+#include <unordered_map>
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/mkl_util.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+using mkldnn::algorithm;
+using mkldnn::eltwise_forward;
+using mkldnn::memory;
+using mkldnn::prop_kind;
+using mkldnn::stream;
+
+using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc;
+
+namespace tensorflow {
+
+template <typename T>
+class MklEltwiseFwdParams {
+ public:
+  memory::dims src_dims;
+  memory::desc src_md;
+  algorithm alg_kind;
+  float alpha;
+  float beta;
+
+  MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md,
+                      algorithm alg_kind, float alpha, float beta)
+      : src_dims(src_dims),
+        src_md(src_md),
+        alg_kind(alg_kind),
+        alpha(alpha),
+        beta(beta) {}
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitive : public MklPrimitive {
+ public:
+  explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
+      : MklPrimitive(engine(engine::kind::cpu, 0)) {
+    // create eltwise primitive
+    if (context_.eltwise_fwd == nullptr) {
+      Setup(fwdParams);
+    }
+  }
+
+  ~MklEltwiseFwdPrimitive() {}
+
+  // Eltwise forward execute
+  //   src_data:  input data buffer of src
+  //   dst_data:  output data buffer of dst
+  void Execute(const T* src_data, T* dst_data, OpKernelContext* op_context) {
+    context_.src_mem->set_data_handle(
+        static_cast<void*>(const_cast<T*>(src_data)));
+    context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+    DCHECK_EQ(context_.fwd_primitives.size(),
+              context_.fwd_primitives_args.size());
+
+    std::vector<primitive> net;
+    net.push_back(eltwise_forward(*context_.fwd_pd));
+    std::vector<MemoryArgsMap> net_args;
+    net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
+                        {MKLDNN_ARG_DST, *context_.dst_mem}});
+    // execute eltwise_fwd primitve
+    ExecutePrimitive(net, &net_args, GetEngine(), op_context);
+
+    // After execution, set data handle back.
+    context_.src_mem->set_data_handle(DummyData);
+    context_.dst_mem->set_data_handle(DummyData);
+  }
+
+  std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
+
+ private:
+  // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
+  struct EltwiseFwdContext {
+    // MKLDNN memory
+    std::shared_ptr<memory> src_mem;
+    std::shared_ptr<memory> dst_mem;
+
+    // desc & primitive desc
+    std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
+    std::shared_ptr<EltwiseFwdPd> fwd_pd;
+
+    // memory desc
+    std::shared_ptr<memory::desc> src_md;
+    std::shared_ptr<memory::desc> dst_md;
+
+    // memory primitive desc
+    std::shared_ptr<memory::desc> src_mpd;
+
+    // Eltwise primitive
+    std::shared_ptr<mkldnn::primitive> eltwise_fwd;
+
+    std::vector<mkldnn::primitive> fwd_primitives;
+
+    std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
+
+    EltwiseFwdContext()
+        : src_mem(nullptr),
+          dst_mem(nullptr),
+          fwd_desc(nullptr),
+          fwd_pd(nullptr),
+          src_md(nullptr),
+          dst_md(nullptr),
+          src_mpd(nullptr),
+          eltwise_fwd(nullptr) {}
+  };
+
+  // Eltwise forward primitive setup
+  void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
+    // create memory descriptors for eltwise data with specified format
+    context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
+    context_.src_mpd.reset(new memory::desc(*context_.src_md));
+
+    // Create an eltwise forward descriptor and primitive descriptor
+    context_.fwd_desc.reset(new eltwise_forward::desc(
+        prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
+        fwdParams.alpha, fwdParams.beta));
+    context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
+    auto fwd_pd = context_.fwd_pd.get();
+
+    // Create memory primitive based on dummy data
+    context_.src_mem.reset(
+        new memory(fwd_pd->src_desc(), cpu_engine_, DummyData));
+    context_.dst_mem.reset(
+        new memory(fwd_pd->dst_desc(), cpu_engine_, DummyData));
+    // Create eltwise primitive and add it to net
+    context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
+    context_.fwd_primitives_args.push_back(
+        {{MKLDNN_ARG_SRC, *context_.src_mem},
+         {MKLDNN_ARG_DST, *context_.dst_mem}});
+    context_.fwd_primitives.push_back(*context_.eltwise_fwd);
+  }
+
+  struct EltwiseFwdContext context_;
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+  static MklEltwiseFwdPrimitive<T>* Get(
+      const MklEltwiseFwdParams<T>& fwdParams) {
+    MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
+
+    // Get a eltwise fwd primitive from the cached pool
+    eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
+        MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
+            fwdParams));
+    if (eltwise_forward == nullptr) {
+      eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
+      MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
+          fwdParams, eltwise_forward);
+    }
+
+    return eltwise_forward;
+  }
+
+  static MklEltwiseFwdPrimitiveFactory& GetInstance() {
+    static MklEltwiseFwdPrimitiveFactory instance_;
+    return instance_;
+  }
+
+ private:
+  MklEltwiseFwdPrimitiveFactory() {}
+  ~MklEltwiseFwdPrimitiveFactory() {}
+
+  static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
+    string prefix = "eltwise_fwd";
+    FactoryKeyCreator key_creator;
+    key_creator.AddAsKey(prefix);
+    key_creator.AddAsKey(fwdParams.src_dims);
+    key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
+    key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
+    key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
+    return key_creator.GetKey();
+  }
+
+  MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
+    string key = CreateKey(fwdParams);
+    return this->GetOp(key);
+  }
+
+  void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
+                     MklPrimitive* op) {
+    string key = CreateKey(fwdParams);
+    this->SetOp(key, op);
+  }
+};
+
+template <typename Device, typename T, algorithm alg_kind>
+class MklEltwiseFwdActivationOpBase : public OpKernel {
+ public:
+  ~MklEltwiseFwdActivationOpBase() {}
+
+  explicit MklEltwiseFwdActivationOpBase(OpKernelConstruction* context,
+                                         float alpha, float beta)
+      : OpKernel(context), alpha_(alpha), beta_(beta) {}
+  virtual void Compute_Scalar(OpKernelContext* context) = 0;
+
+  void Compute(OpKernelContext* context) override {
+    try {
+      const Tensor& src_tensor = context->input(0);
+      TensorShape src_shape = src_tensor.shape();
+      if (src_tensor.dims() == 0) {
+        Compute_Scalar(context);
+        return;
+      }
+      // Allocate output (dst) tensor
+      TensorShape dst_shape = src_shape;
+      Tensor* dst_tensor = nullptr;
+      // Nothing to compute, return.
+      if (src_shape.num_elements() == 0) {
+        OP_REQUIRES_OK(context,
+                       context->allocate_output(
+                           GetTensorDataIndex(0, context->num_outputs()),
+                           dst_shape, &dst_tensor));
+        return;
+      }
+      // Set DNN primitive - src
+      MklDnnData<T> src(&cpu_engine);
+      memory::dims src_dims;
+      memory::desc src_md({}, memory::data_type::undef,
+                          memory::format_tag::undef);
+
+      src_dims = TFShapeToMklDnnDims(src_tensor.shape());
+      auto src_strides = CalculateTFStrides(src_dims);
+
+      // Create blocked memory descriptor
+      src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
+
+      // Try to get an eltwise forward primitive from caching pool
+      MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
+                                       beta_);
+      MklEltwiseFwdPrimitive<T>* eltwise_fwd =
+          MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
+
+      const T* src_data = src_tensor.flat<T>().data();
+
+      OP_REQUIRES_OK(context, context->allocate_output(
+                                  GetTensorDataIndex(0, context->num_outputs()),
+                                  dst_shape, &dst_tensor));
+
+      T* dst_data = dst_tensor->flat<T>().data();
+      // execute eltwise
+      eltwise_fwd->Execute(src_data, dst_data, context);
+    } catch (mkldnn::error& e) {
+      string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+                         string(e.message) + ", in file " + string(__FILE__) +
+                         ":" + std::to_string(__LINE__);
+      OP_REQUIRES_OK(
+          context,
+          errors::Aborted("Operation received an exception:", error_msg));
+    }
+  }
+
+ private:
+  engine cpu_engine = engine(engine::kind::cpu, 0);
+
+ protected:
+  float alpha_;
+  float beta_;
+};
+
+// TODO : Implement Eltwise bwd / eltwiseGrad class
+
+}  // namespace tensorflow
+
+#endif
\ No newline at end of file
diff --git a/tensorflow/core/kernels/mkl/mkl_swish_op.cc b/tensorflow/core/kernels/mkl/mkl_swish_op.cc
new file mode 100644
index 0000000..874d159
--- /dev/null
+++ b/tensorflow/core/kernels/mkl/mkl_swish_op.cc
@@ -0,0 +1,67 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/nn_ops.cc.
+#ifdef INTEL_MKL
+
+#include "tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h"
+
+namespace tensorflow {
+
+template <typename Device, typename T>
+class MklSwishOp
+    : public MklEltwiseFwdActivationOpBase<Device, T,
+                                           mkldnn::algorithm::eltwise_swish> {
+ public:
+  ~MklSwishOp() {}
+
+  explicit MklSwishOp(OpKernelConstruction* context)
+      : MklEltwiseFwdActivationOpBase<Device, T,
+                                      mkldnn::algorithm::eltwise_swish>(
+            context, 1.0f, 0.0f) {}
+
+  virtual void Compute_Scalar(OpKernelContext* context) {
+    const Tensor& src_tensor = context->input(0);
+
+    // Get shapes of input tensors
+    TensorShape src_shape = src_tensor.shape();
+
+    Tensor* dst_tensor = nullptr;
+    void* user_i =
+        static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
+
+    TensorShape dst_shape = src_shape;
+
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                GetTensorDataIndex(0, context->num_outputs()),
+                                dst_shape, &dst_tensor));
+
+    // swish(x) =  x * sigmoid(x).
+    void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
+    T feature = (static_cast<T*>(user_i))[0];
+    T e1 = Eigen::numext::exp(-feature);
+    (static_cast<T*>(out_o))[0] = feature / (static_cast<T>(1) + e1);
+    return;
+  }
+};
+
+// register dnn kernels for supported operations and supported types
+#define REGISTER_SWISH_MKL_SUPPORTED_KERNELS_TYPES(type)              \
+  REGISTER_KERNEL_BUILDER(                                            \
+      Name("_MklSwish").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+      MklSwishOp<CPUDevice, type>);
+TF_CALL_float(REGISTER_SWISH_MKL_SUPPORTED_KERNELS_TYPES);
+TF_CALL_bfloat16(REGISTER_SWISH_MKL_SUPPORTED_KERNELS_TYPES);
+
+}  // namespace tensorflow
+
+#endif
\ No newline at end of file
diff --git a/tensorflow/core/kernels/mkl/mkl_swish_op_test.cc b/tensorflow/core/kernels/mkl/mkl_swish_op_test.cc
new file mode 100644
index 0000000..132f7f7
--- /dev/null
+++ b/tensorflow/core/kernels/mkl/mkl_swish_op_test.cc
@@ -0,0 +1,114 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include "absl/strings/match.h"
+#include "mkldnn.hpp"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/mkl/mkl_eltwise_activation_base_op.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stacktrace_handler.h"
+#include "tensorflow/core/platform/str_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+// This is a special case, because EIGEN kernels does not have Swish Kerenls.
+// Compare the performance of default tensorflow kernels (Eigen) with
+// MKL kernels on CPU.
+//
+// Then you could use below command to test mkl and eigen performance:
+// $ bazel run --action_env=TF_ENABLE_ONEDNN_OPTS=1 -c opt \
+//  //tensorflow/core/kernels/mkl:mkl_swish_op_test -- --benchmark_filter=all
+//
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------//
+//  Test Swish Kernels accuracy and performance                              //
+// --------------------------------------------------------------------------//
+template <typename T>
+static Graph* SwishGraph(const string& kind, const TensorShape& shape) {
+  auto* graph = new Graph(OpRegistry::Global());
+
+  DataType dtype = DataTypeToEnum<T>::v();
+  Tensor input_t(dtype, shape);
+  input_t.flat<T>().setRandom();
+  Node* input = test::graph::Constant(graph, input_t, "input");
+  const bool isDefault = (kind == "Default");
+
+  Node* sigmoid;
+  Node* mul;
+  Node* swish;
+  if (isDefault) {
+    TF_CHECK_OK(NodeBuilder(graph->NewName("Default_sigmoid"), "Sigmoid")
+                    .Input(input)
+                    .Attr("T", dtype)
+                    .Finalize(graph, &sigmoid));
+
+    TF_CHECK_OK(NodeBuilder(graph->NewName("Default_mul"), "Mul")
+                    .Input(input)
+                    .Input(sigmoid)
+                    .Attr("T", dtype)
+                    .Finalize(graph, &mul));
+    return graph;
+  }
+  // Mkl Swish op.
+  TF_CHECK_OK(NodeBuilder(graph->NewName("Mkl_swish"), "_MklSwish")
+                  .Input(input)
+                  .Attr("T", dtype)
+                  .Finalize(graph, &swish));
+  return graph;
+}
+
+#define BM_SWISH(kind, A, B, C, D, type, T)                                \
+  static void BM_SWISH_##kind##_##type##_##A##_##B##_##C##_##D##_##T(      \
+      ::testing::benchmark::State& state) {                                \
+    int64 num_computed_elements = (A) * (B) * (C) * (D);                   \
+    int64 flops_per_iter = num_computed_elements;                          \
+                                                                           \
+    test::Benchmark(#type, SwishGraph<T>(#kind, {A, B, C, D})).Run(state); \
+    state.SetItemsProcessed(state.iterations() * flops_per_iter);          \
+  }                                                                        \
+  BENCHMARK(BM_SWISH_##kind##_##type##_##A##_##B##_##C##_##D##_##T)
+
+#define BENCHMARK_SWISH(A, B, C, D, type, T) \
+  BM_SWISH(Default, A, B, C, D, type, T);    \
+  BM_SWISH(Mkl, A, B, C, D, type, T);
+
+#define BENCHMARK_DTYPE(T)                    \
+  BENCHMARK_SWISH(1, 16, 16, 3, cpu, T);      \
+  BENCHMARK_SWISH(16, 32, 32, 1, cpu, T);     \
+  BENCHMARK_SWISH(16, 64, 64, 128, cpu, T);   \
+  BENCHMARK_SWISH(32, 64, 64, 128, cpu, T);   \
+  BENCHMARK_SWISH(32, 256, 256, 128, cpu, T); \
+  BENCHMARK_SWISH(32, 512, 512, 128, cpu, T);
+
+BENCHMARK_DTYPE(float)
+BENCHMARK_DTYPE(bfloat16)
+
+}  // namespace tensorflow
+
+#endif  // INTEL_MKL
\ No newline at end of file
diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc
index 1397643..fc9d951 100644
--- a/tensorflow/core/ops/mkl_nn_ops.cc
+++ b/tensorflow/core/ops/mkl_nn_ops.cc
@@ -1779,6 +1779,17 @@
 expected to invoke these operators.
 )doc");
 
+REGISTER_OP("_MklSwish")
+    .Input("features: T")
+    .Output("activations: T")
+    .Attr("T: {float, bfloat16} = DT_FLOAT")
+    .SetShapeFn(shape_inference::UnchangedShape)
+    .Doc(R"doc(
+MKL version of Swish operator. Uses MKL DNN APIs to implement Swish operator.
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
 }  // namespace tensorflow
 
 #endif  // INTEL_MKL