Fixes based on review
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 3ce85b7..1787ec4 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -23,6 +23,9 @@
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#define GET_FLAG(bn_flag) static_cast<int>(BN_FLAGS::bn_flag)
+#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag))
+
using mkldnn::batch_normalization_backward;
using mkldnn::batch_normalization_forward;
using mkldnn::prop_kind;
@@ -81,12 +84,12 @@
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
- if (context_.flags & static_cast<int>(BN_FLAGS::use_scale_shift))
+ if (IS_SET(use_scale_shift))
context_.weights_mem->set_data_handle(
static_cast<void*>(const_cast<U*>(weights_data)));
if ((context_.pkind == prop_kind::forward_training) ||
- (context_.flags & static_cast<int>(BN_FLAGS::use_global_stats))) {
+ (IS_SET(use_global_stats))) {
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
}
@@ -95,16 +98,16 @@
execute_primitives(context_.fwd_primitives, context_.fwd_stream, context_.net_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
-#endif // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
- if (context_.flags & (int)BN_FLAGS::use_scale_shift)
+ if (IS_SET(use_scale_shift))
context_.weights_mem->set_data_handle(DummyData);
if ((context_.pkind == prop_kind::forward_training) ||
- (context_.flags & (int)BN_FLAGS::use_global_stats)) {
+ (IS_SET(use_global_stats))) {
context_.mean_mem->set_data_handle(DummyData);
context_.variance_mem->set_data_handle(DummyData);
}
@@ -122,7 +125,7 @@
mkldnn_memory_format_t GetDstFmt() const {
return (*context_.dst_mem).get_primitive_desc().desc().data.format;
}
-#endif
+#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<BatchNormFwdPd> GetBatchNormFwdPd() const {
return context_.fwd_pd;
@@ -154,7 +157,7 @@
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args;
-#endif
+#endif // ENABLE_MKLDNN_V1
BatchNormFwdContext()
: flags(0),
@@ -169,9 +172,9 @@
};
void Setup(const MklBatchNormFwdParams& fwdParams) {
- context_.flags = fwdParams.training ? (int)BN_FLAGS::use_scale_shift
- : ((int)BN_FLAGS::use_scale_shift |
- (int)BN_FLAGS::use_global_stats);
+ context_.flags = fwdParams.training ? GET_FLAG(use_scale_shift)
+ : (GET_FLAG(use_scale_shift) |
+ GET_FLAG(use_global_stats));
context_.pkind = fwdParams.training ? prop_kind::forward_training
: prop_kind::forward_scoring;
@@ -181,7 +184,7 @@
fwdParams.src_format);
#else
auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>());
-#endif
+#endif // !ENABLE_MKLDNN_V1
// Create forward BatchNorm descriptor and primitive descriptor.
#ifdef ENABLE_MKLDNN_V1
@@ -191,7 +194,7 @@
#else
auto fwd_desc = batch_normalization_forward::desc(
context_.pkind, src_md, fwdParams.eps, context_.flags);
-#endif
+#endif // ENABLE_MKLDNN_V1
context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_));
@@ -203,13 +206,13 @@
memory::dims s_dims = {2, fwdParams.depth};
memory::dims m_dims = {1, fwdParams.depth};
- if (context_.flags & (int)BN_FLAGS::use_scale_shift) {
+ if (IS_SET(use_scale_shift)) {
context_.weights_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
s_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
}
if (fwdParams.training ||
- (context_.flags & (int)BN_FLAGS::use_global_stats)) {
+ (IS_SET(use_global_stats))) {
context_.mean_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData));
@@ -219,9 +222,9 @@
// BatchNorm forward primitive.
if (!fwdParams.training &&
- !(context_.flags & (int)BN_FLAGS::use_global_stats)) {
+ !(IS_SET(use_global_stats))) {
#ifdef ENABLE_MKLDNN_V1
- if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
+ if ((IS_SET(use_scale_shift)) &&
mkldnn_use_scaleshift) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
@@ -235,8 +238,8 @@
}
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
#else
- if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
- (int)BN_FLAGS::use_scale_shift) {
+ if ((IS_SET(use_scale_shift)) &&
+ GET_FLAG(use_scale_shift)) {
context_.bn_fwd.reset(new batch_normalization_forward(
*context_.fwd_pd, *context_.src_mem, *context_.weights_mem,
*context_.dst_mem));
@@ -244,11 +247,11 @@
context_.bn_fwd.reset(new batch_normalization_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
}
-#endif // ENABLE_MKLDNN_V1
- } else if (context_.flags & (int)BN_FLAGS::use_global_stats) {
+#endif // ENABLE_MKLDNN_V1
+ } else if (IS_SET(use_global_stats)) {
#ifdef ENABLE_MKLDNN_V1
- if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
- (int)BN_FLAGS::use_scale_shift) {
+ if ((IS_SET(use_scale_shift)) &&
+ GET_FLAG(use_scale_shift)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_MEAN, *context_.mean_mem},
@@ -266,8 +269,8 @@
}
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
#else
- if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
- (int)BN_FLAGS::use_scale_shift) {
+ if ((IS_SET(use_scale_shift)) &&
+ GET_FLAG(use_scale_shift)) {
context_.bn_fwd.reset(new batch_normalization_forward(
*context_.fwd_pd, *context_.src_mem,
(const primitive::at)*context_.mean_mem,
@@ -279,11 +282,11 @@
(const primitive::at)*context_.mean_mem,
(const primitive::at)*context_.variance_mem, *context_.dst_mem));
}
-#endif // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
} else {
#ifdef ENABLE_MKLDNN_V1
- if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
- (int)BN_FLAGS::use_scale_shift) {
+ if ((IS_SET(use_scale_shift)) &&
+ GET_FLAG(use_scale_shift)) {
context_.net_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.weights_mem},
@@ -300,8 +303,8 @@
}
context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd));
#else
- if ((context_.flags & (int)BN_FLAGS::use_scale_shift) &&
- (int)BN_FLAGS::use_scale_shift) {
+ if ((IS_SET(use_scale_shift)) &&
+ GET_FLAG(use_scale_shift)) {
context_.bn_fwd.reset(new batch_normalization_forward(
*context_.fwd_pd, *context_.src_mem, *context_.weights_mem,
*context_.dst_mem, *context_.mean_mem, *context_.variance_mem));
@@ -310,7 +313,7 @@
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem,
*context_.mean_mem, *context_.variance_mem));
}
-#endif // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
}
context_.fwd_primitives.push_back(*context_.bn_fwd);
}
@@ -377,7 +380,7 @@
float eps;
bool training;
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
memory::format src_format;
MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
@@ -397,7 +400,7 @@
depth(depth),
eps(eps),
training(training) {}
-#endif
+#endif // !ENABLE_MKLDNN_V1
};
template <typename T, typename U>
@@ -435,7 +438,7 @@
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
- if (context_.flags & (int)BN_FLAGS::use_scale_shift) {
+ if (IS_SET(use_scale_shift)) {
context_.weights_mem->set_data_handle(
static_cast<void*>(const_cast<U*>(weights_data)));
context_.diff_weights_mem->set_data_handle(
@@ -453,14 +456,14 @@
}
#else
context_.bwd_stream->submit(context_.bwd_primitives);
-#endif // ENABLE_MKLDNN_V1
+#endif // ENABLE_MKLDNN_V1
// After execution, set data handle back to DummyData.
context_.src_mem->set_data_handle(DummyData);
context_.mean_mem->set_data_handle(DummyData);
context_.variance_mem->set_data_handle(DummyData);
context_.diff_dst_mem->set_data_handle(DummyData);
- if (context_.flags & (int)BN_FLAGS::use_scale_shift) {
+ if (IS_SET(use_scale_shift)) {
context_.weights_mem->set_data_handle(DummyData);
context_.diff_weights_mem->set_data_handle(DummyData);
}
@@ -475,7 +478,7 @@
mkldnn_memory_format_t GetDiffDstMemoryFormat() const {
return context_.diff_dst_mem->get_primitive_desc().desc().data.format;
}
-#endif
+#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<BatchNormBwdPd> GetBatchNormBwdPd() const {
return context_.bwd_pd;
@@ -509,7 +512,7 @@
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args;
-#endif
+#endif // ENABLE_MKLDNN_V1
BatchNormBwdContext()
: src_mem(nullptr),
@@ -523,12 +526,12 @@
};
void Setup(const MklBatchNormBwdParams& bwdParams) {
- context_.flags = bwdParams.training ? (int)BN_FLAGS::use_scale_shift
- : ((int)BN_FLAGS::use_scale_shift |
- (int)BN_FLAGS::use_global_stats);
+ context_.flags = bwdParams.training ? GET_FLAG(use_scale_shift)
+ : (GET_FLAG(use_scale_shift) |
+ GET_FLAG(use_global_stats));
// Memory descriptors.
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
bwdParams.src_format);
auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(),
@@ -536,7 +539,7 @@
#else
auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>());
auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>());
-#endif
+#endif // !ENABLE_MKLDNN_V1
auto variance_desc =
memory::desc({1, bwdParams.depth}, MklDnnType<U>(), MEMORY_FORMAT::nc);
auto mean_desc =
@@ -590,7 +593,7 @@
*context_.bwd_pd, *context_.src_mem, *context_.mean_mem,
*context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem,
*context_.diff_src_mem, *context_.diff_weights_mem));
-#endif
+#endif // ENABLE_MKLDNN_V1
context_.bwd_primitives.push_back(*context_.bn_bwd);
}
@@ -1302,16 +1305,16 @@
mkl_shape_p.SetMklTensor(false);
AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
mkl_shape_p);
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
std::fill_n(p1_tensor->flat<U>().data(), p1_tensor->shape().num_elements(),
static_cast<U>(0));
-#endif
+#endif // !ENABLE_MKLDNN_V1
AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
mkl_shape_p);
-#ifndef ENABLE_MKL_DNN_V1
+#ifndef ENABLE_MKLDNN_V1
std::fill_n(p2_tensor->flat<U>().data(), p2_tensor->shape().num_elements(),
static_cast<U>(0));
-#endif
+#endif // !ENABLE_MKLDNN_V1
}
memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
@@ -1398,4 +1401,7 @@
} // namespace tensorflow
-#endif // INTEL_MKL
+#undef GET_FLAG
+#undef IS_SET
+
+#endif // INTEL_MKL