| /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // See docs in ../ops/nn_ops.cc. |
| #ifdef INTEL_MKL |
| |
| #include "tensorflow/core/kernels/mkl/mkl_conv_ops.h" |
| |
| #include <algorithm> |
| #include <map> |
| #include <string> |
| #include <unordered_map> |
| |
| #include "absl/strings/str_join.h" |
| #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" |
| #include "tensorflow/core/kernels/no_op.h" |
| #ifdef DNNL_AARCH64_USE_ACL |
| #include "tensorflow/core/platform/hash.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #endif |
| |
| using dnnl::convolution_forward; |
| using dnnl::prop_kind; |
| using dnnl::stream; |
| using ConvFwdPd = dnnl::convolution_forward::primitive_desc; |
| using ReorderPd = dnnl::reorder::primitive_desc; |
| |
| namespace tensorflow { |
| // This structure aggregates multiple inputs to Conv2DFwd* methods. |
| struct MklConvFwdParams { |
| memory::dims src_dims; |
| memory::dims filter_dims; |
| memory::dims bias_dims; |
| memory::dims dst_dims; |
| memory::dims strides; |
| memory::dims dilations; |
| memory::dims padding_left; |
| memory::dims padding_right; |
| memory::dims fuse_bn_dims; |
| MklTensorFormat tf_fmt; |
| bool native_format; |
| string dtypes = string(""); |
| #ifdef DNNL_AARCH64_USE_ACL |
| uint64 filter_hash; |
| #endif |
| struct PostOpParam { |
| string name; |
| dnnl::algorithm alg; |
| std::vector<float> param; |
| std::string partial_key; |
| }; |
| std::vector<PostOpParam> post_op_params; |
| |
| MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims, |
| memory::dims bias_dims, memory::dims dst_dims, |
| memory::dims strides, memory::dims dilations, |
| memory::dims padding_left, memory::dims padding_right, |
| memory::dims fuse_bn_dims, MklTensorFormat tf_fmt, |
| bool native_format) |
| : src_dims(src_dims), |
| filter_dims(filter_dims), |
| bias_dims(bias_dims), |
| dst_dims(dst_dims), |
| strides(strides), |
| dilations(dilations), |
| padding_left(padding_left), |
| padding_right(padding_right), |
| fuse_bn_dims(fuse_bn_dims), |
| tf_fmt(tf_fmt), |
| native_format(native_format) {} |
| }; |
| |
| // With quantization, input, filter, and output can have different types |
| // so we use different template parameter for each type |
| template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput> |
| class MklConvFwdPrimitive : public MklPrimitive { |
| public: |
| explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) |
| : MklPrimitive(engine(engine::kind::cpu, 0)) { |
| // Create convolution primitive |
| if (context_.conv_fwd == nullptr) { |
| Setup(convFwdDims); |
| } |
| } |
| ~MklConvFwdPrimitive() {} |
| |
| dnnl::memory::desc GetScratchPadDesc() { |
| return context_.fwd_pd->scratchpad_desc(); |
| } |
| |
| // Convolution forward execute with bias |
| // src_data: input data buffer of src |
| // filter_data: input data buffer of filter (weights) |
| // bias_data: input data buffer of bias |
| // dst_data: output data buffer of dst |
| void Execute(const Tinput* src_data, const Tfilter* filter_data, |
| const Tbias* bias_data, const Toutput* dst_data, |
| std::shared_ptr<stream> fwd_stream, void* sp_data = nullptr) { |
| Execute(src_data, filter_data, bias_data, dst_data, nullptr, nullptr, |
| nullptr, nullptr, fwd_stream, sp_data); |
| } |
| |
| void Execute(const Tinput* src_data, const Tfilter* filter_data, |
| const Tbias* bias_data, const Toutput* dst_data, |
| const Tinput* bn_scale_data, const Tinput* bn_mean_data, |
| const Tinput* bn_offset_data, const Tinput* bn_rsqrt_data, |
| std::shared_ptr<stream> fwd_stream, void* sp_data) { |
| #ifdef DNNL_AARCH64_USE_ACL |
| // When we are using single global cache then in this case we can have |
| // multiple threads running the same primitive that we created so this |
| // should happen under the lock. |
| mutex_lock lock(primitive_execution_mu_); |
| #endif |
| #ifndef ENABLE_ONEDNN_OPENMP |
| // TODO(intel-tf): Create a common function and avoid the duplicate code |
| context_.src_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream); |
| context_.filter_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream); |
| if (bias_data != nullptr) { |
| context_.bias_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream); |
| } |
| if (bn_scale_data != nullptr) { |
| context_.bn_scale_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_scale_data)), *fwd_stream); |
| context_.bn_mean_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_mean_data)), *fwd_stream); |
| context_.bn_rsqrt_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_rsqrt_data)), *fwd_stream); |
| context_.bn_offset_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_offset_data)), *fwd_stream); |
| } |
| context_.dst_mem->set_data_handle( |
| static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream); |
| #else |
| context_.src_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(src_data))); |
| context_.filter_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tfilter*>(filter_data))); |
| if (bias_data != nullptr) { |
| context_.bias_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tbias*>(bias_data))); |
| } |
| if (bn_scale_data != nullptr) { |
| context_.bn_scale_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_scale_data))); |
| context_.bn_mean_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_mean_data))); |
| context_.bn_rsqrt_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_rsqrt_data))); |
| context_.bn_offset_mem->set_data_handle( |
| static_cast<void*>(const_cast<Tinput*>(bn_offset_data))); |
| } |
| context_.dst_mem->set_data_handle( |
| static_cast<void*>(const_cast<Toutput*>(dst_data))); |
| #endif // !ENABLE_ONEDNN_OPENMP |
| if (sp_data) { |
| context_.sp_mem->set_data_handle(static_cast<void*>(sp_data), |
| *fwd_stream); |
| } |
| |
| DCHECK_EQ(context_.fwd_primitives.size(), |
| context_.fwd_primitives_args.size()); |
| for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { |
| context_.fwd_primitives.at(i).execute(*fwd_stream, |
| context_.fwd_primitives_args.at(i)); |
| } |
| |
| // After execution, set data handle back |
| context_.src_mem->set_data_handle(DummyData); |
| context_.filter_mem->set_data_handle(DummyData); |
| if (bias_data != nullptr) { |
| context_.bias_mem->set_data_handle(DummyData); |
| } |
| if (bn_scale_data != nullptr) { |
| context_.bn_scale_mem->set_data_handle(DummyData); |
| context_.bn_mean_mem->set_data_handle(DummyData); |
| context_.bn_rsqrt_mem->set_data_handle(DummyData); |
| context_.bn_offset_mem->set_data_handle(DummyData); |
| } |
| context_.dst_mem->set_data_handle(DummyData); |
| if (sp_data) { |
| context_.sp_mem->set_data_handle(DummyData); |
| } |
| } |
| |
| // Convolution forward execute without bias |
| // src_data: input data buffer of src |
| // filter_data: input data buffer of filter (weights) |
| // dst_data: output data buffer of dst |
| void Execute(const Tinput* src_data, const Tfilter* filter_data, |
| const Toutput* dst_data, std::shared_ptr<stream> fwd_stream, |
| void* sp_data) { |
| Execute(src_data, filter_data, nullptr, dst_data, nullptr, nullptr, nullptr, |
| nullptr, fwd_stream, sp_data); |
| } |
| |
| std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const { |
| return context_.fwd_pd; |
| } |
| |
| private: |
| // Primitive reuse context for Conv2D Fwd op |
| struct ConvFwdContext { |
| // MKL-DNN memory |
| std::shared_ptr<dnnl::memory> src_mem; |
| std::shared_ptr<dnnl::memory> filter_mem; |
| std::shared_ptr<dnnl::memory> bias_mem; |
| std::shared_ptr<dnnl::memory> dst_mem; |
| std::shared_ptr<dnnl::memory> sp_mem; |
| |
| // FusedBatchNorm related memory |
| std::shared_ptr<dnnl::memory> bn_scale_mem; |
| std::shared_ptr<dnnl::memory> bn_mean_mem; |
| std::shared_ptr<dnnl::memory> bn_rsqrt_mem; |
| std::shared_ptr<dnnl::memory> bn_offset_mem; |
| |
| // Desc & primitive desc |
| std::shared_ptr<dnnl::convolution_forward::desc> fwd_desc; |
| |
| // Memory desc |
| std::shared_ptr<dnnl::memory::desc> src_md; |
| std::shared_ptr<dnnl::memory::desc> filter_md; |
| std::shared_ptr<dnnl::memory::desc> bias_md; |
| std::shared_ptr<dnnl::memory::desc> dst_md; |
| |
| // TODO(intel-tf): Only need one? FusedBatchNorm related. |
| std::shared_ptr<dnnl::memory::desc> bn_scale_md; |
| std::shared_ptr<dnnl::memory::desc> bn_mean_md; |
| std::shared_ptr<dnnl::memory::desc> bn_rsqrt_md; |
| std::shared_ptr<dnnl::memory::desc> bn_offset_md; |
| |
| // Convolution primitive |
| std::shared_ptr<ConvFwdPd> fwd_pd; |
| std::shared_ptr<dnnl::primitive> conv_fwd; |
| |
| std::vector<dnnl::primitive> fwd_primitives; |
| std::vector<std::unordered_map<int, memory>> fwd_primitives_args; |
| |
| ConvFwdContext() |
| : src_mem(nullptr), |
| filter_mem(nullptr), |
| bias_mem(nullptr), |
| dst_mem(nullptr), |
| sp_mem(nullptr), |
| bn_scale_mem(nullptr), |
| bn_mean_mem(nullptr), |
| bn_rsqrt_mem(nullptr), |
| bn_offset_mem(nullptr), |
| fwd_desc(nullptr), |
| src_md(nullptr), |
| filter_md(nullptr), |
| bias_md(nullptr), |
| dst_md(nullptr), |
| bn_scale_md(nullptr), |
| bn_mean_md(nullptr), |
| bn_rsqrt_md(nullptr), |
| bn_offset_md(nullptr), |
| fwd_pd(nullptr), |
| conv_fwd(nullptr) {} |
| }; |
| |
| void Setup(const MklConvFwdParams& convFwdDims) { |
| memory::format_tag user_data_fmt; |
| if (convFwdDims.native_format) { |
| user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt); |
| } else { |
| // Create memory descriptors for convolution data w/ no specified format |
| user_data_fmt = memory::format_tag::any; |
| } |
| context_.src_md.reset(new memory::desc( |
| {convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt)); |
| |
| context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims}, |
| MklDnnType<Tfilter>(), |
| memory::format_tag::any)); |
| |
| context_.dst_md.reset(new memory::desc( |
| {convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt)); |
| |
| if (!convFwdDims.bias_dims.empty()) { |
| context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims}, |
| MklDnnType<Tbias>(), |
| memory::format_tag::any)); |
| // Create a convolution descriptor |
| context_.fwd_desc.reset(new convolution_forward::desc( |
| prop_kind::forward, dnnl::algorithm::convolution_direct, |
| *context_.src_md, *context_.filter_md, *context_.bias_md, |
| *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, |
| convFwdDims.padding_left, convFwdDims.padding_right)); |
| } else { |
| context_.fwd_desc.reset(new convolution_forward::desc( |
| prop_kind::forward, dnnl::algorithm::convolution_direct, |
| *context_.src_md, *context_.filter_md, *context_.dst_md, |
| convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, |
| convFwdDims.padding_right)); |
| } |
| |
| if (!convFwdDims.fuse_bn_dims.empty()) { |
| const memory::format_tag fused_bn_arg_fmt = |
| convFwdDims.native_format |
| ? user_data_fmt |
| : MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt); |
| |
| context_.bn_scale_md.reset(new memory::desc( |
| {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); |
| context_.bn_mean_md.reset(new memory::desc( |
| {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); |
| context_.bn_rsqrt_md.reset(new memory::desc( |
| {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); |
| context_.bn_offset_md.reset(new memory::desc( |
| {convFwdDims.fuse_bn_dims}, MklDnnType<Tinput>(), fused_bn_arg_fmt)); |
| } |
| |
| // Check if there is any fusions as post-ops |
| auto const& post_op_params = convFwdDims.post_op_params; |
| dnnl::primitive_attr post_ops_attr; |
| dnnl::post_ops post_ops; |
| post_ops_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); |
| if (!post_op_params.empty()) { |
| for (auto const& post_op_param : post_op_params) { |
| if (post_op_param.name == "activation") { |
| DCHECK_EQ(post_op_param.param.size(), 3); |
| float op_scale = post_op_param.param[0]; |
| float op_alpha = post_op_param.param[1]; |
| float op_beta = post_op_param.param[2]; |
| post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha, |
| op_beta); |
| } else if (post_op_param.name == "sum") { |
| DCHECK_EQ(post_op_param.param.size(), 1); |
| float op_scale = post_op_param.param[0]; |
| post_ops.append_sum(op_scale); |
| } else if (post_op_param.name == "output_scale") { |
| if (post_op_param.param.size() == 1) { |
| post_ops_attr.set_output_scales(0, post_op_param.param); |
| } else { |
| post_ops_attr.set_output_scales(2, post_op_param.param); |
| } |
| } else if (post_op_param.name == "fuse_bn") { |
| post_ops.append_binary(dnnl::algorithm::binary_sub, |
| *context_.bn_mean_md); |
| post_ops.append_binary(dnnl::algorithm::binary_mul, |
| *context_.bn_rsqrt_md); |
| post_ops.append_binary(dnnl::algorithm::binary_mul, |
| *context_.bn_scale_md); |
| post_ops.append_binary(dnnl::algorithm::binary_add, |
| *context_.bn_offset_md); |
| } else { |
| DCHECK((post_op_param.name == "activation") || |
| (post_op_param.name == "sum") || |
| (post_op_param.name == "output_scale") || |
| (post_op_param.name == "fuse_bn")); |
| } |
| } |
| post_ops_attr.set_post_ops(post_ops); |
| } |
| context_.fwd_pd.reset( |
| new ConvFwdPd(*context_.fwd_desc, post_ops_attr, cpu_engine_)); |
| |
| // Create memory primitive based on dummy data |
| context_.src_mem.reset( |
| new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); |
| context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), |
| cpu_engine_, DummyData)); |
| context_.dst_mem.reset( |
| new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); |
| |
| context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); |
| auto scratchpad_md = context_.fwd_pd->scratchpad_desc(); |
| context_.sp_mem.reset( |
| new dnnl::memory(scratchpad_md, cpu_engine_, DummyData)); |
| |
| // Create convolution primitive and add it to net |
| if (!convFwdDims.bias_dims.empty()) { |
| context_.bias_mem.reset(new memory( |
| {{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x}, |
| cpu_engine_, DummyData)); |
| context_.fwd_primitives_args.push_back( |
| {{DNNL_ARG_SRC, *context_.src_mem}, |
| {DNNL_ARG_WEIGHTS, *context_.filter_mem}, |
| {DNNL_ARG_BIAS, *context_.bias_mem}, |
| {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, |
| {DNNL_ARG_DST, *context_.dst_mem}}); |
| } else if (!convFwdDims.fuse_bn_dims.empty()) { |
| context_.bn_scale_mem.reset( |
| new memory(*context_.bn_scale_md, cpu_engine_, DummyData)); |
| context_.bn_mean_mem.reset( |
| new memory(*context_.bn_mean_md, cpu_engine_, DummyData)); |
| context_.bn_offset_mem.reset( |
| new memory(*context_.bn_offset_md, cpu_engine_, DummyData)); |
| context_.bn_rsqrt_mem.reset( |
| new memory(*context_.bn_rsqrt_md, cpu_engine_, DummyData)); |
| |
| context_.fwd_primitives_args.push_back( |
| {{DNNL_ARG_SRC, *context_.src_mem}, |
| {DNNL_ARG_WEIGHTS, *context_.filter_mem}, |
| {DNNL_ARG_DST, *context_.dst_mem}, |
| {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, |
| {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, |
| *context_.bn_mean_mem}, |
| {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, |
| *context_.bn_rsqrt_mem}, |
| {DNNL_ARG_ATTR_MULTIPLE_POST_OP(2) | DNNL_ARG_SRC_1, |
| *context_.bn_scale_mem}, |
| {DNNL_ARG_ATTR_MULTIPLE_POST_OP(3) | DNNL_ARG_SRC_1, |
| *context_.bn_offset_mem}}); |
| } else { |
| context_.fwd_primitives_args.push_back( |
| {{DNNL_ARG_SRC, *context_.src_mem}, |
| {DNNL_ARG_WEIGHTS, *context_.filter_mem}, |
| {DNNL_ARG_SCRATCHPAD, *context_.sp_mem}, |
| {DNNL_ARG_DST, *context_.dst_mem}}); |
| } |
| context_.fwd_primitives.push_back(*context_.conv_fwd); |
| } |
| |
| struct ConvFwdContext context_; |
| |
| #ifdef DNNL_AARCH64_USE_ACL |
| // Guards Execution() |
| mutex primitive_execution_mu_; |
| #endif |
| }; |
| |
| // TODO(intel-tf): We should not require passing a type to MklPrimitiveFactory. |
| // But removing the need for type in MklPrimitiveFactory is going to require |
| // change to every MKL op. So not doing it now. Instead passing float. |
| template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput> |
| class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<float> { |
| public: |
| static MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* Get( |
| const MklConvFwdParams& convFwdDims, bool do_not_cache) { |
| MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>* conv_fwd = nullptr; |
| |
| if (do_not_cache) { |
| // Always create a new primitive |
| conv_fwd = |
| new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>(convFwdDims); |
| } else { |
| // Try to find a suitable one in pool |
| conv_fwd = |
| dynamic_cast<MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>*>( |
| MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, |
| Toutput>::GetInstance() |
| .GetConvFwd(convFwdDims)); |
| if (conv_fwd == nullptr) { |
| conv_fwd = new MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Toutput>( |
| convFwdDims); |
| MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, |
| Toutput>::GetInstance() |
| .SetConvFwd(convFwdDims, conv_fwd); |
| } |
| } |
| |
| return conv_fwd; |
| } |
| |
| private: |
| MklConvFwdPrimitiveFactory() {} |
| ~MklConvFwdPrimitiveFactory() {} |
| |
| static const int kDilationH = 0, kDilationW = 1; |
| |
| static MklConvFwdPrimitiveFactory& GetInstance() { |
| static MklConvFwdPrimitiveFactory instance_; |
| return instance_; |
| } |
| |
| static string CreateKey(const MklConvFwdParams& convFwdDims) { |
| string prefix = "conv_fwd_"; |
| FactoryKeyCreator key_creator; |
| key_creator.AddAsKey(prefix); |
| key_creator.AddAsKey(convFwdDims.src_dims); |
| key_creator.AddAsKey(convFwdDims.filter_dims); |
| #ifdef DNNL_AARCH64_USE_ACL |
| key_creator.AddAsKey(convFwdDims.filter_hash); |
| #endif |
| key_creator.AddAsKey(convFwdDims.bias_dims); |
| key_creator.AddAsKey(convFwdDims.dst_dims); |
| key_creator.AddAsKey(convFwdDims.strides); |
| key_creator.AddAsKey(convFwdDims.dilations); |
| key_creator.AddAsKey(convFwdDims.padding_left); |
| key_creator.AddAsKey(convFwdDims.padding_right); |
| key_creator.AddAsKey(convFwdDims.dtypes); |
| if (convFwdDims.native_format) { |
| key_creator.AddAsKey(convFwdDims.tf_fmt); |
| } |
| |
| // Generate keys for post-ops |
| for (auto const& post_op_param : convFwdDims.post_op_params) { |
| key_creator.AddAsKey(post_op_param.name); |
| if (post_op_param.name == "activation") { |
| DCHECK_EQ(post_op_param.param.size(), 3); |
| for (auto& param : post_op_param.param) { |
| key_creator.AddAsKey(param); |
| } |
| } else if (post_op_param.name == "sum") { |
| DCHECK_EQ(post_op_param.param.size(), 1); |
| for (auto& param : post_op_param.param) { |
| key_creator.AddAsKey(param); |
| } |
| } else if (post_op_param.name == "output_scale") { |
| key_creator.AddAsKey(post_op_param.partial_key); |
| } else if (post_op_param.name == "fuse_bn") { |
| key_creator.AddAsKey(post_op_param.name); |
| key_creator.AddAsKey(convFwdDims.fuse_bn_dims); |
| } else { |
| return string("not_a_key"); |
| } |
| } |
| |
| return key_creator.GetKey(); |
| } |
| |
| MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) { |
| string key = CreateKey(convFwdDims); |
| return this->GetOp(key); |
| } |
| |
| void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) { |
| string key = CreateKey(convFwdDims); |
| this->SetOp(key, op); |
| } |
| }; |
| |
| // Base class for convolution forward operations |
| template <typename Device, typename Tinput, typename Tfilter, typename Tbias, |
| typename Toutput, typename Ttemp_output, typename Tpadding, |
| bool bias_enabled, bool pad_enabled, bool is_depthwise, |
| bool native_format> |
| class MklConvOp : public OpKernel { |
| public: |
| ~MklConvOp() {} |
| |
| explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { |
| OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); |
| |
| // Conv and QuantizedConv ops have different padding attributes |
| // (`padding_list` versus `explicit_paddings`). But one and only one |
| // attribute is expected. |
| OP_REQUIRES( |
| context, |
| !(context->HasAttr("padding_list") && |
| context->HasAttr("explicit_paddings")), |
| errors::InvalidArgument("Can only have 1 `padding` list at most")); |
| if (context->HasAttr("padding_list")) { |
| OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); |
| } |
| if (context->HasAttr("explicit_paddings")) { |
| OP_REQUIRES_OK(context, |
| context->GetAttr("explicit_paddings", &padding_list_)); |
| } |
| |
| OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); |
| OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str_)); |
| OP_REQUIRES(context, FormatFromString(data_format_str_, &data_format_), |
| errors::InvalidArgument("Invalid data format")); |
| OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5), |
| errors::InvalidArgument("Sliding window strides field must " |
| "specify 4 or 5 dimensions")); |
| |
| const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); |
| const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); |
| OP_REQUIRES( |
| context, stride_n == 1 && stride_c == 1, |
| errors::Unimplemented("Current implementation does not yet support " |
| "strides in the batch and depth dimensions.")); |
| |
| OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); |
| is_filter_const_ = false; |
| if (AreWeightsFrozen()) { |
| is_filter_const_ = true; |
| } else if (context->HasAttr("is_filter_const")) { |
| OP_REQUIRES_OK(context, |
| context->GetAttr("is_filter_const", &is_filter_const_)); |
| } |
| |
| if (strides_.size() == 4) { |
| OP_REQUIRES(context, dilations_.size() == 4, |
| errors::InvalidArgument("Sliding window dilations field must " |
| "specify 4 dimensions")); |
| const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); |
| const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); |
| const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); |
| const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); |
| OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, |
| errors::InvalidArgument( |
| "Current implementation does not yet support " |
| "dilations in the batch and depth dimensions.")); |
| OP_REQUIRES( |
| context, dilation_h > 0 && dilation_w > 0, |
| errors::InvalidArgument("Dilated rates should be larger than 0.")); |
| } else if (strides_.size() == 5) { |
| OP_REQUIRES(context, dilations_.size() == 5, |
| errors::InvalidArgument("Dilation rates field must " |
| "specify 5 dimensions")); |
| OP_REQUIRES(context, |
| (GetTensorDim(dilations_, data_format_, 'N') == 1 && |
| GetTensorDim(dilations_, data_format_, 'C') == 1), |
| errors::InvalidArgument( |
| "Current implementation does not yet support " |
| "dilations rates in the batch and depth dimensions.")); |
| OP_REQUIRES( |
| context, |
| (GetTensorDim(dilations_, data_format_, '0') > 0 && |
| GetTensorDim(dilations_, data_format_, '1') > 0 && |
| GetTensorDim(dilations_, data_format_, '2') > 0), |
| errors::InvalidArgument("Dilated rates should be larger than 0.")); |
| } |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| try { |
| // Input tensors |
| const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src); |
| const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); |
| |
| OP_REQUIRES( |
| context, filter_tensor.NumElements() > 0, |
| errors::InvalidArgument("filter must not have zero elements " |
| "(i.e. all dimensions must be non-zero)")); |
| |
| MklDnnShape src_mkl_shape, filter_mkl_shape; |
| GetMklShape(context, kInputIndex_Src, &src_mkl_shape, native_format); |
| GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, |
| native_format); |
| |
| OP_REQUIRES(context, !filter_mkl_shape.IsMklTensor(), |
| errors::InvalidArgument("Filter should not be in " |
| "Mkl Layout")); |
| |
| MklDnnData<Tinput> src(&cpu_engine_); |
| MklDnnData<Tfilter> filter(&cpu_engine_); |
| |
| memory::dims src_dims, filter_dims, padding_left, padding_right, |
| dilations, strides; |
| memory::dims dst_dims_tf_order, dst_dims_mkl_order; |
| |
| // For any Conv with `EXPLICIT` padding, get padding from `padding_list` |
| // attribute. Otherwise, get it from one of the inputs. |
| bool pad_attr_enabled = false; |
| for (auto const& padding_val : padding_list_) { |
| if (padding_val) { |
| pad_attr_enabled = true; |
| |
| break; |
| } |
| } |
| |
| if (fuse_pad_ || pad_attr_enabled) { |
| PadWithConvFusion(context, padding_left, padding_right, |
| pad_attr_enabled, data_format_str_); |
| } |
| |
| // Get shapes of input tensors in MKL-DNN order |
| MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_, |
| dilations_); |
| auto src_tf_shape = GetTfShape(context, kInputIndex_Src, native_format); |
| auto filter_tf_shape = |
| GetTfShape(context, kInputIndex_Filter, native_format); |
| bool is_grouped_convolution = false; |
| conv_utl.GetConvFwdSizesInMklOrder( |
| src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, |
| &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left, |
| &padding_right, &is_grouped_convolution, |
| (fuse_pad_ || pad_attr_enabled), is_depthwise); |
| |
| if (!context->status().ok()) return; |
| |
| // Check for corner case - if there is nothing to compute, return. |
| TensorShape dst_tf_shape = MklDnnDimsToTFShape(dst_dims_tf_order); |
| |
| // Corner cases: output with 0 elements and 0 batch size. |
| Tensor* dst_tensor = nullptr; |
| bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) && |
| typeid(Tinput) == typeid(Toutput) && |
| (typeid(Tinput) == typeid(float) || |
| typeid(Tinput) == typeid(bfloat16))) && |
| !native_format; |
| if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) { |
| MklDnnShape dst_mkl_shape; |
| dst_mkl_shape.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor, |
| src_tf_shape, dst_mkl_shape, native_format); |
| |
| // MklConv2D/3D also outputs converted filter as 2nd output. |
| filter_mkl_shape.SetMklTensor(false); |
| Tensor* output_filter_tensor = nullptr; |
| if (emit_filter_output) { |
| filter_mkl_shape.SetMklTensor(false); |
| AllocateOutputSetMklShape(context, kOutputIndex_Filter, |
| &output_filter_tensor, filter_tf_shape, |
| filter_mkl_shape); |
| } |
| return; |
| } |
| |
| bool is_conv2d = (strides_.size() == 4); |
| bool is_conv3d = (strides_.size() == 5); |
| |
| if (!is_conv2d && !is_conv3d) { |
| OP_REQUIRES( |
| context, !pad_enabled, |
| errors::InvalidArgument("Pad + Conv fusion only works for 2D/3D")); |
| OP_REQUIRES( |
| context, !fuse_pad_, |
| errors::InvalidArgument("Pad+Conv fusion only works for 2D/3D")); |
| } |
| |
| // TODO(intel-tf) 3-D support for Depthwise is not there |
| if (is_depthwise) { |
| OP_REQUIRES(context, is_conv2d, |
| errors::InvalidArgument( |
| "Only 2D convolution is supported for depthwise.")); |
| } |
| |
| // Create memory for user data. |
| // Describe how the inputs and outputs of Convolution look like. Also |
| // specify buffers containing actual input and output data. |
| auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_) |
| : TFDataFormatToMklDnn3DDataFormat(data_format_); |
| |
| auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); |
| // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU |
| OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, |
| errors::InvalidArgument("Invalid data format")); |
| |
| // If input is in MKL layout, then simply grab the layout; otherwise, |
| // construct TF layout for input. |
| // For constructing TF layout for input, although input shape (src_dims) |
| // is required to be in MKL-DNN order, the input layout is actually in |
| // TF layout depending on the data format: |
| // Conv2D: NHWC or NCHW |
| // Conv3D: NDHWC or NCDHW |
| auto src_md = |
| src_mkl_shape.IsMklTensor() |
| ? src_mkl_shape.GetMklLayout() |
| : memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag); |
| src.SetUsrMem(src_md, &src_tensor); |
| |
| // Although filter shape (filter_dims) required is in MKL-DNN order, |
| // the layout is Tensorflow's layout (HWIO) and (HWIGO) for |
| // depthwise/group convolutions. |
| auto filter_format = is_conv2d ? ((is_depthwise || is_grouped_convolution) |
| ? memory::format_tag::hwigo |
| : memory::format_tag::hwio) |
| : memory::format_tag::dhwio; |
| |
| DCHECK(!filter_mkl_shape.IsMklTensor()); |
| auto filter_md = |
| filter_mkl_shape.IsMklTensor() |
| ? filter_mkl_shape.GetMklLayout() |
| : memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format); |
| filter.SetUsrMem(filter_md, &filter_tensor); |
| |
| // MKL-DNN dilations start from 0. |
| for (int i = 0; i < dilations.size(); ++i) --dilations[i]; |
| |
| // In some cases, primitive descriptor could potentially contain |
| // large buffers. As a result, we don't cache these primitives if the |
| // environment variable `TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE` is set to True. |
| // MKL-DNN allocates buffers in the following cases: |
| // 1. Legacy CPU without AVX512/AVX2, or |
| // 2. 1x1 convolution with strides != 1 |
| bool do_not_cache = |
| MklPrimitiveFactory<Tinput>::IsPrimitiveMemOptEnabled() && |
| (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) && |
| (MklPrimitiveFactory<Tinput>::IsLegacyPlatform() || |
| IsConv1x1StrideNot1(filter_dims, strides)); |
| |
| // Get a conv2d fwd from primitive pool |
| MklConvFwdPrimitive<Tinput, Tfilter, Tbias, Ttemp_output>* conv_fwd = |
| nullptr; |
| memory::dims bias_dims = {}; |
| if (fuse_biasadd_) { |
| conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims); |
| } |
| memory::dims fuse_bn_dims = {}; |
| TensorShape fuse_bn_shape; |
| if (fuse_bn_) { |
| // Inputs to FusedBatchNorm have same 1D shape |
| fuse_bn_shape = MklGetInput(context, kInputIndex_BN_Mean).shape(); |
| OP_REQUIRES(context, fuse_bn_shape.dims() == 1, |
| errors::InvalidArgument("FusedBatchNorm must be 1D, not: ", |
| fuse_bn_shape.DebugString())); |
| |
| // Note - MKL-DNN expects {1, C, 1, 1} for binary post-op even for NHWC |
| fuse_bn_dims = {1, fuse_bn_shape.dim_size(0), 1, 1}; |
| } |
| |
| MklConvFwdParams convFwdDims( |
| src_dims, filter_dims, fuse_biasadd_ ? bias_dims : NONE_DIMS, |
| dst_dims_mkl_order, strides, dilations, padding_left, padding_right, |
| fuse_bn_dims, tf_fmt, native_format); |
| |
| // TODO(intel-tf): Extend the basic parameters for data types and fusions |
| this->ExtendConvFwdParams(context, convFwdDims); |
| #ifdef DNNL_AARCH64_USE_ACL |
| // TODO(milpuz01): Remove once Arm Compute Library provides support for |
| // in-place updates |
| convFwdDims.filter_hash = Hash64( |
| filter_tensor.tensor_data().data(), |
| std::min(kFilterTensorHashLength, |
| static_cast<int>(filter_tensor.tensor_data().size()))); |
| #endif |
| |
| conv_fwd = |
| MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get( |
| convFwdDims, do_not_cache); |
| // Allocate output tensors `dst_tensor` and `filter_out_tensor` |
| MklDnnShape output_mkl_shape; |
| std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); |
| AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt, |
| &output_mkl_shape, &dst_tensor); |
| |
| Tensor* filter_out_tensor = nullptr; |
| if (emit_filter_output) { |
| AllocateFilterOutputTensor(context, *conv_fwd_pd, |
| TFShapeToMklDnnDims(filter_tf_shape), |
| &filter_out_tensor); |
| } |
| |
| Ttemp_output* dst_data = |
| reinterpret_cast<Ttemp_output*>(dst_tensor->flat<Toutput>().data()); |
| |
| // Check whether src and filter need to be reordered. |
| Tinput* src_data = nullptr; |
| if (src_md != conv_fwd_pd->src_desc()) { |
| src.SetUsrMem(src_md, &src_tensor); |
| src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context); |
| src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); |
| } else { |
| src_data = static_cast<Tinput*>( |
| const_cast<Tinput*>(src_tensor.flat<Tinput>().data())); |
| } |
| |
| Tfilter* filter_data = nullptr; |
| if (filter_md != conv_fwd_pd->weights_desc()) { |
| bool is_filter_cached = false; |
| // If filter is a constant, we can avoid the conversion of filter from |
| // Tensorflow format to MKL format by caching the filter when it is |
| // converted for the first time. This cached filter can then be reused |
| // in subsequent iterations. |
| if (is_filter_const_) { |
| if (IsFilterCacheEmpty(context)) { |
| // Cache filter if it is not already cached. |
| CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, |
| filter, filter_md, filter_mkl_shape); |
| } |
| filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc()); |
| is_filter_cached = (filter_data != nullptr); |
| } |
| if (!is_filter_cached) { |
| filter.SetUsrMem(filter_md, &filter_tensor); |
| if (filter_out_tensor == nullptr) { |
| filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_, |
| context); |
| } else { |
| filter.CheckReorderToOpMem( |
| conv_fwd_pd->weights_desc(), |
| filter.GetTensorBuffer(filter_out_tensor), cpu_engine_, |
| context); |
| } |
| filter_data = |
| static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); |
| } |
| } else { |
| filter_data = static_cast<Tfilter*>( |
| const_cast<Tfilter*>(filter_tensor.flat<Tfilter>().data())); |
| } |
| |
| UserScratchPad<unsigned char> scratch_pad; |
| scratch_pad.AllocateSPTensor(conv_fwd, context); |
| |
| // Execute convolution |
| std::shared_ptr<stream> fwd_cpu_stream; |
| MklDnnThreadPool eigen_tp(context); |
| fwd_cpu_stream.reset(CreateStream(&eigen_tp, conv_fwd->GetEngine())); |
| if (fuse_biasadd_) { |
| const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); |
| Tbias* bias_data = |
| this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); |
| conv_fwd->Execute(src_data, filter_data, bias_data, dst_data, |
| fwd_cpu_stream, scratch_pad.Get()); |
| } else if (fuse_bn_) { |
| const Tensor& bn_scale_tensor = |
| MklGetInput(context, kInputIndex_BN_Scale); |
| Tinput* bn_scale_data = static_cast<Tinput*>( |
| const_cast<Tinput*>(bn_scale_tensor.flat<Tinput>().data())); |
| const Tensor& bn_mean_tensor = |
| MklGetInput(context, kInputIndex_BN_Mean); |
| Tinput* bn_mean_data = static_cast<Tinput*>( |
| const_cast<Tinput*>(bn_mean_tensor.flat<Tinput>().data())); |
| const Tensor& bn_offset_tensor = |
| MklGetInput(context, kInputIndex_BN_Offset); |
| Tinput* bn_offset_data = static_cast<Tinput*>( |
| const_cast<Tinput*>(bn_offset_tensor.flat<Tinput>().data())); |
| |
| Tensor bn_rsqrt_tensor; |
| OP_REQUIRES_OK(context, |
| context->allocate_temp(DataTypeToEnum<Tinput>::v(), |
| fuse_bn_shape, &bn_rsqrt_tensor)); |
| Tinput* bn_rsqrt_data = static_cast<Tinput*>( |
| const_cast<Tinput*>(bn_rsqrt_tensor.flat<Tinput>().data())); |
| this->ComputeBNScale(context, epsilon_, kInputIndex_BN_Variance, |
| bn_rsqrt_data); |
| conv_fwd->Execute(src_data, filter_data, nullptr, dst_data, |
| bn_scale_data, bn_mean_data, bn_offset_data, |
| bn_rsqrt_data, fwd_cpu_stream, scratch_pad.Get()); |
| } else { |
| conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream, |
| scratch_pad.Get()); |
| } |
| |
| // Delete primitive since it is not cached. |
| if (do_not_cache) delete conv_fwd; |
| |
| } catch (dnnl::error& e) { |
| string error_msg = tensorflow::strings::StrCat( |
| "Status: ", e.status, ", message: ", string(e.message), ", in file ", |
| __FILE__, ":", __LINE__); |
| OP_REQUIRES_OK( |
| context, |
| errors::Aborted("Operation received an exception:", error_msg)); |
| } |
| } |
| |
| void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left, |
| memory::dims& padding_right, bool pad_attr_enabled, |
| string data_format_str_) { |
| Tpadding* paddings = nullptr; |
| if (pad_attr_enabled) { |
| paddings = padding_list_.data(); |
| } else { |
| const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); |
| OP_REQUIRES(context, paddings_tf.dims() == 2, |
| errors::InvalidArgument("paddings must be 2-dimensional: ", |
| paddings_tf.shape().DebugString())); |
| // Flatten tensor to get individual paddings. |
| paddings = static_cast<Tpadding*>( |
| const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data())); |
| } |
| // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf) |
| // will be zero. |
| // Example: |
| // paddings_tf = [ [0, 0] [1, 2] [3, 4] [0, 0] ], |
| // flat method = row-major, then: |
| // paddings = {0, 0, 1, 2, 3, 4, 0, 0}. |
| // Hence, the values are: top = 1, bottom = 2, left = 3, right = 4. |
| // |
| // Similarly, if the data format is NCHW, indices 0, 1, 2 and 3 of |
| // paddings(_tf) will be zero. |
| // i.e. for the above example, paddings = {0, 0, 0, 0, 1, 2, 3, 4}. |
| int64 pad_top = 0, pad_left = 0, pad_front = 0; |
| int64 pad_bottom = 0, pad_right = 0, pad_back = 0; |
| if (data_format_str_ == "NHWC") { |
| pad_top = paddings[2]; |
| pad_bottom = paddings[3]; |
| pad_left = paddings[4]; |
| pad_right = paddings[5]; |
| } else if (data_format_str_ == "NCHW") { |
| pad_top = paddings[4]; |
| pad_bottom = paddings[5]; |
| pad_left = paddings[6]; |
| pad_right = paddings[7]; |
| } else if (data_format_str_ == "NDHWC") { |
| pad_front = paddings[2]; |
| pad_back = paddings[3]; |
| pad_top = paddings[4]; |
| pad_bottom = paddings[5]; |
| pad_left = paddings[6]; |
| pad_right = paddings[7]; |
| } else if (data_format_str_ == "NCDHW") { |
| pad_front = paddings[4]; |
| pad_back = paddings[5]; |
| pad_top = paddings[6]; |
| pad_bottom = paddings[7]; |
| pad_left = paddings[8]; |
| pad_right = paddings[9]; |
| } |
| // Create padding arrays for MKL-DNN convolutions. |
| // MKL-DNN uses asymmetric padding. |
| if (data_format_str_ == "NHWC" || data_format_str_ == "NCHW") { |
| padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; |
| padding_right = {static_cast<int>(pad_bottom), |
| static_cast<int>(pad_right)}; |
| } else if (data_format_str_ == "NDHWC" || data_format_str_ == "NCDHW") { |
| padding_left = {static_cast<int>(pad_front), static_cast<int>(pad_top), |
| static_cast<int>(pad_left)}; |
| padding_right = {static_cast<int>(pad_back), static_cast<int>(pad_bottom), |
| static_cast<int>(pad_right)}; |
| } |
| } |
| |
| protected: |
| void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; } |
| void set_fuse_activation(bool fuse_activation, dnnl::algorithm activation_alg, |
| float alpha_or_upbound = 0.0) { |
| fuse_activation_ = fuse_activation; |
| activation_alg_ = activation_alg; |
| // This variable is used for alpha in leakyrelu or upper bound in relu6 |
| // depending on the context |
| alpha_or_upbound_ = alpha_or_upbound; |
| } |
| void set_fuse_pad(bool fuse_pad) { |
| fuse_pad_ = fuse_pad; |
| if (fuse_bn_) { |
| // If FusedBatchNorm is fused in PadWithFusedConv2D, pad is the 7th input |
| input_index_pad_ = 6; |
| } else if (fuse_add_ && fuse_biasadd_) { |
| // If Bias and Add are fused in PadWithFusedConv2D, pad is the 5th input |
| input_index_pad_ = 4; |
| } else { |
| // Case of Bias is fused in PadwithFusedConv OP, pad is the fourth input |
| input_index_pad_ = 3; |
| } |
| } |
| void set_fuse_add(bool fuse_add) { fuse_add_ = fuse_add; } |
| void set_fuse_bn(bool fuse_bn, float epsilon) { |
| fuse_bn_ = fuse_bn; |
| epsilon_ = epsilon; |
| } |
| |
| virtual void ComputeBNScale(OpKernelContext* context, float epsilon, |
| int bn_variance_index, Tinput* scale_buf_ptr) { |
| OP_REQUIRES( |
| context, false, |
| errors::Unimplemented("Compute BN scale not expected in base class")); |
| return; |
| } |
| |
| // This method is for the base class MklConvOp, which handles the |
| // floating point implementation of Conv. The quantized conv implementations |
| // will use overridden versions of this method. |
| virtual void ExtendConvFwdParams(OpKernelContext* context, |
| MklConvFwdParams& params) { |
| // Create a string from data types of input, filter, bias, and output. |
| params.dtypes.append(typeid(Tinput).name()); |
| params.dtypes.append(typeid(Tfilter).name()); |
| params.dtypes.append(typeid(Tbias).name()); |
| params.dtypes.append(typeid(Toutput).name()); |
| |
| // Add fusions as post ops |
| // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by |
| // checking `fuse_biasadd_` flag. |
| if (fuse_add_) { |
| params.post_op_params.push_back( |
| {"sum", dnnl::algorithm::undef, {1.0}, ""}); |
| } |
| // NOTE - fuse_bn post_op entry must be before fuse_activation |
| if (fuse_bn_) { |
| params.post_op_params.push_back( |
| {"fuse_bn", dnnl::algorithm::undef, {1.0}, ""}); |
| } |
| if (fuse_activation_) { |
| params.post_op_params.push_back( |
| {"activation", activation_alg_, {1.0, alpha_or_upbound_, 0.0}, ""}); |
| } |
| } |
| |
| virtual Tbias* GetBiasHandle(OpKernelContext* context, |
| std::shared_ptr<ConvFwdPd>& conv2d_fwd_pd, |
| const Tensor& bias_tensor) { |
| if (fuse_biasadd_) { |
| return static_cast<Tbias*>( |
| const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); |
| } |
| return nullptr; |
| } |
| |
| virtual void AllocateOutputTensor(OpKernelContext* context, |
| const ConvFwdPd& conv_prim_desc, |
| const memory::dims& output_dims_mkl_order, |
| MklTensorFormat output_tf_format, |
| MklDnnShape* output_mkl_shape, |
| Tensor** output_tensor) { |
| DCHECK(output_tensor); |
| auto dst_md = conv_prim_desc.dst_desc(); |
| |
| if (!std::is_same<Ttemp_output, Toutput>::value) { |
| dst_md.data.data_type = |
| static_cast<dnnl_data_type_t>(MklDnnType<Toutput>()); |
| } |
| |
| // Allocate shape of MKL tensor |
| output_mkl_shape->SetMklTensor(true); |
| output_mkl_shape->SetMklLayout(&dst_md); |
| output_mkl_shape->SetElemType(MklDnnType<Toutput>()); |
| output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(), |
| output_dims_mkl_order, output_tf_format); |
| |
| // Allocate shape of TF tensor |
| TensorShape output_tf_shape; |
| output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput))); |
| if (native_format) { |
| output_tf_shape = output_mkl_shape->GetTfShape(); |
| } |
| |
| if (fuse_add_) { |
| const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add); |
| MklDnnShape add_mkl_shape; |
| GetMklShape(context, kInputIndex_Add, &add_mkl_shape, native_format); |
| // Forward the summand tensor to the output only if it has no other |
| // references, otherwise make a copy of it. |
| if (native_format && context->forward_input_to_output_with_shape( |
| kInputIndex_Add, kOutputIndex_Dst, |
| output_tf_shape, output_tensor)) { |
| return; |
| } |
| // Check if reorder is needed |
| if (!native_format && add_mkl_shape == *output_mkl_shape && |
| ForwardMklTensorInToOutWithMklShape(context, kInputIndex_Add, |
| kOutputIndex_Dst, output_tensor, |
| add_mkl_shape, false)) { |
| return; |
| } else { |
| AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, |
| output_tf_shape, *output_mkl_shape, |
| native_format); |
| auto output_format_tag = MklTensorFormatToMklDnnDataFormat( |
| output_mkl_shape->GetTfDataFormat()); |
| OP_REQUIRES(context, output_format_tag != memory::format_tag::undef, |
| errors::InvalidArgument( |
| "MklConvOp: AddN fusion: Invalid data format")); |
| auto add_md = |
| add_mkl_shape.IsMklTensor() |
| ? add_mkl_shape.GetMklLayout() |
| : memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(), |
| output_format_tag); |
| void* add_buf = static_cast<void*>( |
| const_cast<Toutput*>(add_tensor.flat<Toutput>().data())); |
| void* dst_buf = |
| static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); |
| if (native_format) { |
| // We are simply deep copying the add_tensor to output_tensor without |
| // changing memory layout, hence using same memory descriptor. |
| add_md = dst_md = |
| memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(), |
| dnnl::memory::format_tag::x); |
| } |
| fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf)); |
| fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf)); |
| auto reorder_desc = |
| ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md); |
| |
| CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_, |
| this->cpu_engine_, context); |
| } |
| } else { |
| AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, |
| output_tf_shape, *output_mkl_shape, |
| native_format); |
| } |
| } |
| |
| engine cpu_engine_ = engine(engine::kind::cpu, 0); |
| |
| private: |
| std::shared_ptr<dnnl::memory> fuse_add_src_; |
| std::shared_ptr<dnnl::memory> fuse_add_dst_; |
| std::vector<int32> strides_; |
| std::vector<int32> dilations_; |
| std::vector<Tpadding> padding_list_; |
| bool is_filter_const_; |
| mutex mu_; |
| Padding padding_; |
| string data_format_str_; |
| TensorFormat data_format_; |
| Tensor cached_filter_data_ TF_GUARDED_BY(mu_); |
| Tensor cached_filter_md_ TF_GUARDED_BY(mu_); |
| |
| // Initialize to values the template is instantiated with |
| bool fuse_biasadd_ = bias_enabled; |
| bool fuse_activation_ = false; |
| bool fuse_pad_ = pad_enabled; |
| bool fuse_add_ = false; |
| bool fuse_bn_ = false; |
| float epsilon_ = 0.0001; |
| |
| // This variable is used for alpha in leakyrelu or upper bound in relu6 |
| // depending on the context |
| float alpha_or_upbound_ = 0.0; |
| dnnl::algorithm activation_alg_ = dnnl::algorithm::undef; |
| |
| int input_index_pad_ = 2; |
| |
| const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2; |
| const int kInputIndex_Add = 3; |
| const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; |
| const int kDilationH = 0, kDilationW = 1; |
| |
| // Input indices for FusedBatchNorm |
| const int kInputIndex_BN_Scale = 2, kInputIndex_BN_Offset = 3; |
| const int kInputIndex_BN_Mean = 4, kInputIndex_BN_Variance = 5; |
| #ifdef DNNL_AARCH64_USE_ACL |
| const int kFilterTensorHashLength = 1024; |
| #endif |
| |
| MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape, |
| const ConvFwdPd& conv_prim_desc) const { |
| DCHECK(filter_mkl_shape); |
| return filter_mkl_shape->GetTfDataFormat(); |
| } |
| |
| // Allocate tensors for cached filter data and cached filter memory |
| // descriptor (data format) |
| void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, |
| Tensor** filter_tensor, |
| const MklDnnShape* filter_mkl_shape) |
| TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
| DCHECK(filter_tensor); |
| TensorShape filter_tf_shape; |
| filter_tf_shape.AddDim( |
| (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter))); |
| OP_REQUIRES_OK( |
| context, context->allocate_temp(DataTypeToEnum<Tfilter>::value, |
| filter_tf_shape, &cached_filter_data_)); |
| |
| *filter_tensor = &cached_filter_data_; |
| |
| // There is no tensor format in DNNL 1.x. So we cache the complete filter |
| // descriptor as flat byte array. |
| TensorShape cached_filter_md_shape; |
| memory::desc weights_desc = conv_prim_desc.weights_desc(); |
| // We don't use .get_size() method of memory::desc since it returns size |
| // required to store primitive's input memory. It is much more than size of |
| // memory::desc itself. |
| cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8)); |
| OP_REQUIRES_OK(context, |
| context->allocate_temp(DT_UINT8, cached_filter_md_shape, |
| &cached_filter_md_)); |
| *reinterpret_cast<memory::desc*>(cached_filter_md_.flat<uint8>().data()) = |
| weights_desc; |
| } |
| |
| void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, |
| Tensor** filter_tensor) { |
| AllocateTensor(context, conv_prim_desc, filter_tensor, nullptr); |
| } |
| |
| void AllocateFilterOutputTensor(OpKernelContext* context, |
| const ConvFwdPd& conv_prim_desc, |
| const memory::dims& filter_dims_tf_order, |
| Tensor** filter_tensor) { |
| DCHECK(filter_tensor); |
| auto filter_md = conv_prim_desc.weights_desc(); |
| |
| // Allocate shape of MKL tensor |
| MklDnnShape filter_mkl_shape; |
| filter_mkl_shape.SetMklTensor(true); |
| filter_mkl_shape.SetMklLayout(&filter_md); |
| filter_mkl_shape.SetElemType(MklDnnType<Tfilter>()); |
| |
| // The format of the filter is actually OIhw8i8o, but TF doesn't support |
| // this format. Just use format::blocked for now because the layout |
| // is stored in the MKL data. |
| filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), |
| filter_dims_tf_order, |
| MklTensorFormat::FORMAT_BLOCKED); |
| |
| // Allocate the data space for the filter to propagate as TF tensor. |
| TensorShape filter_tf_shape; |
| filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter))); |
| |
| AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, |
| filter_tf_shape, filter_mkl_shape); |
| } |
| |
| // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot |
| // be acquired before entering the function, since it is acquired |
| // inside the function. |
| inline bool IsFilterCacheEmpty(OpKernelContext* context) |
| TF_LOCKS_EXCLUDED(mu_) { |
| tf_shared_lock lock(mu_); |
| const Tensor& cached_filter_data_tensor = cached_filter_data_; |
| return (cached_filter_data_tensor.NumElements() == 0); |
| } |
| |
| // Cache the converted filter in a tensor. |
| // Only one thread can execute this method at any given time. |
| void CacheFilter(OpKernelContext* context, |
| const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, |
| Tfilter* filter_data, const Tensor& filter_tensor, |
| MklDnnData<Tfilter>& filter, const memory::desc& filter_md, |
| const MklDnnShape& filter_mkl_shape) TF_LOCKS_EXCLUDED(mu_) { |
| mutex_lock lock(mu_); |
| const Tensor& cached_filter_data_tensor = cached_filter_data_; |
| |
| // If filter is already cached, there's nothing to do. |
| if (cached_filter_data_tensor.NumElements() > 0) { |
| return; |
| } |
| |
| // Otherwise, cache filter |
| filter.SetUsrMem(filter_md, &filter_tensor); |
| filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(), |
| this->cpu_engine_, context); |
| filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); |
| |
| Tensor* filter_tensor_ptr = nullptr; |
| AllocateTensor(context, *conv_fwd_pd, &filter_tensor_ptr, |
| &filter_mkl_shape); |
| void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr); |
| size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size(); |
| memcpy(cached_filter_data, filter_data, cached_filter_data_size); |
| } |
| |
| bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, |
| const Tensor& cached_filter_md) { |
| auto filter_md_data = filter_md.data; |
| const char* filter_data = reinterpret_cast<const char*>(&filter_md_data); |
| |
| auto cached_filter_md_data = cached_filter_md.scalar<int64_t>()(); |
| const char* cached_filter_data = |
| reinterpret_cast<const char*>(&cached_filter_md_data); |
| |
| for (size_t i = 0; i < sizeof(filter_md_data); ++i) { |
| if (*filter_data++ != *cached_filter_data++) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| Tfilter* GetCachedFilter(OpKernelContext* context, |
| const memory::desc& filter_md) |
| TF_LOCKS_EXCLUDED(mu_) { |
| tf_shared_lock lock(mu_); |
| const Tensor& cached_filter_data = cached_filter_data_; |
| const Tensor& cached_filter_md = cached_filter_md_; |
| |
| // Check if the memory descriptor of the cached weights is the same as |
| // filter_md. If so, we can use the cached weights; otherwise |
| // return nullptr. |
| if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) { |
| return static_cast<Tfilter*>( |
| const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data())); |
| } |
| return nullptr; |
| } |
| }; |
| |
| // Base class for fused convolution forward operations |
| template <typename Device, typename Tinput, typename Tfilter, typename Tbias, |
| typename Toutput, typename Ttemp_output, typename Tpadding, |
| bool pad_enabled, bool native_format> |
| class MklFusedConvOp |
| : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, |
| Tpadding, false, false, false, native_format> { |
| public: |
| explicit MklFusedConvOp(OpKernelConstruction* context) |
| : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, |
| Tpadding, false, false, false, native_format>(context) { |
| // Since we came here through the registration of _MklFusedConv2D, get |
| // all information from 'fused_ops' and 'num_args' |
| std::vector<string> fused_ops; |
| OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); |
| |
| int num_args; |
| OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); |
| OP_REQUIRES(context, !fused_ops.empty(), |
| errors::InvalidArgument( |
| "Fused Conv2D must have at least one fused op.")); |
| |
| // TODO(intel-tf): Compact the code for activation checking |
| if (fused_ops == std::vector<string>{"BiasAdd"}) { |
| this->set_fuse_biasadd(true); |
| OP_REQUIRES(context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused Conv2D must have one extra argument: bias.")); |
| } else if (fused_ops == std::vector<string>{"Relu"}) { |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| } else if (fused_ops == std::vector<string>{"Relu6"}) { |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| } else if (fused_ops == std::vector<string>{"Elu"}) { |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| } else if (fused_ops == std::vector<string>{"LeakyRelu"}) { |
| float leakyrelu_alpha; |
| OP_REQUIRES_OK(context, |
| context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, |
| leakyrelu_alpha); |
| } else if (fused_ops == std::vector<string>{"FusedBatchNorm"}) { |
| float epsilon; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| OP_REQUIRES( |
| context, num_args == 4, |
| errors::InvalidArgument( |
| "Fused Conv2D with batchnorm must have 4 extra argument")); |
| this->set_fuse_bn(true, epsilon); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| OP_REQUIRES(context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused Conv2D must have one extra argument: bias.")); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| OP_REQUIRES(context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused Conv2D must have one extra argument: bias.")); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| OP_REQUIRES(context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused Conv2D must have one extra argument: bias.")); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) { |
| this->set_fuse_biasadd(true); |
| float leakyrelu_alpha; |
| OP_REQUIRES_OK(context, |
| context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, |
| leakyrelu_alpha); |
| OP_REQUIRES(context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused Conv2D must have one extra argument: bias.")); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| OP_REQUIRES( |
| context, num_args == 2, |
| errors::InvalidArgument( |
| "Fused Conv2D must have two extra arguments: bias and add.")); |
| } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"}) { |
| float epsilon; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| OP_REQUIRES( |
| context, num_args == 4, |
| errors::InvalidArgument( |
| "Fused Conv2D with batchnorm must have 4 extra argument")); |
| this->set_fuse_bn(true, epsilon); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"}) { |
| float epsilon; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| OP_REQUIRES( |
| context, num_args == 4, |
| errors::InvalidArgument( |
| "Fused Conv2D with batchnorm must have 4 extra argument")); |
| this->set_fuse_bn(true, epsilon); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| } else if (fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"}) { |
| float epsilon; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| OP_REQUIRES( |
| context, num_args == 4, |
| errors::InvalidArgument( |
| "Fused Conv2D with batchnorm must have 4 extra argument")); |
| this->set_fuse_bn(true, epsilon); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| } else if (fused_ops == |
| std::vector<string>{"FusedBatchNorm", "LeakyRelu"}) { |
| float epsilon, leakyrelu_alpha; |
| OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); |
| OP_REQUIRES_OK(context, |
| context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); |
| OP_REQUIRES( |
| context, num_args == 4, |
| errors::InvalidArgument( |
| "Fused Conv2D with batchnorm must have 4 extra argument")); |
| this->set_fuse_bn(true, epsilon); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, |
| leakyrelu_alpha); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| OP_REQUIRES( |
| context, num_args == 2, |
| errors::InvalidArgument( |
| "Fused Conv2D must have two extra arguments: bias and add.")); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| OP_REQUIRES( |
| context, num_args == 2, |
| errors::InvalidArgument( |
| "Fused Conv2D must have two extra arguments: bias and add.")); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| OP_REQUIRES( |
| context, num_args == 2, |
| errors::InvalidArgument( |
| "Fused Conv2D must have two extra arguments: bias and add.")); |
| } else if (fused_ops == |
| std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| float leakyrelu_alpha; |
| OP_REQUIRES_OK(context, |
| context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, |
| leakyrelu_alpha); |
| OP_REQUIRES( |
| context, num_args == 2, |
| errors::InvalidArgument( |
| "Fused Conv2D must have two extra arguments: bias and add.")); |
| } else { |
| OP_REQUIRES(context, false, |
| errors::Unimplemented("Fusion is not implemented: [", |
| absl::StrJoin(fused_ops, ","), "]")); |
| } |
| |
| if (pad_enabled) { |
| this->set_fuse_pad(true); |
| } |
| } |
| |
| void ComputeBNScale(OpKernelContext* context, float epsilon, |
| int bn_variance_index, Tinput* scale_buf_ptr) override { |
| const Tensor& bn_var_tensor = MklGetInput(context, bn_variance_index); |
| |
| Eigen::Tensor<Tinput, 1, Eigen::RowMajor> bn_rsqrt = |
| (bn_var_tensor.flat<Tinput>() + static_cast<Tinput>(epsilon)).rsqrt(); |
| Tinput* bn_rsqrt_data = bn_rsqrt.data(); |
| size_t num_elem = bn_var_tensor.shape().dim_size(0); |
| for (size_t i = 0; i < num_elem; i++) { |
| scale_buf_ptr[i] = bn_rsqrt_data[i]; |
| } |
| return; |
| } |
| |
| virtual ~MklFusedConvOp() {} |
| }; |
| |
| template <typename Device, typename Tinput, typename Tfilter, typename Tbias, |
| typename Toutput, typename Ttemp_output, typename Tpadding, |
| bool pad_enabled, bool bias_enabled, bool is_depthwise, |
| bool native_format> |
| class MklFusedDepthwiseConvOp |
| : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, |
| Tpadding, bias_enabled, false, is_depthwise, |
| native_format> { |
| public: |
| explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context) |
| : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, |
| Tpadding, bias_enabled, false, is_depthwise, native_format>( |
| context) { |
| // Since we came here through the registration of |
| // _MklFusedDepthwiseConv2dNative, get all |
| // information from 'fused_ops' and 'num_args' |
| std::vector<string> fused_ops; |
| OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); |
| |
| int num_args; |
| OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); |
| OP_REQUIRES(context, !fused_ops.empty(), |
| errors::InvalidArgument( |
| "Fused DepthwiseConv2D must have at least one fused op.")); |
| |
| if (fused_ops == std::vector<string>{"BiasAdd"}) { |
| this->set_fuse_biasadd(true); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| } else { |
| OP_REQUIRES(context, false, |
| errors::Unimplemented("Fusion is not implemented: [", |
| absl::StrJoin(fused_ops, ","), "]")); |
| } |
| |
| OP_REQUIRES( |
| context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused DepthwiseConv2D must have one extra argument: bias.")); |
| |
| if (pad_enabled) { |
| this->set_fuse_pad(true); |
| } |
| } |
| |
| virtual ~MklFusedDepthwiseConvOp() {} |
| }; |
| |
| // We create new class for each version of Quantized Convolution and inherit |
| // from the FP32 version of the base class |
| template <typename Device, typename Tinput, typename Tbias, typename Toutput, |
| typename Ttemp_output, bool bias_enabled, bool is_depthwise, |
| bool native_format = false> |
| class MklQuantizedConv2DOp |
| : public MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, |
| int32, bias_enabled, false, is_depthwise, |
| native_format> { |
| public: |
| virtual ~MklQuantizedConv2DOp() { |
| if (this->input_bias_ != nullptr) { |
| delete this->input_bias_; |
| input_bias_ = nullptr; |
| } |
| |
| if (this->scaled_bias_ != nullptr) { |
| delete this->scaled_bias_; |
| scaled_bias_ = nullptr; |
| } |
| } |
| |
| explicit MklQuantizedConv2DOp(OpKernelConstruction* context) |
| : MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, |
| bias_enabled, false, is_depthwise, native_format>(context) { |
| bool is_filter_const; |
| OP_REQUIRES_OK(context, |
| context->GetAttr("is_filter_const", &is_filter_const)); |
| |
| if (bias_enabled) { |
| OP_REQUIRES_OK(context, |
| context->GetAttr("is_bias_const", &is_bias_const_)); |
| } |
| |
| OP_REQUIRES(context, is_filter_const, |
| errors::InvalidArgument("Filter must be a constant")); |
| } |
| |
| void Compute(OpKernelContext* context) override { |
| // Compute int32 output tensor |
| MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, |
| bias_enabled, false, is_depthwise, |
| native_format>::Compute(context); |
| |
| // Compute additional outputs: min/max scalars. |
| int bias_index_offset; |
| bias_index_offset = bias_enabled ? 1 : 0; |
| |
| const float min_input = |
| context->input(2 + bias_index_offset).flat<float>()(0); |
| const float max_input = |
| context->input(3 + bias_index_offset).flat<float>()(0); |
| |
| MklDnnShape output_min_mkl_shape, output_max_mkl_shape; |
| output_min_mkl_shape.SetMklTensor(false); |
| output_max_mkl_shape.SetMklTensor(false); |
| |
| Tensor* output_min = nullptr; |
| Tensor* output_max = nullptr; |
| if (std::is_same<Toutput, quint8>::value || |
| std::is_same<Toutput, qint8>::value) { |
| AllocateOutputSetMklShape(context, 1, &output_min, {}, |
| output_min_mkl_shape, native_format); |
| AllocateOutputSetMklShape(context, 2, &output_max, {}, |
| output_max_mkl_shape, native_format); |
| // This is the case the convolution and requantization are fused. |
| output_min->flat<float>()(0) = |
| context->input(6 + bias_index_offset).flat<float>()(0); |
| output_max->flat<float>()(0) = |
| context->input(7 + bias_index_offset).flat<float>()(0); |
| } else { |
| const Tensor& min_filter = context->input(4 + bias_index_offset); |
| const Tensor& max_filter = context->input(5 + bias_index_offset); |
| if (min_filter.dims() == 0) { |
| float min_output_value; |
| float max_output_value; |
| MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>( |
| min_input, max_input, min_filter.flat<float>()(0), |
| max_filter.flat<float>()(0), &min_output_value, &max_output_value); |
| AllocateOutputSetMklShape(context, 1, &output_min, {}, |
| output_min_mkl_shape, native_format); |
| AllocateOutputSetMklShape(context, 2, &output_max, {}, |
| output_max_mkl_shape, native_format); |
| output_min->flat<float>()(0) = min_output_value; |
| output_max->flat<float>()(0) = max_output_value; |
| } else { |
| size_t depth = min_filter.NumElements(); |
| AllocateOutputSetMklShape(context, 1, &output_min, |
| {static_cast<ptrdiff_t>(depth)}, |
| output_min_mkl_shape, native_format); |
| AllocateOutputSetMklShape(context, 2, &output_max, |
| {static_cast<ptrdiff_t>(depth)}, |
| output_max_mkl_shape, native_format); |
| MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>( |
| min_input, max_input, min_filter, max_filter, &output_min, |
| &output_max); |
| } |
| } |
| } |
| |
| protected: |
| void ExtendConvFwdParams(OpKernelContext* context, |
| MklConvFwdParams& params) override { |
| MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, |
| bias_enabled, false, is_depthwise, |
| native_format>::ExtendConvFwdParams(context, params); |
| |
| // When the output type is quint8, the output data id requantized |
| // into quint8. A post_op "output_scale" is added to do the conversion. |
| if (std::is_same<Toutput, quint8>::value || |
| std::is_same<Toutput, qint8>::value) { |
| int bias_index_offset; |
| bias_index_offset = bias_enabled ? 1 : 0; |
| |
| const float min_input = |
| context->input(2 + bias_index_offset).flat<float>()(0); |
| const float max_input = |
| context->input(3 + bias_index_offset).flat<float>()(0); |
| const Tensor& min_filter_vector = context->input(4 + bias_index_offset); |
| const Tensor& max_filter_vector = context->input(5 + bias_index_offset); |
| |
| // min_freezed_output and max_freezed_output are the actual range |
| // for the output. |
| const float min_freezed_output = |
| context->input(6 + bias_index_offset).flat<float>()(0); |
| const float max_freezed_output = |
| context->input(7 + bias_index_offset).flat<float>()(0); |
| |
| float int_output_limit = |
| std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f; |
| size_t depth = min_filter_vector.NumElements(); |
| const float* min_filter = min_filter_vector.flat<float>().data(); |
| const float* max_filter = max_filter_vector.flat<float>().data(); |
| std::vector<float> scales(depth); |
| float float_input_range = |
| std::max(std::abs(min_input), std::abs(max_input)); |
| float float_output_range = |
| std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); |
| const float int_const_scale_limit = |
| (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; |
| for (size_t i = 0; i < depth; ++i) { |
| // For simplicity and symmetry, we set filter range to be outer |
| // bounds of min_filter and max_filter. |
| float float_filter_range = |
| std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); |
| // To understand the scaling, please see mkl_requantize_ops_test. |
| scales[i] = int_output_limit * float_input_range * float_filter_range / |
| (int_const_scale_limit * float_output_range); |
| } |
| // we are creating a partial key here to use with primitive key caching to |
| // improve key creation performance. Instead of using actual values we are |
| // using the pointers for min/max_filter_vector, and this works since the |
| // filter vector here is a constant. |
| FactoryKeyCreator param_key; |
| param_key.AddAsKey<float>(min_input); |
| param_key.AddAsKey<float>(max_input); |
| param_key.AddAsKey<float>(min_freezed_output); |
| param_key.AddAsKey<float>(max_freezed_output); |
| param_key.AddAsKey<const float*>(min_filter); |
| param_key.AddAsKey<const float*>(max_filter); |
| params.post_op_params.push_back( |
| {"output_scale", dnnl::algorithm::undef, scales, param_key.GetKey()}); |
| } |
| } |
| |
| Tbias* GetBiasHandle(OpKernelContext* context, |
| std::shared_ptr<ConvFwdPd>& conv_fwd_pd, |
| const Tensor& bias_tensor) override { |
| if (!bias_enabled) { |
| return nullptr; |
| } |
| if (std::is_same<Tbias, qint32>::value) { |
| return static_cast<Tbias*>( |
| const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); |
| } |
| int bias_index_offset; |
| bias_index_offset = bias_enabled ? 1 : 0; |
| |
| const float min_input = |
| context->input(2 + bias_index_offset).flat<float>()(0); |
| const float max_input = |
| context->input(3 + bias_index_offset).flat<float>()(0); |
| const Tensor& min_filter_vector = context->input(4 + bias_index_offset); |
| const Tensor& max_filter_vector = context->input(5 + bias_index_offset); |
| const float* min_filter = min_filter_vector.flat<float>().data(); |
| const float* max_filter = max_filter_vector.flat<float>().data(); |
| |
| const float int_const_scale_limit = |
| (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; |
| // Re-scale bias if either of following 2 conditions are met: |
| // 1. Bias is not const; |
| // 2. Bias is const, but bias cache is empty (first iteration). |
| |
| size_t depth = min_filter_vector.NumElements(); |
| bool scales_are_valid = (depth == scales_.size()); |
| scales_.resize(depth); |
| for (size_t i = 0; i < depth; ++i) { |
| float tmp_scale = |
| int_const_scale_limit / |
| (std::max(std::abs(max_input), std::abs(min_input)) * |
| std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); |
| if (scales_are_valid && std::abs(tmp_scale - scales_[i]) > 1e-6) { |
| scales_are_valid = false; |
| } |
| scales_[i] = tmp_scale; |
| } |
| if (!is_bias_const_ || IsBiasCacheEmpty(context) || !scales_are_valid) { |
| dnnl::primitive_attr bias_attr; |
| if (depth == 1) { |
| bias_attr.set_output_scales(0, scales_); |
| } else { |
| bias_attr.set_output_scales(1, scales_); |
| } |
| |
| auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())}, |
| MklDnnType<Tbias>(), memory::format_tag::x); |
| void* bias_buf = static_cast<void*>( |
| const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); |
| if (!input_bias_) { |
| input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); |
| } else { |
| input_bias_->set_data_handle(bias_buf); |
| } |
| |
| if (!scaled_bias_buf_) |
| AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_, |
| conv_fwd_pd->bias_desc(), &scaled_bias_buf_); |
| if (!scaled_bias_) { |
| scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_); |
| } else { |
| scaled_bias_->set_data_handle(scaled_bias_buf_); |
| } |
| auto reorder_desc = |
| ReorderPd(this->cpu_engine_, input_bias_->get_desc(), |
| this->cpu_engine_, scaled_bias_->get_desc(), bias_attr); |
| CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, |
| this->cpu_engine_, context); |
| |
| Tbias* bias_data = |
| reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle()); |
| if (is_bias_const_) |
| CacheBias(context, conv_fwd_pd, bias_data, scaled_bias_); |
| |
| return bias_data; |
| } |
| return GetCachedBias(context); |
| } |
| |
| bool is_bias_const_; |
| Tensor cached_bias_data_ TF_GUARDED_BY(bias_cache_mu_); |
| |
| memory* input_bias_ = nullptr; |
| memory* scaled_bias_ = nullptr; |
| |
| Tensor scaled_bias_tensor_; |
| void* scaled_bias_buf_ = nullptr; |
| |
| private: |
| std::vector<float> scales_; |
| mutex bias_cache_mu_; |
| // Allocate tensors for cached bias data and |
| // cached bias memory descriptor (data format) |
| void AllocateTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, |
| Tensor** bias_tensor) { |
| DCHECK(bias_tensor); |
| TensorShape bias_tf_shape; |
| bias_tf_shape.AddDim( |
| (conv_prim_desc.bias_desc().get_size() / sizeof(Tbias))); |
| OP_REQUIRES_OK(context, |
| context->allocate_temp(DataTypeToEnum<Tbias>::value, |
| bias_tf_shape, &cached_bias_data_)); |
| *bias_tensor = &cached_bias_data_; |
| } |
| |
| // TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot |
| // be acquired before entering the function, since it is acquired |
| // inside the function. |
| inline bool IsBiasCacheEmpty(OpKernelContext* context) |
| TF_LOCKS_EXCLUDED(bias_cache_mu_) { |
| tf_shared_lock lock(bias_cache_mu_); |
| return (cached_bias_data_.NumElements() == 0); |
| } |
| |
| // Cache the converted bias in a tensor. |
| // Only one thread can execute this method at any given time. |
| void CacheBias(OpKernelContext* context, |
| const std::shared_ptr<ConvFwdPd>& conv_fwd_pd, |
| Tbias* bias_data, const memory* scaled_bias) |
| TF_LOCKS_EXCLUDED(bias_cache_mu_) { |
| mutex_lock lock(bias_cache_mu_); |
| |
| // If bias is already cached, there's nothing to do. |
| if (cached_bias_data_.NumElements() > 0) { |
| return; |
| } |
| |
| // Otherwise, cache bias |
| Tensor* bias_tensor_ptr = nullptr; |
| AllocateTensor(context, *conv_fwd_pd, &bias_tensor_ptr); |
| void* cached_bias_data = const_cast<void*>( |
| static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data())); |
| size_t cached_bias_data_size = scaled_bias->get_desc().get_size(); |
| memcpy(cached_bias_data, bias_data, cached_bias_data_size); |
| } |
| |
| Tbias* GetCachedBias(OpKernelContext* context) |
| TF_LOCKS_EXCLUDED(bias_cache_mu_) { |
| tf_shared_lock lock(bias_cache_mu_); |
| const Tensor& cached_bias_data = cached_bias_data_; |
| |
| return static_cast<Tbias*>( |
| const_cast<Tbias*>(cached_bias_data.flat<Tbias>().data())); |
| } |
| }; |
| |
| template <typename Device, typename Tinput, typename Tbias, typename Toutput, |
| typename Ttemp_output, bool bias_enabled, bool is_depthwise, |
| bool native_format = false> |
| class MklQuantizedConv2DReluOp |
| : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, |
| bias_enabled, is_depthwise, native_format> { |
| public: |
| virtual ~MklQuantizedConv2DReluOp() {} |
| |
| explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context) |
| : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, |
| bias_enabled, is_depthwise, native_format>( |
| context) {} |
| |
| protected: |
| void ExtendConvFwdParams(OpKernelContext* context, |
| MklConvFwdParams& params) override { |
| MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, |
| bias_enabled, is_depthwise, |
| native_format>::ExtendConvFwdParams(context, params); |
| |
| params.post_op_params.push_back( |
| {"activation", dnnl::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""}); |
| } |
| }; |
| |
| template <typename Device, typename Tinput, typename Tbias, typename Toutput, |
| typename Ttemp_output, bool bias_enabled, bool is_depthwise, |
| bool native_format = false> |
| class MklQuantizedConv2DSumReluOp |
| : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, |
| bias_enabled, is_depthwise, native_format> { |
| public: |
| virtual ~MklQuantizedConv2DSumReluOp() {} |
| |
| explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context) |
| : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, |
| bias_enabled, is_depthwise, native_format>( |
| context) {} |
| |
| protected: |
| void ExtendConvFwdParams(OpKernelContext* context, |
| MklConvFwdParams& params) override { |
| MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output, |
| bias_enabled, is_depthwise, |
| native_format>::ExtendConvFwdParams(context, params); |
| // Calculate the scale (beta in oneDNN API term) for sum |
| if (std::is_same<Toutput, quint8>::value) { |
| int summand_idx = native_format ? context->num_inputs() - 1 - 2 |
| : context->num_inputs() / 2 - 1 - 2; |
| DataType summand_type = this->input_type(summand_idx); |
| bool summand_condition = |
| (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); |
| CHECK((summand_condition)); |
| int bias_index_offset = bias_enabled ? 1 : 0; |
| const float min_freezed_output = |
| context->input(6 + bias_index_offset).flat<float>()(0); |
| const float max_freezed_output = |
| context->input(7 + bias_index_offset).flat<float>()(0); |
| const float min_freezed_summand = |
| context->input(9 + bias_index_offset).flat<float>()(0); |
| const float max_freezed_summand = |
| context->input(10 + bias_index_offset).flat<float>()(0); |
| |
| float scale_output = |
| std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); |
| float scale_summand = std::max(std::abs(min_freezed_summand), |
| std::abs(max_freezed_summand)); |
| // if summand_type is also DT_QUINT8 as the scale_output, |
| // the scaling factor of 255.0f cancels each other and thus is avoided. |
| // If it is not then it is DT_INT8 and is scaled appropriately. |
| if (summand_type == DT_QUINT8) { |
| params.post_op_params.push_back({"sum", |
| dnnl::algorithm::undef, |
| {scale_summand / scale_output}, |
| ""}); |
| } else { |
| params.post_op_params.push_back( |
| {"sum", |
| dnnl::algorithm::undef, |
| {255.0f * scale_summand / (scale_output * 127.0f)}, |
| ""}); |
| } |
| } else { |
| params.post_op_params.push_back( |
| {"sum", dnnl::algorithm::undef, {1.0}, ""}); |
| } |
| params.post_op_params.push_back( |
| {"activation", dnnl::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""}); |
| } |
| |
| void AllocateOutputTensor(OpKernelContext* context, |
| const ConvFwdPd& conv_prim_desc, |
| const memory::dims& output_dims_mkl_order, |
| MklTensorFormat output_tf_format, |
| MklDnnShape* output_mkl_shape, |
| Tensor** output_tensor) override { |
| int summand_idx = native_format ? context->num_inputs() - 1 |
| : context->num_inputs() / 2 - 1; |
| if (std::is_same<Toutput, quint8>::value) { |
| summand_idx -= 2; |
| DataType summand_type = this->input_type(summand_idx); |
| bool summand_condition = |
| (summand_type == DT_QINT8) || (summand_type == DT_QUINT8); |
| CHECK((summand_condition)); |
| Tensor& summand = const_cast<Tensor&>(MklGetInput(context, summand_idx)); |
| MklDnnShape summand_mkl_shape; |
| GetMklShape(context, summand_idx, &summand_mkl_shape, native_format); |
| auto dst_md = summand_mkl_shape.GetMklLayout(); |
| |
| // TODO(intel-tf): Handle both non-MKL and MKL tensors |
| if (summand_type == DT_QINT8) { |
| OP_REQUIRES_OK( |
| context, summand.BitcastFrom(summand, DT_QUINT8, summand.shape())); |
| dst_md.data.data_type = |
| static_cast<dnnl_data_type_t>(MklDnnType<Toutput>()); |
| summand_mkl_shape.SetMklLayout(&dst_md); |
| summand_mkl_shape.SetElemType(MklDnnType<Toutput>()); |
| } |
| // TODO(intel-tf): Support cases when summand cannot be forwarded. |
| OP_REQUIRES(context, |
| native_format |
| ? context->forward_input_to_output_with_shape( |
| summand_idx, 0, summand.shape(), output_tensor) |
| : ForwardMklTensorInToOutWithMklShape( |
| context, summand_idx, 0, output_tensor, |
| summand_mkl_shape, false), |
| errors::InvalidArgument( |
| "Summand cannot be forwarded in the current fusion.")); |
| return; |
| } |
| MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32, |
| bias_enabled, false, false, |
| native_format>::AllocateOutputTensor(context, conv_prim_desc, |
| output_dims_mkl_order, |
| output_tf_format, |
| output_mkl_shape, |
| output_tensor); |
| const Tensor& summand = MklGetInput(context, summand_idx); |
| if (summand.dtype() != DT_FLOAT) |
| TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION, |
| "Current fusion requires summand to be float")); |
| MklDnnShape summand_mkl_shape; |
| GetMklShape(context, summand_idx, &summand_mkl_shape, native_format); |
| // We need to compute scale for the summand |
| int bias_index_offset = bias_enabled ? 1 : 0; |
| const float min_input = |
| context->input(2 + bias_index_offset).flat<float>()(0); |
| const float max_input = |
| context->input(3 + bias_index_offset).flat<float>()(0); |
| const Tensor& min_filter_vector = context->input(4 + bias_index_offset); |
| const Tensor& max_filter_vector = context->input(5 + bias_index_offset); |
| const float* min_filter = min_filter_vector.flat<float>().data(); |
| const float* max_filter = max_filter_vector.flat<float>().data(); |
| |
| const float int_const_scale_limit = |
| (std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0; |
| size_t depth = min_filter_vector.NumElements(); |
| std::vector<float> scales(depth); |
| for (size_t i = 0; i < depth; ++i) { |
| // TODO(intel-tf): scale factors for UINT8(inputs) & INT8(weights) are |
| // done regularly. A Cleaner design to address all mapping in one |
| // function needs to be implemented in future which also supports other |
| // quantized type mapping in future. |
| scales[i] = int_const_scale_limit / |
| (std::max(std::abs(max_input), std::abs(min_input)) * |
| std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); |
| } |
| dnnl::primitive_attr reorder_attr; |
| if (depth == 1) { |
| reorder_attr.set_output_scales(0, scales); |
| } else { |
| reorder_attr.set_output_scales(2, scales); |
| } |
| auto summand_md = |
| summand_mkl_shape.IsMklTensor() |
| ? summand_mkl_shape.GetMklLayout() |
| : memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(), |
| memory::format_tag::nhwc); |
| void* summand_buf = |
| static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data())); |
| void* dst_buf = |
| static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data()); |
| summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf)); |
| dst_.reset( |
| new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf)); |
| auto reorder_desc = |
| ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_, |
| conv_prim_desc.dst_desc(), reorder_attr); |
| CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, |
| context); |
| } |
| |
| std::shared_ptr<dnnl::memory> summand_; |
| std::shared_ptr<dnnl::memory> dst_; |
| }; |
| |
| // Base class for fused convolution forward operations |
| template <typename Device, typename Tinput, typename Tfilter, typename Tbias, |
| typename Toutput, typename Ttemp_output, typename Tpadding, |
| bool pad_enabled, bool native_format> |
| class MklFusedConv3DOp |
| : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, |
| Tpadding, false, false, false, native_format> { |
| public: |
| explicit MklFusedConv3DOp(OpKernelConstruction* context) |
| : MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, |
| Tpadding, false, false, false, native_format>(context) { |
| // Since we came here through the registration of _MklFusedConv3D, get |
| // all information from 'fused_ops' and 'num_args' |
| std::vector<string> fused_ops; |
| OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); |
| |
| int num_args; |
| OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); |
| |
| std::vector<int> padding_list; |
| OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list)); |
| if (padding_list.empty()) { |
| OP_REQUIRES(context, !fused_ops.empty(), |
| errors::InvalidArgument("Fused Conv3D must have at least one " |
| "fused op when Pad is not fused.")); |
| if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == |
| fused_ops.end()) { |
| OP_REQUIRES(context, num_args == 1, |
| errors::InvalidArgument( |
| "Fused Conv3D must have one extra argument: bias.")); |
| } else if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") == |
| fused_ops.end() && |
| std::find(fused_ops.begin(), fused_ops.end(), "Add") == |
| fused_ops.end()) { |
| OP_REQUIRES( |
| context, num_args == 2, |
| errors::InvalidArgument( |
| "Fused Conv3D must have two extra arguments: bias and add.")); |
| } |
| } |
| |
| if (fused_ops == std::vector<string>{"BiasAdd"}) { |
| this->set_fuse_biasadd(true); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"}) { |
| this->set_fuse_biasadd(true); |
| float leakyrelu_alpha; |
| OP_REQUIRES_OK(context, |
| context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, |
| leakyrelu_alpha); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_bounded_relu, |
| 6.0); |
| } else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_elu, 1.0); |
| } else if (fused_ops == |
| std::vector<string>{"BiasAdd", "Add", "LeakyRelu"}) { |
| this->set_fuse_biasadd(true); |
| this->set_fuse_add(true); |
| float leakyrelu_alpha; |
| OP_REQUIRES_OK(context, |
| context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); |
| this->set_fuse_activation(true, dnnl::algorithm::eltwise_relu, |
| leakyrelu_alpha); |
| } else { |
| if (padding_list.empty()) { |
| OP_REQUIRES(context, false, |
| errors::Unimplemented("Fusion is not implemented: [", |
| absl::StrJoin(fused_ops, ","), "]")); |
| } |
| } |
| } |
| |
| virtual ~MklFusedConv3DOp() {} |
| }; |
| |
| #define REGISTER_MKL_KERNEL(op, kernel, input_type, bias_type, output_type, \ |
| accu_type, has_bias, is_depthwise, is_native) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name(op) \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<input_type>("Tinput") \ |
| .TypeConstraint<qint8>("Tfilter") BIAS_TYPE_CONSTRAINT(bias_type) \ |
| .TypeConstraint<output_type>("out_type") LABEL, \ |
| kernel TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ |
| accu_type, has_bias, is_depthwise, is_native)); |
| |
| #define REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, bias_type, \ |
| output_type, accu_type, has_bias, \ |
| is_depthwise, is_native) \ |
| REGISTER_MKL_KERNEL(op, kernel, qint8, bias_type, output_type, accu_type, \ |
| has_bias, is_depthwise, is_native); \ |
| REGISTER_MKL_KERNEL(op, kernel, quint8, bias_type, output_type, accu_type, \ |
| has_bias, is_depthwise, is_native); |
| |
| #define REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(op, kernel, input_type, \ |
| output_type, accu_type, has_bias, \ |
| is_depthwise, is_native) \ |
| REGISTER_MKL_KERNEL(op, kernel, input_type, qint32, output_type, accu_type, \ |
| has_bias, is_depthwise, is_native); \ |
| REGISTER_MKL_KERNEL(op, kernel, input_type, float, output_type, accu_type, \ |
| has_bias, is_depthwise, is_native); |
| |
| #define REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( \ |
| op, kernel, output_type, accu_type, has_bias, is_depthwise, is_native) \ |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, qint32, output_type, \ |
| accu_type, has_bias, is_depthwise, \ |
| is_native); \ |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES(op, kernel, float, output_type, \ |
| accu_type, has_bias, is_depthwise, \ |
| is_native); |
| |
| #define LABEL |
| #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ |
| accu_type, has_bias, is_depthwise, is_native) |
| #define BIAS_TYPE_CONSTRAINT(bias_type) |
| |
| REGISTER_MKL_KERNEL("QuantizedConv2D", NoOp, quint8, float, qint32, qint32, |
| false, false, false); |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("QuantizedConv2DWithBias", NoOp, float, |
| qint32, qint32, false, false, false); |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("QuantizedConv2DWithBiasAndRelu", NoOp, |
| float, qint32, qint32, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedConv2DWithBiasSumAndRelu", NoOp, quint8, float, |
| qint32, qint32, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedConv2DAndRequantize", NoOp, quint8, float, qint8, |
| qint8, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedConv2DPerChannel", NoOp, quint8, float, qint32, |
| qint32, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedConv2DAndRelu", NoOp, quint8, float, qint32, |
| qint32, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedConv2DAndReluAndRequantize", NoOp, quint8, float, |
| quint8, quint8, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2D", NoOp, quint8, float, qint32, |
| qint32, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2DWithBias", NoOp, quint8, float, |
| qint32, qint32, false, false, false); |
| REGISTER_MKL_KERNEL("QuantizedDepthwiseConv2DWithBiasAndRelu", NoOp, quint8, |
| float, qint32, qint32, false, false, false); |
| #undef BIAS_TYPE_CONSTRAINT |
| |
| #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias") |
| REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( |
| "QuantizedConv2DWithBiasAndRequantize", NoOp, qint8, qint8, false, false, |
| false); |
| REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( |
| "QuantizedConv2DWithBiasAndReluAndRequantize", NoOp, quint8, quint8, false, |
| false, false); |
| REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( |
| "QuantizedConv2DWithBiasSumAndReluAndRequantize", NoOp, quint8, quint8, |
| quint8, false, false, false); |
| REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( |
| "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize", NoOp, quint8, |
| quint8, qint8, false, false, false); |
| REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( |
| "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", NoOp, quint8, |
| quint8, quint8, false, false, false); |
| #undef BIAS_TYPE_CONSTRAINT |
| #undef TEMPLATE_ARGS |
| #undef LABEL |
| |
| #define LABEL .Label(mkl_op_registry::kMklQuantizedOpLabel) |
| #define TEMPLATE_ARGS(CPUDevice, input_type, bias_type, output_type, \ |
| accu_type, has_bias, is_depthwise, is_native) \ |
| <CPUDevice, input_type, bias_type, output_type, accu_type, has_bias, \ |
| is_depthwise, is_native> |
| #define BIAS_TYPE_CONSTRAINT(bias_type) |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2D", MklQuantizedConv2DOp, |
| float, qint32, qint32, false, false, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedConv2DPerChannel", MklQuantizedConv2DOp, |
| quint8, float, qint32, qint32, false, false, true); |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2DWithBias", |
| MklQuantizedConv2DOp, float, qint32, qint32, |
| true, false, true); |
| REGISTER_MKL_KERNEL_ALL_INPUT_TYPES("_MklQuantizedConv2DWithBiasAndRelu", |
| MklQuantizedConv2DReluOp, float, qint32, |
| qint32, true, false, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedConv2DWithBiasSumAndRelu", |
| MklQuantizedConv2DSumReluOp, quint8, float, qint32, qint32, |
| true, false, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndRequantize", MklQuantizedConv2DOp, |
| quint8, float, qint8, qint8, false, false, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndRelu", MklQuantizedConv2DReluOp, |
| quint8, float, qint32, qint32, false, false, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedConv2DAndReluAndRequantize", |
| MklQuantizedConv2DReluOp, quint8, float, quint8, quint8, |
| false, false, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2D", MklQuantizedConv2DOp, |
| quint8, float, qint32, qint32, false, true, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2DWithBias", |
| MklQuantizedConv2DOp, quint8, float, qint32, qint32, true, |
| true, true); |
| REGISTER_MKL_KERNEL("_MklQuantizedDepthwiseConv2DWithBiasAndRelu", |
| MklQuantizedConv2DReluOp, quint8, float, qint32, qint32, |
| true, true, true); |
| #undef BIAS_TYPE_CONSTRAINT |
| |
| #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias") |
| REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( |
| "_MklQuantizedConv2DWithBiasAndRequantize", MklQuantizedConv2DOp, qint8, |
| qint8, true, false, true); |
| REGISTER_MKL_KERNEL_ALL_INPUT_AND_BIAS_TYPES( |
| "_MklQuantizedConv2DWithBiasAndReluAndRequantize", MklQuantizedConv2DReluOp, |
| quint8, quint8, true, false, true); |
| REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( |
| "_MklQuantizedConv2DWithBiasSumAndReluAndRequantize", |
| MklQuantizedConv2DSumReluOp, quint8, quint8, quint8, true, false, true); |
| REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( |
| "_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize", |
| MklQuantizedConv2DSumReluOp, quint8, quint8, qint8, true, false, true); |
| REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( |
| "_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize", |
| MklQuantizedConv2DReluOp, quint8, quint8, quint8, true, true, true); |
| #undef BIAS_TYPE_CONSTRAINT |
| #undef TEMPLATE_ARGS |
| #undef LABEL |
| |
| // Register NoOp kernel for ops that will be rewritten to the _Mkl* version |
| |
| #define REGISTER_NO_OP_CPU_2D_DEPTHWISE(T) \ |
| REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T"), \ |
| NoOp); |
| |
| TF_CALL_float(REGISTER_NO_OP_CPU_2D_DEPTHWISE); |
| TF_CALL_bfloat16(REGISTER_NO_OP_CPU_2D_DEPTHWISE); |
| |
| // Register 2D operations |
| #define REGISTER_MKL_CPU_2D(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklConv2DWithBias") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("__MklDummyConv2DWithBias") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklDummyOp<CPUDevice, T>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklPadWithConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int32>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklPadWithConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int64_t>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("__MklDummyPadWithConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int32>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklDummyOp<CPUDevice, T>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeConv2DWithBias") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativePadWithConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int32>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativePadWithConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int64_t>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, true>); |
| |
| TF_CALL_float(REGISTER_MKL_CPU_2D); |
| TF_CALL_bfloat16(REGISTER_MKL_CPU_2D); |
| |
| #define REGISTER_MKL_CPU_2D_DEPTHWISE(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklDepthwiseConv2dNative") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedDepthwiseConv2dNative") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \ |
| true, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeFusedDepthwiseConv2dNative") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklFusedDepthwiseConvOp<CPUDevice, T, T, T, T, T, int32, false, true, \ |
| true, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeDepthwiseConv2dNative") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, true>); |
| |
| TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); |
| TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE); |
| |
| // Note we are registering _MklFusedConv2D. |
| // We check the fused_ops attributes to decide if bias is enabled or not. |
| #define REGISTER_MKL_CPU_2D_FUSED(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklPadWithFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<int32>("Tpaddings") \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklPadWithFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int64_t>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("__MklDummyPadWithFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int32>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklDummyOp<CPUDevice, T>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, false, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativePadWithFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<int32>("Tpaddings") \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklFusedConvOp<CPUDevice, T, T, T, T, T, int32, true, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativePadWithFusedConv2D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .TypeConstraint<int64_t>("Tpaddings") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklFusedConvOp<CPUDevice, T, T, T, T, T, int64, true, true>); |
| |
| TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED); |
| TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED); |
| |
| // Register 3D operations |
| #define REGISTER_MKL_CPU_3D(T) \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklConv3D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeConv3D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>); \ |
| REGISTER_KERNEL_BUILDER( \ |
| Name("_MklNativeFusedConv3D") \ |
| .Device(DEVICE_CPU) \ |
| .TypeConstraint<T>("T") \ |
| .Label(mkl_op_registry::kMklNameChangeOpLabel), \ |
| MklFusedConv3DOp<CPUDevice, T, T, T, T, T, int32, false, true>); |
| TF_CALL_float(REGISTER_MKL_CPU_3D); |
| TF_CALL_bfloat16(REGISTER_MKL_CPU_3D); |
| |
| REGISTER_KERNEL_BUILDER( |
| Name("_FusedConv3D").Device(DEVICE_CPU).TypeConstraint<float>("T"), NoOp); |
| REGISTER_KERNEL_BUILDER( |
| Name("_FusedConv3D").Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), |
| NoOp); |
| } // namespace tensorflow |
| #endif // INTEL_MKL |