blob: aa4254de20b3375db531ba9f90d8b8e832dcef71 [file] [log] [blame]
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL
#include <algorithm>
#include <vector>
#include "mkldnn.hpp"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
using mkldnn::convolution_backward_weights;
using mkldnn::memory;
using mkldnn::prop_kind;
using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklConvBwdFilterParams {
memory::dims src_dims;
memory::dims diff_filter_dims;
memory::dims diff_bias_dims;
memory::dims diff_dst_dims;
memory::dims strides;
memory::dims dilations;
memory::dims padding_left;
memory::dims padding_right;
padding_kind padding;
MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims,
memory::dims diff_bias_dims,
memory::dims diff_dst_dims, memory::dims strides,
memory::dims dilations, memory::dims padding_left,
memory::dims padding_right, padding_kind padding)
: src_dims(src_dims),
diff_filter_dims(diff_filter_dims),
diff_bias_dims(diff_bias_dims),
diff_dst_dims(diff_dst_dims),
strides(strides),
dilations(dilations),
padding_left(padding_left),
padding_right(padding_right),
padding(padding) {}
};
template <typename T>
class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
explicit MklConvBwdFilterPrimitive(
const MklConvBwdFilterParams& convBwdFilterDims)
: cpu_engine_(engine::cpu, 0) {
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_bwd_filter == nullptr) {
Setup(convBwdFilterDims);
}
}
~MklConvBwdFilterPrimitive() {}
// Convolution backward weights with bias
// src_data: input data buffer of src
// diff_filter_data: output data buffer of diff_filter
// diff_bias_data: output data buffer of diff_bias
// diff_dst_data: input data buffer of diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_bias_data, const T* diff_dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_filter_data)));
context_.diff_bias_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_bias_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
context_.src_mem->set_data_handle(DummyData);
context_.diff_filter_mem->set_data_handle(DummyData);
context_.diff_bias_mem->set_data_handle(DummyData);
context_.diff_dst_mem->set_data_handle(DummyData);
return;
}
// Convolution backward weights without bias
// src_data: input data buffer of src
// diff_filter_data: output data buffer of diff_filter
// diff_dst_data: input data buffer of diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_filter_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
context_.src_mem->set_data_handle(DummyData);
context_.diff_filter_mem->set_data_handle(DummyData);
context_.diff_dst_mem->set_data_handle(DummyData);
return;
}
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetDiffDstMemoryFormat() const {
return context_.diff_dst_fmt;
}
memory::format GetDiffFilterMemoryFormat() const {
return context_.diff_filter_fmt;
}
// convolution primitive
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
GetPrimitiveDesc() const {
return context_.bwd_filter_pd;
}
private:
// Primitive reuse context for Conv2D bwd filter op
struct ConvBwdFilterContext {
// expected memory format for this primitive instance
memory::format src_fmt;
memory::format diff_dst_fmt;
memory::format diff_filter_fmt;
// convolution bwd input primitive
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
bwd_filter_pd;
std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
// MKLDNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> diff_filter_mem;
std::shared_ptr<mkldnn::memory> diff_bias_mem;
std::shared_ptr<mkldnn::memory> diff_dst_mem;
// desc & prmitive desc
std::shared_ptr<mkldnn::convolution_backward_weights::desc> bwd_filter_desc;
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
// memory desc: forward & backward can share same memory desc
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> diff_filter_md;
std::shared_ptr<mkldnn::memory::desc> diff_bias_md;
std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
// MKL pipeline
std::shared_ptr<mkldnn::stream> bwd_filter_stream;
std::vector<mkldnn::primitive> bwd_filter_primitives;
ConvBwdFilterContext()
: src_fmt(memory::format::any),
diff_dst_fmt(memory::format::any),
diff_filter_fmt(memory::format::any),
src_mem(nullptr),
diff_filter_mem(nullptr),
diff_bias_mem(nullptr),
diff_dst_mem(nullptr),
bwd_filter_desc(nullptr),
fwd_desc(nullptr),
fwd_pd(nullptr),
src_md(nullptr),
diff_filter_md(nullptr),
diff_bias_md(nullptr),
diff_dst_md(nullptr),
bwd_filter_stream(nullptr) {}
};
// Setup Conv2d backward filter (weights) primitives.
void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
// create memory descriptors for convolution data w/ no specified format
context_.src_md.reset(new memory::desc(
{convBwdFilterDims.src_dims}, MklDnnType<T>(), memory::format::any));
context_.diff_dst_md.reset(
new memory::desc({convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(),
memory::format::any));
context_.diff_filter_md.reset(
new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType<T>(),
memory::format::any));
if (!convBwdFilterDims.diff_bias_dims.empty())
context_.diff_bias_md.reset(
new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
memory::format::x));
// create a convolution
if (!convBwdFilterDims.diff_bias_dims.empty()) {
context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
convolution_direct, *context_.src_md, *context_.diff_filter_md,
*context_.diff_bias_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
} else {
context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
convolution_direct, *context_.src_md, *context_.diff_filter_md,
*context_.diff_dst_md, convBwdFilterDims.strides,
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
}
// create fwd primitive_desc
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
// create backward conv primitive_desc
context_.bwd_filter_pd.reset(
new convolution_backward_weights::primitive_desc(
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
// store the expected memory format
auto bwd_filter_pd = context_.bwd_filter_pd.get();
context_.src_fmt = static_cast<mkldnn::memory::format>(
bwd_filter_pd->src_primitive_desc().desc().data.format);
context_.diff_filter_fmt = static_cast<mkldnn::memory::format>(
bwd_filter_pd->diff_weights_primitive_desc().desc().data.format);
context_.diff_dst_fmt = static_cast<mkldnn::memory::format>(
bwd_filter_pd->diff_dst_primitive_desc().desc().data.format);
// create memory primitive based on dummy data
context_.src_mem.reset(
new memory(bwd_filter_pd->src_primitive_desc(), DummyData));
context_.diff_filter_mem.reset(
new memory(bwd_filter_pd->diff_weights_primitive_desc(), DummyData));
context_.diff_dst_mem.reset(
new memory(bwd_filter_pd->diff_dst_primitive_desc(), DummyData));
// create convolution primitive and add it to net
if (!convBwdFilterDims.diff_bias_dims.empty()) {
context_.diff_bias_mem.reset(
new memory({{{convBwdFilterDims.diff_bias_dims},
MklDnnType<T>(),
memory::format::x},
cpu_engine_},
DummyData));
context_.conv_bwd_filter.reset(new convolution_backward_weights(
*context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
*context_.diff_filter_mem, *context_.diff_bias_mem));
} else {
context_.conv_bwd_filter.reset(new convolution_backward_weights(
*context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
*context_.diff_filter_mem));
}
context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter);
}
struct ConvBwdFilterContext context_;
engine cpu_engine_;
};
template <typename T>
class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklConvBwdFilterPrimitive<T>* Get(
const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) {
MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
if (do_not_cache) { /* Create new primitive always */
conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
} else {
// look into the pool for reusable primitive
conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
convBwdFilterDims));
if (conv_bwd_filter == nullptr) {
conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
convBwdFilterDims, conv_bwd_filter);
}
}
return conv_bwd_filter;
}
private:
MklConvBwdFilterPrimitiveFactory() {}
~MklConvBwdFilterPrimitiveFactory() {}
static MklConvBwdFilterPrimitiveFactory& GetInstance() {
static MklConvBwdFilterPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
string prefix = "conv_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
key_creator.AddAsKey(convBwdFilterDims.diff_filter_dims);
key_creator.AddAsKey(convBwdFilterDims.diff_bias_dims);
key_creator.AddAsKey(convBwdFilterDims.diff_dst_dims);
key_creator.AddAsKey(convBwdFilterDims.strides);
key_creator.AddAsKey(convBwdFilterDims.dilations);
key_creator.AddAsKey(convBwdFilterDims.padding_left);
key_creator.AddAsKey(convBwdFilterDims.padding_right);
return key_creator.GetKey();
}
MklPrimitive* GetConvBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims,
MklPrimitive* op) {
string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
};
template <typename Device, class T, bool bias_enabled, bool is_depthwise>
class MklConvCustomBackpropFilterOp
: public MklConvBackpropCommonOp<Device, T, is_depthwise> {
public:
explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context)
: MklConvBackpropCommonOp<Device, T, is_depthwise>(context) {}
~MklConvCustomBackpropFilterOp() {}
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> src(&cpu_engine_);
MklDnnData<T> diff_dst(&cpu_engine_);
MklDnnData<T> diff_filter(&cpu_engine_); // output
// This flag indicates Conv2D or Conv3D
bool is_conv2d = (this->strides_.size() == 4);
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
GetMklShape(context, kInputIdx, &src_mkl_shape);
GetMklShape(context, kFilterIdx, &filter_mkl_shape);
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
// Allow operator-specific sanity checking of shapes.
ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
// Allow operator-specific generation of shapes.
// E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
// tensor containing shape of filter. So filter.shape() is not
// a correct way to get filter shape. These operator-specific calls
// allow this class to handle this case.
TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
// Corner cases: output with 0 elements and 0 batch size.
Tensor* diff_filter_tensor = nullptr;
if (src_tf_shape.num_elements() == 0 ||
filter_tf_shape.num_elements() == 0 ||
diff_dst_tf_shape.num_elements() == 0) {
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
TensorShape diff_filter_tf_shape =
GetOutputTfShape(src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
const int kOutputIdx = 0;
AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
diff_filter_tf_shape, diff_filter_mkl_shape);
CHECK_NOTNULL(diff_filter_tensor);
// if output tensor has more than 0 elements, we need to 0 them out.
auto diff_filter_data = diff_filter_tensor->flat<T>().data();
for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) {
diff_filter_data[i] = static_cast<T>(0);
}
return;
}
// By default, all dims are in MKL order. Only dims in TF order
// are those with prefix tf_order.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
memory::dims padding_left, padding_right, dilations, strides,
fwd_dst_dims;
memory::dims fwd_dst_dims_tf_order;
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
this->data_format_, this->dilations_);
conv_utl.GetConvFwdSizesInMklOrder(
src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
&strides, &dilations, &fwd_dst_dims_tf_order, &fwd_dst_dims,
&padding_left, &padding_right, false, is_depthwise);
if (!context->status().ok()) return;
auto tf_fmt = is_conv2d
? TFDataFormatToMklDnnDataFormat(this->data_format_)
: TFDataFormatToMklDnn3DDataFormat(this->data_format_);
auto fwd_src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(fwd_src_dims, MklDnnType<T>(), tf_fmt);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
auto diff_dst_md =
diff_dst_mkl_shape.IsMklTensor()
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt);
memory::dims diff_bias_dims = {};
int64 depth = 0;
if (bias_enabled) {
TensorShape obp_tf_shape = GetTfShape(context, 2);
depth = (this->data_format_ == FORMAT_NCHW)
? obp_tf_shape.dim_size(1)
: obp_tf_shape.dim_size(is_conv2d ? 3 : 4);
diff_bias_dims = {static_cast<int>(depth)};
}
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
MklConvBwdFilterParams convBwdFilterDims(
fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
// MKL DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env
// variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
convBwdFilterDims, do_not_cache);
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
// diff_filter
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
if (is_conv2d) {
if (!is_depthwise) {
// Conv2D: output_dims_mkl_order is in OIHW format.
TensorShape diff_filter_tf_shape(
{bwd_output_dims[MklDnnDims::Dim_H],
bwd_output_dims[MklDnnDims::Dim_W],
bwd_output_dims[MklDnnDims::Dim_I],
bwd_output_dims[MklDnnDims::Dim_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape,
diff_filter_mkl_shape);
} else {
// Depthwise Conv2d: bwd_output_dims is GOIHW format
// | TensorFlow | MKLDNN
// ----------------------------------------------------------------
// filter_out_depth | depth_multiplier | depth_multiplier *
// | | group_count
// ----------------------------------------------------------------
// filter_in_depth | in_depth | in_depth / group_count
// For depthwise convolution, we have group_count == in_depth.
// So here G = original I, and I = 1.
// And the GOIHW is mkldnn format, here we try to extract the TF
// format, TF format is HWIO, as G = original I, so here is HWGO.
TensorShape diff_filter_tf_shape(
{bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G],
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape,
diff_filter_mkl_shape);
}
} else {
// Conv3D: output_dims_mkl_order is in OIDHW format.
TensorShape diff_filter_tf_shape(
{bwd_output_dims[MklDnnDims3D::Dim3d_D],
bwd_output_dims[MklDnnDims3D::Dim3d_H],
bwd_output_dims[MklDnnDims3D::Dim3d_W],
bwd_output_dims[MklDnnDims3D::Dim3d_I],
bwd_output_dims[MklDnnDims3D::Dim3d_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape, diff_filter_mkl_shape);
}
Tensor* diff_bias_tensor = nullptr;
if (bias_enabled) {
TensorShape diff_bias_shape({depth});
AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor);
}
// check if src and diff_dst need reorder
T* src_data = nullptr;
if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* diff_dst_data = nullptr;
if (diff_dst_md.data.format !=
conv_bwd_filter->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
}
// For backward filter, convert diff_filter back to Tensorflow layout
// Here we prepare to reorder op memory back to user memory
bool diff_filter_reorder_required = false;
T* diff_filter_data = nullptr;
if (GetOutputFormat(tf_fmt) !=
conv_bwd_filter->GetDiffFilterMemoryFormat()) {
// Allocate diff filter tensor as Tensorflow layout
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
diff_filter_tensor);
diff_filter_reorder_required = true;
diff_filter.PrepareReorderToUserMemIfReq(
bwd_filter_pd->diff_weights_primitive_desc());
diff_filter_data =
static_cast<T*>(diff_filter.GetOpMem().get_data_handle());
} else {
diff_filter_data = static_cast<T*>(
const_cast<T*>(diff_filter_tensor->flat<T>().data()));
}
// Execute convolution filter bwd
if (bias_enabled) {
T* diff_bias_data =
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
diff_dst_data);
} else {
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
}
// Reorder diff_filter back to Tensorflow layout if necessary
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
}
// delete primitive since it is not cached.
if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
}
}
private:
const int kInputIndex_Filter = 1;
const int kInputIndex_InputSizes = 0;
const int kDilationH = 0, kDilationW = 1;
engine cpu_engine_ = engine(engine::cpu, 0);
// Validate input shapes.
// Function asserts that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
CHECK(!filter_mkl_shape.IsMklTensor())
<< "ConvBackpropFilter: filter should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
size_t input_idx = 0;
return GetTfShape(context, input_idx);
}
// Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
TensorShape filter_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true);
CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(),
&filter_tf_shape)
.ok(),
true);
return filter_tf_shape;
}
// Get Tensorflow shape of output tensor (diff_filter),
// which is same as shape of filter.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
return filter_shape;
}
// Get the shape of output (diff_filter) in MKL-DNN order.
// Computes shape of output from input shape (fwd_input_dims)
// and filter shape (fwd_filter_dims).
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_filter_dims;
}
// Output layout is Tensorflow's filter layout
// Conv2D: HWIO; Conv3D: DHWIO; Depthwise Conv: HWIGO
memory::format GetOutputFormat(const memory::format data_format) {
return is_depthwise
? memory::format::hwigo
: ((this->strides_.size() == 4) ? memory::format::hwio
: memory::format::dhwio);
}
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
const convolution_backward_weights::primitive_desc& conv_pd,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
// For BackpropFilter, we convert the output tensor back in Tensorflow
// layout. Because typically, BackpropFilter is the last operator in the
// graph that emit filter gradient that is provided to ApplyGradient
// method to update the filter. But it may be possible to eliminate this
// by forwarding filter in MKL layout if we support ApplyGradient method
// for MKL layout propagation.
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
// output_dims_mkl_order is in OIHW format.
// Allocate shape of TF tensor in HWIO format.
TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H],
output_dims_mkl_order[MklDnnDims::Dim_W],
output_dims_mkl_order[MklDnnDims::Dim_I],
output_dims_mkl_order[MklDnnDims::Dim_O]});
AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
output_mkl_shape);
}
// Allocate tensor for bias grad
void AllocateBiasGradTensor(OpKernelContext* context,
const TensorShape& bias_grad_shape,
Tensor** bias_grad_tensor) {
CHECK_NOTNULL(bias_grad_tensor);
MklDnnShape bias_grad_mkl_shape;
bias_grad_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape,
bias_grad_mkl_shape);
}
};
#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv2DBackpropFilterWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, true, false>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklDepthwiseConv2dNativeBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, true>); \
REGISTER_KERNEL_BUILDER( \
Name("__MklDummyConv2DBackpropFilterWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklDummyOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklConv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropFilterOp<CPUDevice, T, false, false>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
} // namespace tensorflow
#endif // INTEL_MKL