Fixes based on review
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 3ce85b7..1787ec4 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -23,6 +23,9 @@
 #include "tensorflow/core/util/tensor_format.h"
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 
+#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
+#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
+
 using mkldnn::batch_normalization_backward;
 using mkldnn::batch_normalization_forward;
 using mkldnn::prop_kind;
@@ -81,12 +84,12 @@
         static_cast<void*>(const_cast<T*>(src_data)));
     context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
 
-    if (context_.flags & static_cast<int>(BN_FLAGS::use_scale_shift))
+    if (IS_SET(use_scale_shift))
       context_.weights_mem->set_data_handle(
           static_cast<void*>(const_cast<U*>(weights_data)));
 
     if ((context_.pkind == prop_kind::forward_training) ||
-        (context_.flags & static_cast<int>(BN_FLAGS::use_global_stats))) {
+        (IS_SET(use_global_stats))) {
       context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
       context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
     }
@@ -95,16 +98,16 @@
     execute_primitives(context_.fwd_primitives, context_.fwd_stream, context_.net_args);
 #else
     context_.fwd_stream->submit(context_.fwd_primitives);
-#endif  // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
 
     context_.src_mem->set_data_handle(DummyData);
     context_.dst_mem->set_data_handle(DummyData);
 
-    if (context_.flags & (int)BN_FLAGS::use_scale_shift)
+    if (IS_SET(use_scale_shift))
       context_.weights_mem->set_data_handle(DummyData);
 
     if ((context_.pkind == prop_kind::forward_training) ||
-        (context_.flags & (int)BN_FLAGS::use_global_stats)) {
+        (IS_SET(use_global_stats))) {
       context_.mean_mem->set_data_handle(DummyData);
       context_.variance_mem->set_data_handle(DummyData);
     }
@@ -122,7 +125,7 @@
   mkldnn_memory_format_t GetDstFmt() const {
     return (*context_.dst_mem).get_primitive_desc().desc().data.format;
   }
-#endif
+#endif // !ENABLE_MKLDNN_V1
 
   std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const {
     return context_.fwd_pd;
@@ -154,7 +157,7 @@
 
 #ifdef ENABLE_MKLDNN_V1
     std::vector<std::unordered_map<int, memory>> net_args;
-#endif
+#endif // ENABLE_MKLDNN_V1
 
     BatchNormFwdContext()
         : flags(0),
@@ -169,9 +172,9 @@
   };
 
   void Setup(const MklBatchNormFwdParams& fwdParams) {
-    context_.flags = fwdParams.training ? (int)BN_FLAGS::use_scale_shift
-                                        : ((int)BN_FLAGS::use_scale_shift |
-                                           (int)BN_FLAGS::use_global_stats);
+    context_.flags = fwdParams.training ? GET_FLAG(use_scale_shift)
+                                        : (GET_FLAG(use_scale_shift) |
+                                           GET_FLAG(use_global_stats));
     context_.pkind = fwdParams.training ? prop_kind::forward_training
                                         : prop_kind::forward_scoring;
 
@@ -181,7 +184,7 @@
                                fwdParams.src_format);
 #else
     auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>());
-#endif
+#endif // !ENABLE_MKLDNN_V1
 
     // Create forward BatchNorm descriptor and primitive descriptor.
 #ifdef ENABLE_MKLDNN_V1
@@ -191,7 +194,7 @@
 #else
     auto fwd_desc = batch_normalization_forward::desc(
         context_.pkind, src_md, fwdParams.eps, context_.flags);
-#endif
+#endif // ENABLE_MKLDNN_V1
 
     context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_));
 
@@ -203,13 +206,13 @@
 
     memory::dims s_dims = {2, fwdParams.depth};
     memory::dims m_dims = {1, fwdParams.depth};
