Fix some issues for MklSoftmax
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 96d5945..1a18835 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -48,10 +48,7 @@
   explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
       : cpu_engine_(engine::cpu, 0) {
     context_.fwd_stream.reset(new stream(stream::kind::eager));
-
-    if (context_.softmax_fwd == nullptr) {
-      Setup(fwdParams);
-    }
+    Setup(fwdParams);
   }
 
   ~MklSoftmaxPrimitive() {}
@@ -66,7 +63,7 @@
 
     context_.fwd_stream->submit(context_.fwd_primitives);
 
-    // after execution, set data handle back
+    // After execution, set data handle back
     context_.src_mem->set_data_handle(DummyData);
     context_.dst_mem->set_data_handle(DummyData);
   }
@@ -77,14 +74,14 @@
 
  private:
   struct SoftmaxFwdContext {
-    // MKLDNN memory
+    // MKL-DNN memory
     std::shared_ptr<memory> src_mem;
     std::shared_ptr<memory> dst_mem;
 
-    // desc & prmitive desc
+    // Primitive desc
     std::shared_ptr<mkldnn::softmax_forward::desc> fwd_desc;
 
-    // memory desc
+    // Memory desc
     std::shared_ptr<memory::desc> src_md;
 
     // Softmax primitive
@@ -106,23 +103,23 @@
 
   // Softmax forward primitive setup
   void Setup(const MklSoftmaxParams& fwdParams) {
-    // create memory descriptors for softmax data with specified format
+    // Create memory descriptors for softmax data with specified format
     context_.src_md.reset(new memory::desc({fwdParams.src_dims},
                                            MklDnnType<T>(), fwdParams.src_fmt));
 
-    // create a softmax
+    // Create a softmax
     context_.fwd_desc.reset(new mkldnn::softmax_forward::desc(
         prop_kind::forward_scoring, *context_.src_md, fwdParams.axis));
     context_.fwd_pd.reset(new mkldnn::softmax_forward::primitive_desc(
         *context_.fwd_desc, cpu_engine_));
 
-    // create memory primitive based on dummy data
+    // Create memory primitive based on dummy data
     context_.src_mem.reset(
         new memory({*context_.src_md, cpu_engine_}, DummyData));
     context_.dst_mem.reset(
         new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
 
-    // create softmax primitive and add it to net
+    // Create softmax primitive and add it to net
     context_.softmax_fwd.reset(new mkldnn::softmax_forward(
         *context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
 
@@ -137,11 +134,11 @@
 class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory<T> {
  public:
   static MklSoftmaxPrimitive<T>* Get(const MklSoftmaxParams& fwdParams) {
-    MklSoftmaxPrimitive<T>* softmax_forward = nullptr;
-
     // Get a softmax fwd primitive from the cached pool
-    softmax_forward = static_cast<MklSoftmaxPrimitive<T>*>(
-        MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(fwdParams));
+    MklSoftmaxPrimitive<T>* softmax_forward =
+        static_cast<MklSoftmaxPrimitive<T>*>(
+            MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(
+                fwdParams));
     if (softmax_forward == nullptr) {
       softmax_forward = new MklSoftmaxPrimitive<T>(fwdParams);
       MklSoftmaxPrimitiveFactory<T>::GetInstance().SetSoftmaxFwd(
@@ -164,7 +161,9 @@
     FactoryKeyCreator key_creator;
     key_creator.AddAsKey(prefix);
     key_creator.AddAsKey(fwdParams.src_dims);
-    key_creator.AddAsKey(fwdParams.axis);
+    key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_fmt));
+    key_creator.AddAsKey<int>(fwdParams.axis);
+
     return key_creator.GetKey();
   }
 
@@ -252,29 +251,27 @@
           return;
       }
 
-      // If input is in MKL layout, then simply grab input layout; otherwise,
-      // construct input Tf layout. For TF layout, although input shape
-      // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
-      // layout
+      // If input is in MKL layout, then simply get the format from input;
+      // otherwise, use TF layout defined before.
       auto src_fmt = src_mkl_shape.IsMklTensor()
                          ? static_cast<mkldnn::memory::format>(
                                src_mkl_shape.GetMklLayout().data.format)
                          : layout_type;
 
-      // get a softmax fwd from primitive pool
+      // Get a softmax fwd from primitive pool
       MklSoftmaxParams fwdParams(src_dims, src_fmt, axis);
       MklSoftmaxPrimitive<T>* softmax_fwd =
           MklSoftmaxPrimitiveFactory<T>::Get(fwdParams);
 
-      // add: output
+      // Add output
       Tensor* output_tensor = nullptr;
       MklDnnShape output_mkl_shape;
       TensorShape output_tf_shape;  // shape of output TF tensor.
 
       auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_primitive_desc();
 
-      // if input is MKL shape, output is also MKL shape.
-      // if input is TF shape, output is also TF shape
+      // If input is MKL shape, output is also MKL shape.
+      // If input is TF shape, output is also TF shape.
       if (src_mkl_shape.IsMklTensor()) {
         output_mkl_shape.SetMklTensor(true);
         output_mkl_shape.SetMklLayout(&dst_pd);
@@ -292,7 +289,7 @@
       const T* src_data = src_tensor.flat<T>().data();
       T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
 
-      // execute softmax
+      // Execute softmax
       softmax_fwd->Execute(src_data, dst_data);
     } catch (mkldnn::error& e) {
       string error_msg = "Status: " + std::to_string(e.status) + ", message: " +