| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| #ifdef INTEL_MKL |
| #include "mkldnn.hpp" |
| #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
| #include "tensorflow/core/framework/op_kernel.h" |
| #include "tensorflow/core/framework/register_types.h" |
| #include "tensorflow/core/framework/tensor.h" |
| #include "tensorflow/core/framework/tensor_types.h" |
| #include "tensorflow/core/util/mkl_util.h" |
| #include "tensorflow/core/util/tensor_format.h" |
| |
| using mkldnn::batch_normalization_backward; |
| using mkldnn::batch_normalization_forward; |
| using mkldnn::prop_kind; |
| using mkldnn::stream; |
| using mkldnn::use_global_stats; |
| using mkldnn::use_scale_shift; |
| |
| namespace tensorflow { |
| using CPUDevice = Eigen::ThreadPoolDevice; |
| |
| struct MklBatchNormFwdParams { |
| memory::dims src_dims; |
| int depth; |
| float eps; |
| bool training; |
| |
| MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, |
| bool training) |
| : src_dims(src_dims), depth(depth), eps(eps), training(training) {} |
| }; |
| |
| template <typename T, typename U> |
| class MklFusedBatchNormFwdPrimitive : public MklPrimitive { |
| public: |
| explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) |
| : cpu_engine_(engine::cpu, 0) { |
| context_.fwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); |
| if (context_.bn_fwd == nullptr) Setup(fwdParams); |
| } |
| |
| ~MklFusedBatchNormFwdPrimitive() {} |
| |
| // BatchNormalization forward execute |
| // src_data: input data buffer of src |
| // weights_data: input data buffer of weights |
| // dst_data: output data buffer of dst |
| // mean_data: output data buffer of means |
| // variance_data: output data buffer of variances |
| void Execute(const T* src_data, const U* weights_data, T* dst_data, |
| U* mean_data, U* variance_data) { |
| context_.src_mem->set_data_handle( |
| static_cast<void*>(const_cast<T*>(src_data))); |
| context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); |
| |
| if (context_.flags & use_scale_shift) |
| context_.weights_mem->set_data_handle( |
| static_cast<void*>(const_cast<U*>(weights_data))); |
| |
| if ((context_.pkind == prop_kind::forward_training) || |
| (context_.flags & use_global_stats)) { |
| context_.mean_mem->set_data_handle(static_cast<void*>(mean_data)); |
| context_.variance_mem->set_data_handle(static_cast<void*>(variance_data)); |
| } |
| |
| // execution |
| context_.fwd_stream->submit(context_.fwd_primitives); |
| |
| context_.src_mem->set_data_handle(DummyData); |
| context_.dst_mem->set_data_handle(DummyData); |
| |
| if (context_.flags & use_scale_shift) |
| context_.weights_mem->set_data_handle(DummyData); |
| |
| if ((context_.pkind == prop_kind::forward_training) || |
| (context_.flags & use_global_stats)) { |
| context_.mean_mem->set_data_handle(DummyData); |
| context_.variance_mem->set_data_handle(DummyData); |
| } |
| } |
| |
| memory::primitive_desc GetDstPd() const { |
| return (*context_.dst_mem).get_primitive_desc(); |
| } |
| |
| mkldnn_memory_format_t GetSrcFmt() const { |
| return (*context_.src_mem).get_primitive_desc().desc().data.format; |
| } |
| |
| mkldnn_memory_format_t GetDstFmt() const { |
| return (*context_.dst_mem).get_primitive_desc().desc().data.format; |
| } |
| |
| private: |
| // Primitive reuse context for BatchNorm fwd op |
| struct BatchNormFwdContext { |
| // flags indict if it is training or inference mode |
| int64 flags; |
| |
| // algorithm |
| mkldnn::prop_kind pkind; |
| |
| // Mkldnn Memory |
| std::shared_ptr<mkldnn::memory> src_mem; |
| std::shared_ptr<mkldnn::memory> weights_mem; |
| std::shared_ptr<mkldnn::memory> dst_mem; |
| std::shared_ptr<mkldnn::memory> mean_mem; |
| std::shared_ptr<mkldnn::memory> variance_mem; |
| |
| // BatchNorm forward primitive |
| std::shared_ptr<mkldnn::primitive> bn_fwd; |
| std::shared_ptr<mkldnn::stream> fwd_stream; |
| std::vector<mkldnn::primitive> fwd_primitives; |
| |
| BatchNormFwdContext() |
| : flags(0), |
| pkind(mkldnn::forward_training), |
| src_mem(nullptr), |
| weights_mem(nullptr), |
| dst_mem(nullptr), |
| mean_mem(nullptr), |
| variance_mem(nullptr), |
| bn_fwd(nullptr), |
| fwd_stream(nullptr) {} |
| }; |
| |
| void Setup(const MklBatchNormFwdParams& fwdParams) { |
| context_.flags = fwdParams.training ? use_scale_shift |
| : (use_scale_shift | use_global_stats); |
| context_.pkind = fwdParams.training ? prop_kind::forward_training |
| : prop_kind::forward_scoring; |
| |
| // memory desc |
| auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(), |
| get_desired_format(fwdParams.src_dims[1])); |
| |
| // fwd desc & primitive desc |
| auto fwd_desc = batch_normalization_forward::desc( |
| context_.pkind, src_md, fwdParams.eps, context_.flags); |
| auto fwd_pd = |
| batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); |
| |
| // memory primitive |
| context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); |
| context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); |
| |
| if (context_.flags & use_scale_shift) { |
| auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType<U>(), |
| memory::format::nc); |
| context_.weights_mem.reset( |
| new memory({weights_desc, cpu_engine_}, DummyData)); |
| } |
| |
| if (fwdParams.training || (context_.flags & use_global_stats)) { |
| auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType<U>(), |
| memory::format::nc); |
| context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); |
| |
| auto variance_desc = |
| memory::desc({1, fwdParams.depth}, MklDnnType<U>(), memory::nc); |
| context_.variance_mem.reset( |
| new memory({variance_desc, cpu_engine_}, DummyData)); |
| } |
| |
| // BatchNorm forward primitive |
| if (!fwdParams.training && !(context_.flags & use_global_stats)) { |
| if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { |
| context_.bn_fwd.reset(new batch_normalization_forward( |
| fwd_pd, *context_.src_mem, *context_.weights_mem, |
| *context_.dst_mem)); |
| } else { |
| context_.bn_fwd.reset(new batch_normalization_forward( |
| fwd_pd, *context_.src_mem, *context_.dst_mem)); |
| } |
| } else if (context_.flags & use_global_stats) { |
| if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { |
| context_.bn_fwd.reset(new batch_normalization_forward( |
| fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, |
| (const primitive::at)*context_.variance_mem, *context_.weights_mem, |
| *context_.dst_mem)); |
| } else { |
| context_.bn_fwd.reset(new batch_normalization_forward( |
| fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, |
| (const primitive::at)*context_.variance_mem, *context_.dst_mem)); |
| } |
| } else { |
| if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { |
| context_.bn_fwd.reset(new batch_normalization_forward( |
| fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem, |
| *context_.mean_mem, *context_.variance_mem)); |
| } else { |
| context_.bn_fwd.reset(new batch_normalization_forward( |
| fwd_pd, *context_.src_mem, *context_.dst_mem, *context_.mean_mem, |
| *context_.variance_mem)); |
| } |
| } |
| |
| context_.fwd_primitives.push_back(*context_.bn_fwd); |
| } |
| |
| mkldnn::memory::desc get_desc_data(const mkldnn::memory& m) const { |
| return m.get_primitive_desc().desc().data; |
| } |
| |
| struct BatchNormFwdContext context_; |
| engine cpu_engine_; |
| }; |
| |
| template <typename T, typename U> |
| class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> { |
| public: |
| static MklFusedBatchNormFwdPrimitive<T, U>* Get( |
| const MklBatchNormFwdParams& fwdParams) { |
| auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T, U>*>( |
| MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance() |
| .GetBatchNormFwd(fwdParams)); |
| |
| if (bn_fwd == nullptr) { |
| bn_fwd = new MklFusedBatchNormFwdPrimitive<T, U>(fwdParams); |
| MklFusedBatchNormFwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormFwd( |
| fwdParams, bn_fwd); |
| } |
| return bn_fwd; |
| } |
| |
| static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() { |
| static MklFusedBatchNormFwdPrimitiveFactory instance_; |
| return instance_; |
| } |
| |
| private: |
| MklFusedBatchNormFwdPrimitiveFactory() {} |
| ~MklFusedBatchNormFwdPrimitiveFactory() {} |
| |
| static string CreateKey(const MklBatchNormFwdParams& fwdParams) { |
| string prefix = "bn_fwd"; |
| FactoryKeyCreator key_creator; |
| key_creator.AddAsKey(prefix); |
| key_creator.AddAsKey(fwdParams.src_dims); |
| key_creator.AddAsKey<int>(fwdParams.depth); |
| key_creator.AddAsKey<float>(fwdParams.eps); |
| key_creator.AddAsKey<bool>(fwdParams.training); |
| key_creator.AddAsKey(typeid(T).name()); |
| key_creator.AddAsKey(typeid(U).name()); |
| return key_creator.GetKey(); |
| } |
| |
| MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) { |
| string key = CreateKey(fwdParams); |
| return this->GetOp(key); |
| } |
| |
| void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams, |
| MklPrimitive* op) { |
| string key = CreateKey(fwdParams); |
| this->SetOp(key, op); |
| } |
| }; |
| |
| struct MklBatchNormBwdParams { |
| memory::dims src_dims; |
| memory::dims diff_dst_dims; |
| int depth; |
| float eps; |
| bool training; |
| |
| MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, |
| int depth, float eps, bool training) |
| : src_dims(src_dims), |
| diff_dst_dims(diff_dst_dims), |
| depth(depth), |
| eps(eps), |
| training(training) {} |
| }; |
| |
| template <typename T, typename U> |
| class MklFusedBatchNormBwdPrimitive : public MklPrimitive { |
| public: |
| explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) |
| : cpu_engine_(engine::cpu, 0) { |
| context_.bwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); |
| if (context_.bn_bwd == nullptr) Setup(bwdParams); |
| } |
| |
| ~MklFusedBatchNormBwdPrimitive() {} |
| |
| // BatchNormalization backward execute |
| // src_data: input data buffer of src |
| // mean_data: input data buffer of mean |
| // variance_data: input data buffer of variance |
| // diff_dst_data: input data buffer of diff_dst |
| // weights_data: input data buffer of weights |
| // diff_src_data: output data buffer of diff_src |
| // diff_weights_data: output data buffer of diff_weights |
| // res_space_data: output data buffer or reserved_space_3. |
| // TODO: reserved_space_3: temp mem to hold |
| // intermediate results is not implemented |
| // on CPU as of now. |
| void Execute(const T* src_data, const U* mean_data, const U* variance_data, |
| const T* diff_dst_data, const U* weights_data, T* diff_src_data, |
| U* diff_weights_data, U* res_space_data) { |
| context_.src_mem->set_data_handle( |
| static_cast<void*>(const_cast<T*>(src_data))); |
| context_.mean_mem->set_data_handle( |
| static_cast<void*>(const_cast<U*>(mean_data))); |
| context_.variance_mem->set_data_handle( |
| static_cast<void*>(const_cast<U*>(variance_data))); |
| context_.diff_dst_mem->set_data_handle( |
| static_cast<void*>(const_cast<T*>(diff_dst_data))); |
| |
| // TODO: type for weights? |
| if (context_.flags & use_scale_shift) { |
| context_.weights_mem->set_data_handle( |
| static_cast<void*>(const_cast<U*>(weights_data))); |
| context_.diff_weights_mem->set_data_handle( |
| static_cast<void*>(diff_weights_data)); |
| } |
| |
| context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); |
| |
| // execution |
| context_.bwd_stream->submit(context_.bwd_primitives); |
| |
| context_.src_mem->set_data_handle(DummyData); |
| context_.mean_mem->set_data_handle(DummyData); |
| context_.variance_mem->set_data_handle(DummyData); |
| context_.diff_dst_mem->set_data_handle(DummyData); |
| if (context_.flags & use_scale_shift) { |
| context_.weights_mem->set_data_handle(DummyData); |
| context_.diff_weights_mem->set_data_handle(DummyData); |
| } |
| context_.diff_src_mem->set_data_handle(DummyData); |
| } |
| |
| mkldnn_memory_format_t GetSrcFmt() { |
| return (*context_.src_mem).get_primitive_desc().desc().data.format; |
| } |
| |
| mkldnn_memory_format_t GetDiffDstFmt() { |
| return (*context_.diff_dst_mem).get_primitive_desc().desc().data.format; |
| } |
| |
| memory::primitive_desc GetDiffSrcPd() { |
| return (*context_.diff_src_mem).get_primitive_desc(); |
| } |
| |
| private: |
| struct BatchNormBwdContext { |
| // Flags to indicate whether it is training or inference |
| int64 flags; |
| |
| // MKLDNN memory |
| std::shared_ptr<mkldnn::memory> src_mem; |
| std::shared_ptr<mkldnn::memory> mean_mem; |
| std::shared_ptr<mkldnn::memory> variance_mem; |
| std::shared_ptr<mkldnn::memory> diff_dst_mem; |
| std::shared_ptr<mkldnn::memory> weights_mem; |
| std::shared_ptr<mkldnn::memory> diff_weights_mem; |
| std::shared_ptr<mkldnn::memory> diff_src_mem; |
| |
| // Batch Norm primitive |
| std::shared_ptr<mkldnn::primitive> bn_bwd; |
| std::vector<mkldnn::primitive> bwd_primitives; |
| std::shared_ptr<mkldnn::stream> bwd_stream; |
| |
| BatchNormBwdContext() |
| : src_mem(nullptr), |
| mean_mem(nullptr), |
| variance_mem(nullptr), |
| diff_dst_mem(nullptr), |
| weights_mem(nullptr), |
| diff_weights_mem(nullptr), |
| diff_src_mem(nullptr), |
| bwd_stream(nullptr) {} |
| }; |
| |
| void Setup(const MklBatchNormBwdParams& bwdParams) { |
| context_.flags = bwdParams.training ? use_scale_shift |
| : (use_scale_shift | use_global_stats); |
| |
| // memory desc |
| auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(), |
| get_desired_format(bwdParams.src_dims[1])); |
| auto diff_dst_md = |
| memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(), |
| get_desired_format(bwdParams.diff_dst_dims[1])); |
| auto variance_desc = |
| memory::desc({1, bwdParams.depth}, MklDnnType<U>(), memory::nc); |
| auto mean_desc = |
| memory::desc({1, bwdParams.depth}, MklDnnType<U>(), memory::format::nc); |
| auto weights_desc = |
| memory::desc({2, bwdParams.depth}, MklDnnType<U>(), memory::format::nc); |
| auto diff_weights_desc = weights_desc; |
| |
| // fwd desc & primitive desc |
| auto fwd_desc = batch_normalization_forward::desc( |
| prop_kind::forward_training, src_md, bwdParams.eps, |
| bwdParams.training ? use_scale_shift |
| : (use_scale_shift | use_global_stats)); |
| auto fwd_pd = |
| batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); |
| |
| // BatchNorm backward primtive |
| // |
| // For inference, specify use_global_stats |
| // 1. on fwd propagation, use mean and variance provided as inputs. |
| // 2. on bwd propagation, mean and variance are considered as constants. |
| // Thus, reduce the amount of MKL computation. |
| auto bwd_desc = batch_normalization_backward::desc( |
| prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, |
| bwdParams.training ? use_scale_shift |
| : (use_scale_shift | use_global_stats)); |
| auto bn_bwd_pd = batch_normalization_backward::primitive_desc( |
| bwd_desc, cpu_engine_, fwd_pd); |
| |
| // memory primitive |
| context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); |
| context_.diff_dst_mem.reset( |
| new memory({diff_dst_md, cpu_engine_}, DummyData)); |
| context_.variance_mem.reset( |
| new memory({variance_desc, cpu_engine_}, DummyData)); |
| context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); |
| context_.weights_mem.reset( |
| new memory({weights_desc, cpu_engine_}, DummyData)); |
| context_.diff_weights_mem.reset( |
| new memory({diff_weights_desc, cpu_engine_}, DummyData)); |
| context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); |
| |
| context_.bn_bwd.reset(new batch_normalization_backward( |
| bn_bwd_pd, *context_.src_mem, *context_.mean_mem, |
| *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem, |
| *context_.diff_src_mem, *context_.diff_weights_mem)); |
| context_.bwd_primitives.push_back(*context_.bn_bwd); |
| } |
| |
| struct BatchNormBwdContext context_; |
| engine cpu_engine_; |
| }; |
| |
| template <typename T, typename U> |
| class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> { |
| public: |
| static MklFusedBatchNormBwdPrimitive<T, U>* Get( |
| const MklBatchNormBwdParams& bwdParams) { |
| auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T, U>*>( |
| MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance() |
| .GetBatchNormBwd(bwdParams)); |
| if (bn_bwd == nullptr) { |
| bn_bwd = new MklFusedBatchNormBwdPrimitive<T, U>(bwdParams); |
| MklFusedBatchNormBwdPrimitiveFactory<T, U>::GetInstance().SetBatchNormBwd( |
| bwdParams, bn_bwd); |
| } |
| return bn_bwd; |
| } |
| |
| static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() { |
| static MklFusedBatchNormBwdPrimitiveFactory instance_; |
| return instance_; |
| } |
| |
| private: |
| MklFusedBatchNormBwdPrimitiveFactory() {} |
| ~MklFusedBatchNormBwdPrimitiveFactory() {} |
| |
| static string CreateKey(const MklBatchNormBwdParams& bwdParams) { |
| string prefix = "bn_bwd"; |
| FactoryKeyCreator key_creator; |
| key_creator.AddAsKey(prefix); |
| key_creator.AddAsKey(bwdParams.src_dims); |
| key_creator.AddAsKey(bwdParams.diff_dst_dims); |
| key_creator.AddAsKey<int>(bwdParams.depth); |
| key_creator.AddAsKey<float>(bwdParams.eps); |
| key_creator.AddAsKey<bool>(bwdParams.training); |
| key_creator.AddAsKey(typeid(T).name()); |
| key_creator.AddAsKey(typeid(U).name()); |
| return key_creator.GetKey(); |
| } |
| |
| MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) { |
| string key = CreateKey(bwdParams); |
| return this->GetOp(key); |
| } |
| |
| void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams, |
| MklPrimitive* op) { |
| string key = CreateKey(bwdParams); |
| this->SetOp(key, op); |
| } |
| }; |
| |
| // Adding a third parameter to the template to support FusedBatchNormV3 |
| // with MKL. This is different from default where the classes are |
| // derived. Moves enabling to compile-time rather than runtime. |
| template <typename Device, typename T, typename U, bool reserved_space> |
| class MklFusedBatchNormOp : public OpKernel { |
| public: |
| explicit MklFusedBatchNormOp(OpKernelConstruction* context) |
| : OpKernel(context) { |
| float epsilon; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| epsilon_ = epsilon; |
| string tensor_format; |
| OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); |
| OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), |
| errors::InvalidArgument("Invalid data format")); |
| OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); |
| depth_ = 0; |
| mean_values_ = nullptr; |
| variance_values_ = nullptr; |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| try { |
| const size_t kSrcIndex = 0; // index of src input tensor |
| const size_t kScaleIndex = 1; // index of scale tensor |
| const size_t kShiftIndex = 2; // index of shift tensor |
| const size_t kMeanIndex = 3; // index of est_mean tensor |
| const size_t kVarianceIndex = 4; // index of est_variance tensor |
| |
| const Tensor& src_tensor = MklGetInput(context, kSrcIndex); |
| const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); |
| const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); |
| const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); |
| const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex); |
| |
| TensorShape tf_shape_src; |
| MklDnnShape dnn_shape_src; |
| GetMklShape(context, kSrcIndex, &dnn_shape_src); |
| |
| if (dnn_shape_src.IsMklTensor()) { |
| tf_shape_src = dnn_shape_src.GetTfShape(); |
| OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, |
| errors::InvalidArgument("input must be 4-dimensional", |
| src_tensor.shape().DebugString())); |
| } else { |
| tf_shape_src = src_tensor.shape(); |
| OP_REQUIRES(context, src_tensor.dims() == 4, |
| errors::InvalidArgument("input must be 4-dimensional", |
| src_tensor.shape().DebugString())); |
| } |
| OP_REQUIRES(context, scale_tensor.dims() == 1, |
| errors::InvalidArgument("scale must be 1-dimensional", |
| scale_tensor.shape().DebugString())); |
| OP_REQUIRES(context, shift_tensor.dims() == 1, |
| errors::InvalidArgument("offset must be 1-dimensional", |
| shift_tensor.shape().DebugString())); |
| OP_REQUIRES( |
| context, est_mean_tensor.dims() == 1, |
| errors::InvalidArgument("estimated_mean must be 1-dimensional", |
| est_mean_tensor.shape().DebugString())); |
| OP_REQUIRES( |
| context, est_variance_tensor.dims() == 1, |
| errors::InvalidArgument("estimated_variance must be 1-dimensional", |
| est_variance_tensor.shape().DebugString())); |
| |
| if (is_training_) { |
| OP_REQUIRES( |
| context, est_mean_tensor.dim_size(0) == 0, |
| errors::InvalidArgument("estimated_mean must be empty for training", |
| est_mean_tensor.shape().DebugString())); |
| OP_REQUIRES(context, est_variance_tensor.dim_size(0) == 0, |
| errors::InvalidArgument( |
| "estimated_variance must be empty for training", |
| est_variance_tensor.shape().DebugString())); |
| } |
| |
| // special case: input with 0 element and 0 batch size |
| Tensor* dst_tensor = nullptr; |
| if (tf_shape_src.num_elements() == 0) { |
| HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), |
| &dst_tensor); |
| return; |
| } |
| |
| if (dnn_shape_src.IsMklTensor()) |
| depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); |
| else |
| ExtractParams(context); |
| |
| // Indices of output tensors |
| const size_t kDstIndex = 0; |
| |
| // allocate 4 output TF tensors |
| Tensor* batch_mean_tensor = nullptr; |
| Tensor* batch_variance_tensor = nullptr; |
| Tensor* saved_mean_tensor = nullptr; |
| Tensor* saved_variance_tensor = nullptr; |
| Tensor* reserved_space_tensor = nullptr; |
| AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor, |
| &batch_variance_tensor, &saved_mean_tensor, |
| &saved_variance_tensor, &reserved_space_tensor); |
| |
| if (is_training_) |
| SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); |
| else |
| SetMeanVariance(est_mean_tensor, est_variance_tensor); |
| |
| MklDnnData<T> src(&cpu_engine); |
| MklDnnData<U> weights(&cpu_engine); |
| |
| memory::format format_m; |
| if (dnn_shape_src.IsMklTensor()) { |
| if (dnn_shape_src.IsTensorInNCHWFormat()) { |
| format_m = memory::format::nchw; |
| } else { |
| format_m = memory::format::nhwc; |
| } |
| } else { |
| format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); |
| } |
| |
| // set src primitive |
| memory::dims src_dims = |
| dnn_shape_src.IsMklTensor() |
| ? dnn_shape_src.GetSizesAsMklDnnDims() |
| : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); |
| |
| auto src_md = dnn_shape_src.IsMklTensor() |
| ? dnn_shape_src.GetMklLayout() |
| : memory::desc(src_dims, MklDnnType<T>(), format_m); |
| |
| // MKL-DNN packs scale & shift as "weights": |
| // <scale>...<scale><shift>...<shift> |
| weights.AllocateBuffer(2 * depth_ * sizeof(U)); |
| U* weights_data = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); |
| const U* scale_tf = scale_tensor.flat<U>().data(); |
| const U* shift_tf = shift_tensor.flat<U>().data(); |
| |
| std::memcpy(weights_data, scale_tf, depth_ * sizeof(U)); |
| std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(U)); |
| char* saved_mean_data_tf = |
| reinterpret_cast<char*>(saved_mean_tensor->flat<U>().data()); |
| std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), |
| depth_ * sizeof(U)); |
| |
| char* saved_variance_data_tf = |
| reinterpret_cast<char*>(saved_variance_tensor->flat<U>().data()); |
| std::memcpy(saved_variance_data_tf, |
| reinterpret_cast<char*>(variance_values_), |
| depth_ * sizeof(U)); |
| |
| // get batchnorm op from the pool |
| MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_); |
| MklFusedBatchNormFwdPrimitive<T, U>* bn_fwd = |
| MklFusedBatchNormFwdPrimitiveFactory<T, U>::Get(fwdParams); |
| |
| // check if reorder is needed for src, weights, mean, variance |
| const T* src_data = src_tensor.flat<T>().data(); |
| if (src_md.data.format != bn_fwd->GetSrcFmt()) { |
| src.SetUsrMem(src_md, &src_tensor); |
| auto src_target = memory::primitive_desc( |
| {{src_dims}, |
| MklDnnType<T>(), |
| static_cast<memory::format>(bn_fwd->GetSrcFmt())}, |
| cpu_engine); |
| src.CheckReorderToOpMem(src_target); |
| src_data = const_cast<T*>( |
| reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); |
| } |
| |
| // allocate output (dst) tensor; always set it as MKL-DNN layout |
| MklDnnShape dnn_shape_dst; |
| TensorShape tf_shape_dst; |
| dnn_shape_dst.SetMklTensor(true); |
| auto dst_pd = bn_fwd->GetDstPd(); |
| dnn_shape_dst.SetMklLayout(&dst_pd); |
| dnn_shape_dst.SetElemType(MklDnnType<T>()); |
| auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() |
| : src_tensor.shape().dims(); |
| dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m); |
| tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); |
| AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, |
| dnn_shape_dst); |
| |
| U* weights_op_data = weights_data; |
| U* mean_op_data = saved_mean_tensor->flat<U>().data(); |
| U* variance_op_data = saved_variance_tensor->flat<U>().data(); |
| T* dst_data = dst_tensor->flat<T>().data(); |
| |
| // execution |
| bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, |
| variance_op_data); |
| |
| // copy batch_mean data |
| U* batch_mean_data_tf = batch_mean_tensor->flat<U>().data(); |
| std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf), |
| reinterpret_cast<char*>(saved_mean_data_tf), |
| depth_ * sizeof(U)); |
| // TODO(yli135): OpMem is same as usr mem since |
| // since its format is hard-coded as nc when primitive is created. |
| |
| // copy batch_variance data with Bessel's correction |
| float adjust_factor = 1.0; |
| if (is_training_) { |
| size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; |
| size_t adjust_size = orig_size - 1; |
| adjust_factor = (static_cast<float>(orig_size)) / adjust_size; |
| } |
| |
| auto variance_data = reinterpret_cast<U*>(saved_variance_data_tf); |
| auto batch_variance_data = batch_variance_tensor->flat<U>().data(); |
| if (is_training_) { |
| for (int k = 0; k < depth_; k++) { |
| batch_variance_data[k] = |
| variance_data[k] * static_cast<U>(adjust_factor); |
| } |
| } else { |
| std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(U)); |
| } |
| } catch (mkldnn::error& e) { |
| string error_msg = "Status: " + std::to_string(e.status) + |
| ", message: " + string(e.message) + ", in file " + |
| string(__FILE__) + ":" + std::to_string(__LINE__); |
| OP_REQUIRES_OK( |
| context, |
| errors::Aborted("Operation received an exception:", error_msg)); |
| } |
| } |
| |
| private: |
| float epsilon_; |
| TensorFormat tensor_format_; |
| bool is_training_; |
| U* mean_values_; |
| U* variance_values_; |
| size_t depth_; // batch normalization is done for per channel. |
| engine cpu_engine = engine(engine::cpu, 0); |
| |
| void ExtractParams(OpKernelContext* context) { |
| const Tensor& input = MklGetInput(context, 0); |
| depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); |
| } |
| |
| void SetMeanVariance(const Tensor& mean, const Tensor& variance) { |
| mean_values_ = reinterpret_cast<U*>(const_cast<U*>(mean.flat<U>().data())); |
| variance_values_ = |
| reinterpret_cast<U*>(const_cast<U*>(variance.flat<U>().data())); |
| } |
| |
| void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, |
| TensorShape tf_shape_scale, Tensor** dst_tensor) { |
| CHECK_NOTNULL(dst_tensor); |
| |
| const size_t kDstIndex = 0; |
| MklDnnShape dnn_shape_dst; |
| dnn_shape_dst.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, |
| dnn_shape_dst); |
| CHECK_NOTNULL(*dst_tensor); |
| memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0, |
| (*dst_tensor)->tensor_data().size()); |
| |
| Tensor* batch_mean_tensor = nullptr; |
| Tensor* batch_variance_tensor = nullptr; |
| Tensor* saved_mean_tensor = nullptr; |
| Tensor* saved_variance_tensor = nullptr; |
| Tensor* reserved_space_tensor = nullptr; |
| AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor, |
| &batch_variance_tensor, &saved_mean_tensor, |
| &saved_variance_tensor, &reserved_space_tensor); |
| } |
| |
| void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, |
| Tensor** batch_mean_tensor, |
| Tensor** batch_variance_tensor, |
| Tensor** saved_mean_tensor, |
| Tensor** saved_variance_tensor, |
| Tensor** reserved_space_tensor) { |
| CHECK_NOTNULL(batch_mean_tensor); |
| CHECK_NOTNULL(batch_variance_tensor); |
| CHECK_NOTNULL(saved_mean_tensor); |
| CHECK_NOTNULL(saved_variance_tensor); |
| |
| const size_t kBatchMeanIndex = 1; |
| const size_t kBatchVarianceIndex = 2; |
| const size_t kSavedMeanIndex = 3; |
| const size_t kSavedVarianceIndex = 4; |
| const size_t kReservedSpaceIndex = 5; |
| |
| // allocate batch mean output tensor |
| MklDnnShape mkl_shape_batch_mean; |
| mkl_shape_batch_mean.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, |
| tf_shape_scale, mkl_shape_batch_mean); |
| CHECK_NOTNULL(*batch_mean_tensor); |
| // set NAN mean value in case of empty input tensor |
| int num_elements = tf_shape_scale.num_elements(); |
| auto batch_mean_data = (*batch_mean_tensor)->flat<U>().data(); |
| std::fill_n(batch_mean_data, num_elements, static_cast<U>(NAN)); |
| |
| // allocate batch variance output tensor |
| MklDnnShape mkl_shape_batch_variance; |
| mkl_shape_batch_variance.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kBatchVarianceIndex, |
| batch_variance_tensor, tf_shape_scale, |
| mkl_shape_batch_variance); |
| CHECK_NOTNULL(*batch_variance_tensor); |
| // set NAN variance value in case of empty input tensor |
| auto batch_variance_data = (*batch_variance_tensor)->flat<U>().data(); |
| std::fill_n(batch_variance_data, num_elements, static_cast<U>(NAN)); |
| |
| // Mean and variance (without Bessel's correction) saved for backward |
| // computation to serve as pre-computed mean and variance. |
| MklDnnShape mkl_shape_saved_mean; |
| mkl_shape_saved_mean.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, |
| tf_shape_scale, mkl_shape_saved_mean); |
| CHECK_NOTNULL(*saved_mean_tensor); |
| // set NAN mean value in case of empty input tensor |
| auto saved_mean_data = (*saved_mean_tensor)->flat<U>().data(); |
| std::fill_n(saved_mean_data, num_elements, static_cast<U>(NAN)); |
| |
| MklDnnShape mkl_shape_saved_variance; |
| mkl_shape_saved_variance.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kSavedVarianceIndex, |
| saved_variance_tensor, tf_shape_scale, |
| mkl_shape_saved_variance); |
| CHECK_NOTNULL(*saved_variance_tensor); |
| // set NAN variance value in case of empty input tensor |
| auto saved_variance_data = (*saved_variance_tensor)->flat<U>().data(); |
| std::fill_n(saved_variance_data, num_elements, static_cast<U>(NAN)); |
| |
| // Changes to support reserved_space_3 parameter in FusedBatchNormV3. |
| // TODO: This parameter functionality is not implemented on CPU. |
| // It is used to hold intermediate results. So the allocated |
| // memory is filled with NANs. |
| if (reserved_space) { |
| DCHECK(reserved_space_tensor != nullptr); |
| |
| MklDnnShape mkl_shape_reserved_space; |
| mkl_shape_reserved_space.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kReservedSpaceIndex, |
| reserved_space_tensor, tf_shape_scale, |
| mkl_shape_reserved_space); |
| DCHECK((*reserved_space_tensor) != nullptr); |
| auto saved_reserved_space_data = |
| (*reserved_space_tensor)->flat<U>().data(); |
| std::fill_n(saved_reserved_space_data, num_elements, static_cast<U>(NAN)); |
| } |
| } |
| }; |
| |
| template <typename Device, typename T, typename U, bool reserved_space> |
| class MklFusedBatchNormGradOp : public OpKernel { |
| public: |
| explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) |
| : OpKernel(context) { |
| float epsilon; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| epsilon_ = epsilon; |
| string tensor_format; |
| OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); |
| OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), |
| errors::InvalidArgument("Invalid data format")); |
| OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); |
| depth_ = 0; |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| try { |
| const size_t kDiffDstIndex = 0; // index of diff_dst tensor |
| const size_t kSrcIndex = 1; // index of src input tensor |
| const size_t kScaleIndex = 2; // index of scale tensor |
| const size_t kMeanIndex = 3; // index of saved_mean tensor |
| const size_t kVarianceIndex = 4; // index of saved_variance tensor |
| const size_t kReservedSpaceIndex = 5; // index of reserved space 3 tensor |
| |
| const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); |
| const Tensor& src_tensor = MklGetInput(context, kSrcIndex); |
| const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); |
| const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); |
| const Tensor& saved_variance_tensor = |
| MklGetInput(context, kVarianceIndex); |
| const Tensor& reserved_space_tensor = |
| (reserved_space) ? MklGetInput(context, kReservedSpaceIndex) |
| : Tensor(); |
| |
| MklDnnShape dnn_shape_src, dnn_shape_diff_dst; |
| GetMklShape(context, kSrcIndex, &dnn_shape_src); |
| GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); |
| |
| TensorShape tf_shape_src, tf_shape_diff_dst; |
| if (dnn_shape_diff_dst.IsMklTensor()) { |
| tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); |
| OP_REQUIRES( |
| context, dnn_shape_diff_dst.GetDimension() == 4, |
| errors::InvalidArgument("input must be 4-dimensional", |
| diff_dst_tensor.shape().DebugString())); |
| } else { |
| tf_shape_diff_dst = diff_dst_tensor.shape(); |
| OP_REQUIRES( |
| context, diff_dst_tensor.dims() == 4, |
| errors::InvalidArgument("input must be 4-dimensional", |
| diff_dst_tensor.shape().DebugString())); |
| } |
| |
| if (dnn_shape_src.IsMklTensor()) { |
| tf_shape_src = dnn_shape_src.GetTfShape(); |
| OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, |
| errors::InvalidArgument("input must be 4-dimensional", |
| src_tensor.shape().DebugString())); |
| } else { |
| tf_shape_src = src_tensor.shape(); |
| OP_REQUIRES(context, src_tensor.dims() == 4, |
| errors::InvalidArgument("input must be 4-dimensional", |
| src_tensor.shape().DebugString())); |
| } |
| |
| OP_REQUIRES(context, scale_tensor.dims() == 1, |
| errors::InvalidArgument("scale must be 1-dimensional", |
| scale_tensor.shape().DebugString())); |
| OP_REQUIRES( |
| context, saved_mean_tensor.dims() == 1, |
| errors::InvalidArgument("saved mean must be 1-dimensional", |
| saved_mean_tensor.shape().DebugString())); |
| |
| OP_REQUIRES( |
| context, saved_variance_tensor.dims() == 1, |
| errors::InvalidArgument("saved variance must be 1-dimensional", |
| saved_variance_tensor.shape().DebugString())); |
| |
| Tensor* diff_src_tensor = nullptr; |
| // special case: input with 0 element and 0 batch size |
| if (tf_shape_src.num_elements() == 0 || |
| tf_shape_diff_dst.num_elements() == 0) { |
| HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), |
| &diff_src_tensor); |
| return; |
| } |
| |
| if (dnn_shape_src.IsMklTensor()) { |
| depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); |
| } else if (dnn_shape_diff_dst.IsMklTensor()) { |
| depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C); |
| } else { |
| ExtractParams(context); |
| } |
| |
| memory::format format_m; |
| if (dnn_shape_src.IsMklTensor()) { |
| if (dnn_shape_src.IsTensorInNCHWFormat()) |
| format_m = memory::format::nchw; |
| else |
| format_m = memory::format::nhwc; |
| } else { |
| format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); |
| } |
| |
| MklDnnData<T> src(&cpu_engine); |
| MklDnnData<T> diff_dst(&cpu_engine); |
| MklDnnData<U> weights(&cpu_engine); |
| MklDnnData<U> diff_weights(&cpu_engine); |
| |
| memory::dims src_dims = |
| dnn_shape_src.IsMklTensor() |
| ? dnn_shape_src.GetSizesAsMklDnnDims() |
| : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); |
| memory::dims diff_dst_dims = |
| dnn_shape_diff_dst.IsMklTensor() |
| ? dnn_shape_diff_dst.GetSizesAsMklDnnDims() |
| : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), |
| tensor_format_); |
| |
| // set src and diff_dst primitive descriptors |
| memory::desc src_md = |
| dnn_shape_src.IsMklTensor() |
| ? dnn_shape_src.GetMklLayout() |
| : memory::desc(src_dims, MklDnnType<T>(), format_m); |
| memory::desc diff_dst_md = |
| dnn_shape_diff_dst.IsMklTensor() |
| ? dnn_shape_diff_dst.GetMklLayout() |
| : memory::desc(diff_dst_dims, MklDnnType<T>(), format_m); |
| |
| // weights -- MKL DNN packs scales/ shifts as weights in order |
| // of scale, ..., scale, shift, ...., shift |
| weights.AllocateBuffer(2 * depth_ * sizeof(U)); |
| U* weights_data_tf = reinterpret_cast<U*>(weights.GetAllocatedBuffer()); |
| const U* scale_tf = scale_tensor.flat<U>().data(); |
| for (int k = 0; k < depth_; k++) { |
| weights_data_tf[k] = scale_tf[k]; |
| weights_data_tf[k + depth_] = static_cast<U>(0); |
| } |
| |
| diff_weights.AllocateBuffer(2 * depth_ * sizeof(U)); |
| |
| MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, |
| is_training_); |
| MklFusedBatchNormBwdPrimitive<T, U>* bn_bwd = |
| MklFusedBatchNormBwdPrimitiveFactory<T, U>::Get(bwdParams); |
| |
| // check if src/diff_dst need to be reordered |
| const T* src_data = src_tensor.flat<T>().data(); |
| if (src_md.data.format != bn_bwd->GetSrcFmt()) { |
| src.SetUsrMem(src_md, &src_tensor); |
| auto src_target = memory::primitive_desc( |
| {{src_dims}, |
| MklDnnType<T>(), |
| static_cast<memory::format>(bn_bwd->GetSrcFmt())}, |
| cpu_engine); |
| src.CheckReorderToOpMem(src_target); |
| src_data = const_cast<T*>( |
| reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); |
| } |
| |
| const T* diff_dst_data = diff_dst_tensor.flat<T>().data(); |
| if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) { |
| diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); |
| auto diff_dst_target = memory::primitive_desc( |
| {{diff_dst_dims}, |
| MklDnnType<T>(), |
| static_cast<memory::format>(bn_bwd->GetDiffDstFmt())}, |
| cpu_engine); |
| diff_dst.CheckReorderToOpMem(diff_dst_target); |
| diff_dst_data = const_cast<T*>( |
| reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle())); |
| } |
| |
| // Indices of output tensors |
| const size_t kDiffSrcIndex = 0; // index of diff_src tensor |
| |
| // allocate output tensor: diff_src, always set as MKL-DNN layout |
| MklDnnShape dnn_shape_diff_src; |
| TensorShape tf_shape_diff_src; |
| dnn_shape_diff_src.SetMklTensor(true); |
| auto diff_src_pd = bn_bwd->GetDiffSrcPd(); |
| dnn_shape_diff_src.SetMklLayout(&diff_src_pd); |
| dnn_shape_diff_src.SetElemType(MklDnnType<T>()); |
| dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, format_m); |
| dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); |
| tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); |
| AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, |
| tf_shape_diff_src, dnn_shape_diff_src); |
| |
| U* mean_data = |
| static_cast<U*>(const_cast<U*>(saved_mean_tensor.flat<U>().data())); |
| U* variance_data = static_cast<U*>( |
| const_cast<U*>(saved_variance_tensor.flat<U>().data())); |
| U* weights_data = weights_data_tf; |
| T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data()); |
| U* diff_weights_data = static_cast<U*>(diff_weights.GetAllocatedBuffer()); |
| |
| U* res_space_data = |
| ((reserved_space) ? static_cast<U*>(const_cast<U*>( |
| reserved_space_tensor.flat<U>().data())) |
| : nullptr); |
| |
| // Execute |
| bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data, |
| weights_data, diff_src_data, diff_weights_data, |
| res_space_data); |
| |
| // allocate output TF tensors: diff_scale and diff_shift |
| Tensor* diff_scale_tensor = nullptr; |
| Tensor* diff_shift_tensor = nullptr; |
| AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, |
| &diff_shift_tensor); |
| |
| // copy data: diff_scale and diff_shift |
| auto diff_scale_data = diff_scale_tensor->flat<U>().data(); |
| auto diff_shift_data = diff_shift_tensor->flat<U>().data(); |
| std::memcpy(reinterpret_cast<char*>(diff_scale_data), |
| reinterpret_cast<char*>(diff_weights_data), |
| depth_ * sizeof(U)); |
| std::memcpy(reinterpret_cast<char*>(diff_shift_data), |
| reinterpret_cast<char*>(diff_weights_data + depth_), |
| depth_ * sizeof(U)); |
| } catch (mkldnn::error& e) { |
| string error_msg = "Status: " + std::to_string(e.status) + |
| ", message: " + string(e.message) + ", in file " + |
| string(__FILE__) + ":" + std::to_string(__LINE__); |
| OP_REQUIRES_OK( |
| context, |
| errors::Aborted("Operation received an exception:", error_msg)); |
| } |
| } |
| |
| private: |
| float epsilon_; |
| TensorFormat tensor_format_; |
| size_t depth_; // batch normalization is done for per channel. |
| bool is_training_; |
| engine cpu_engine = engine(engine::cpu, 0); |
| |
| void ExtractParams(OpKernelContext* context) { |
| const Tensor& input = MklGetInput(context, 0); |
| depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); |
| } |
| |
| void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, |
| TensorShape tf_shape_scale_shift, |
| Tensor** diff_src_tensor) { |
| const size_t kDiffSrcIndex = 0; |
| |
| MklDnnShape dnn_shape_diff_src; |
| dnn_shape_diff_src.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, |
| tf_shape_src, dnn_shape_diff_src); |
| auto diff_src_data = (*diff_src_tensor)->flat<T>().data(); |
| std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), |
| static_cast<T>(0)); |
| |
| Tensor* diff_scale_tensor = nullptr; |
| Tensor* diff_shift_tensor = nullptr; |
| AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor, |
| &diff_shift_tensor); |
| } |
| |
| void AllocateTFOutputs(OpKernelContext* context, |
| TensorShape tf_shape_scale_shift, |
| Tensor** diff_scale_tensor, |
| Tensor** diff_shift_tensor) { |
| CHECK_NOTNULL(diff_scale_tensor); |
| CHECK_NOTNULL(diff_shift_tensor); |
| |
| const size_t kDiffScaleIndex = 1; |
| const size_t kDiffShiftIndex = 2; |
| const size_t kP1Index = 3; |
| const size_t kP2Index = 4; |
| |
| // separate out scale and shift grad and copy to individual tensors |
| MklDnnShape mkl_shape_diff_scale; |
| mkl_shape_diff_scale.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, |
| tf_shape_scale_shift, mkl_shape_diff_scale); |
| CHECK_NOTNULL(*diff_scale_tensor); |
| auto diff_scale_data = (*diff_scale_tensor)->flat<U>().data(); |
| std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), |
| static_cast<U>(0)); |
| |
| MklDnnShape mkl_shape_diff_shift; |
| mkl_shape_diff_shift.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, |
| tf_shape_scale_shift, mkl_shape_diff_shift); |
| CHECK_NOTNULL(*diff_shift_tensor); |
| auto diff_shift_data = (*diff_shift_tensor)->flat<U>().data(); |
| std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), |
| static_cast<U>(0)); |
| |
| // Placeholders for estimated_mean and estimated_variance, which are |
| // used for inference and thus not needed here for gradient computation. |
| Tensor *p1_tensor = nullptr, *p2_tensor = nullptr; |
| MklDnnShape mkl_shape_p; |
| mkl_shape_p.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), |
| mkl_shape_p); |
| AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), |
| mkl_shape_p); |
| } |
| |
| memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } |
| }; |
| |
| #define REGISTER_MKL_FUSED_BATCHNORM_CPU(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedBatchNorm") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedBatchNormOp<CPUDevice, T, T, false>); |
| |
| TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_CPU); |
| TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_CPU); |
| #undef REGISTER_MKL_FUSED_BATCHNORM_CPU |
| |
| #define REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(T, U) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedBatchNormV2") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<U>("U") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedBatchNormOp<CPUDevice, T, U, false>); |
| |
| REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(float, float); |
| REGISTER_MKL_FUSED_BATCHNORM_V2_CPU(bfloat16, float); |
| #undef REGISTER_MKL_FUSED_BATCHNORM_V2_CPU |
| |
| #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedBatchNormGrad") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedBatchNormGradOp<CPUDevice, T, T, false>); |
| |
| TF_CALL_float(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); |
| TF_CALL_bfloat16(REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU); |
| #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_CPU |
| |
| #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(T, U) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedBatchNormGradV2") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<U>("U") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedBatchNormGradOp<CPUDevice, T, U, false>); |
| |
| REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(float, float); |
| REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU(bfloat16, float); |
| #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V2_CPU |
| |
| // TODO: FusedBatchNormV3 has an additional output that is used to |
| // hold intermediate results. This parameter functionality is |
| // not implemented on CPU. |
| #define REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(T, U) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedBatchNormV3") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<U>("U") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedBatchNormOp<CPUDevice, T, U, true>); |
| |
| REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(float, float); |
| REGISTER_MKL_FUSED_BATCHNORM_V3_CPU(bfloat16, float); |
| #undef REGISTER_MKL_FUSED_BATCHNORM_V3_CPU |
| |
| #define REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(T, U) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedBatchNormGradV3") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<U>("U") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedBatchNormGradOp<CPUDevice, T, U, true>); |
| |
| REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(float, float); |
| REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float); |
| #undef REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU |
| |
| } // namespace tensorflow |
| |
| #endif // INTEL_MKL |