-    if (context_.flags & (int)BN_FLAGS::use_scale_shift) {
+    if (IS_SET(use_scale_shift)) {
       context_.weights_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
           s_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
     }
 
     if (fwdParams.training ||
-        (context_.flags & (int)BN_FLAGS::use_global_stats)) {
+        (IS_SET(use_global_stats))) {
       context_.mean_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
           m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
 
@@ -219,9 +222,9 @@
 
     // BatchNorm forward primitive.
     if (!fwdParams.training &&
-        !(context_.flags & (int)BN_FLAGS::use_global_stats)) {
+        !(IS_SET(use_global_stats))) {
 #ifdef ENABLE_MKLDNN_V1
-      if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
+      if ((IS_SET(use_scale_shift)) &&
           mkldnn_use_scaleshift) {
         context_.net_args.push_back(
             {{MKLDNN_ARG_SRC, *context_.src_mem},
@@ -235,8 +238,8 @@
       }
       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
 #else
-      if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
-          (int)BN_FLAGS::use_scale_shift) {
+      if ((IS_SET(use_scale_shift)) &&
+          GET_FLAG(use_scale_shift)) {
         context_.bn_fwd.reset(new batch_normalization_forward(
             *context_.fwd_pd, *context_.src_mem, *context_.weights_mem,
             *context_.dst_mem));
@@ -244,11 +247,11 @@
         context_.bn_fwd.reset(new batch_normalization_forward(
             *context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
       }
-#endif  // ENABLE_MKLDNN_V1
-    } else if (context_.flags & (int)BN_FLAGS::use_global_stats) {
+#endif // ENABLE_MKLDNN_V1
+    } else if (IS_SET(use_global_stats)) {
 #ifdef ENABLE_MKLDNN_V1
-      if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
-          (int)BN_FLAGS::use_scale_shift) {
+      if ((IS_SET(use_scale_shift)) &&
+          GET_FLAG(use_scale_shift)) {
         context_.net_args.push_back(
             {{MKLDNN_ARG_SRC, *context_.src_mem},
              {MKLDNN_ARG_MEAN, *context_.mean_mem},
@@ -266,8 +269,8 @@
       }
       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
 #else
-      if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
-          (int)BN_FLAGS::use_scale_shift) {
+      if ((IS_SET(use_scale_shift)) &&
+          GET_FLAG(use_scale_shift)) {
         context_.bn_fwd.reset(new batch_normalization_forward(
             *context_.fwd_pd, *context_.src_mem,
             (const primitive::at)*context_.mean_mem,
@@ -279,11 +282,11 @@
             (const primitive::at)*context_.mean_mem,
             (const primitive::at)*context_.variance_mem, *context_.dst_mem));
       }
-#endif  // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
     } else {
 #ifdef ENABLE_MKLDNN_V1
-      if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
-          (int)BN_FLAGS::use_scale_shift) {
+      if ((IS_SET(use_scale_shift)) &&
+          GET_FLAG(use_scale_shift)) {
         context_.net_args.push_back(
             {{MKLDNN_ARG_SRC, *context_.src_mem},
              {MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
@@ -300,8 +303,8 @@
       }
       context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
 #else
-      if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
-          (int)BN_FLAGS::use_scale_shift) {
+      if ((IS_SET(use_scale_shift)) &&
+          GET_FLAG(use_scale_shift)) {
         context_.bn_fwd.reset(new batch_normalization_forward(
             *context_.fwd_pd, *context_.src_mem, *context_.weights_mem,
             *context_.dst_mem, *context_.mean_mem, *context_.variance_mem));
@@ -310,7 +313,7 @@
             *context_.fwd_pd, *context_.src_mem, *context_.dst_mem,
             *context_.mean_mem, *context_.variance_mem));
       }
-#endif  // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
     }
     context_.fwd_primitives.push_back(*context_.bn_fwd);
   }
@@ -377,7 +380,7 @@
   float eps;
   bool training;
 
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
   memory::format src_format;
 
   MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
@@ -397,7 +400,7 @@
         depth(depth),
         eps(eps),
         training(training) {}
-#endif
+#endif // !ENABLE_MKLDNN_V1
 };
 
 template <typename T, typename U>
@@ -435,7 +438,7 @@
     context_.diff_dst_mem->set_data_handle(
         static_cast<void*>(const_cast<T*>(diff_dst_data)));
 
-    if (context_.flags & (int)BN_FLAGS::use_scale_shift) {
+    if (IS_SET(use_scale_shift)) {
       context_.weights_mem->set_data_handle(
           static_cast<void*>(const_cast<U*>(weights_data)));
       context_.diff_weights_mem->set_data_handle(
@@ -453,14 +456,14 @@
     }
 #else
     context_.bwd_stream->submit(context_.bwd_primitives);
-#endif  // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
 
     // After execution, set data handle back to DummyData.
     context_.src_mem->set_data_handle(DummyData);
     context_.mean_mem->set_data_handle(DummyData);
     context_.variance_mem->set_data_handle(DummyData);
     context_.diff_dst_mem->set_data_handle(DummyData);
-    if (context_.flags & (int)BN_FLAGS::use_scale_shift) {
+    if (IS_SET(use_scale_shift)) {
       context_.weights_mem->set_data_handle(DummyData);
       context_.diff_weights_mem->set_data_handle(DummyData);
     }
@@ -475,7 +478,7 @@
   mkldnn_memory_format_t GetDiffDstMemoryFormat() const {
     return context_.diff_dst_mem->get_primitive_desc().desc().data.format;
   }
-#endif
+#endif // !ENABLE_MKLDNN_V1
 
   std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const {
     return context_.bwd_pd;
@@ -509,7 +512,7 @@
 
 #ifdef ENABLE_MKLDNN_V1
     std::vector<std::unordered_map<int, memory>> net_args;
-#endif
+#endif // ENABLE_MKLDNN_V1
 
     BatchNormBwdContext()
         : src_mem(nullptr),
@@ -523,12 +526,12 @@
   };
 
   void Setup(const MklBatchNormBwdParams& bwdParams) {
-    context_.flags = bwdParams.training ? (int)BN_FLAGS::use_scale_shift
-                                        : ((int)BN_FLAGS::use_scale_shift |
-                                           (int)BN_FLAGS::use_global_stats);
+    context_.flags = bwdParams.training ? GET_FLAG(use_scale_shift)
+                                        : (GET_FLAG(use_scale_shift) |
+                                           GET_FLAG(use_global_stats));
 
     // Memory descriptors.
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
     auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
                                bwdParams.src_format);
     auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(),
@@ -536,7 +539,7 @@
 #else
     auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>());
     auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>());
-#endif
+#endif // !ENABLE_MKLDNN_V1
     auto variance_desc =
         memory::desc({1, bwdParams.depth}, MklDnnType<U>(), MEMORY_FORMAT::nc);
     auto mean_desc =
@@ -590,7 +593,7 @@
         *context_.bwd_pd, *context_.src_mem, *context_.mean_mem,
         *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem,
         *context_.diff_src_mem, *context_.diff_weights_mem));
-#endif
+#endif // ENABLE_MKLDNN_V1
     context_.bwd_primitives.push_back(*context_.bn_bwd);
   }
 
@@ -1302,16 +1305,16 @@
     mkl_shape_p.SetMklTensor(false);
     AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
                               mkl_shape_p);
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
     std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
                 static_cast<U>(0));
-#endif
+#endif // !ENABLE_MKLDNN_V1
     AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
                               mkl_shape_p);
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
     std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
                 static_cast<U>(0));
-#endif
+#endif // !ENABLE_MKLDNN_V1
   }
 
   memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
@@ -1398,4 +1401,7 @@
 
 }  // namespace tensorflow
 
-#endif  // INTEL_MKL
+#undef GET_FLAG
+#undef IS_SET
+
+#endif // INTEL_MKL