blob: c2c33d91628ae39a7c023ca3dd5f921af4da2209 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
#define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
#ifdef INTEL_MKL
#include <memory>
#include <string>
#include <vector>
#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
using mkldnn::memory;
using mkldnn::pooling_backward;
using mkldnn::pooling_forward;
using mkldnn::stream;
namespace tensorflow {
using mkldnn::memory;
using mkldnn::pooling_avg;
using mkldnn::pooling_avg_exclude_padding;
using mkldnn::pooling_avg_include_padding;
using mkldnn::pooling_max;
using mkldnn::prop_kind;
struct MklPoolingParams {
memory::dims src_dims;
memory::dims dst_dims;
memory::dims filter_dims;
memory::dims strides;
memory::dims padding_left;
memory::dims padding_right;
mkldnn::algorithm alg_kind;
mkldnn::prop_kind prop_kind;
MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
memory::dims filter_dims, memory::dims strides,
memory::dims padding_left, memory::dims padding_right,
mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind)
: src_dims(src_dims),
dst_dims(dst_dims),
filter_dims(filter_dims),
strides(strides),
padding_left(padding_left),
padding_right(padding_right),
alg_kind(alg_kind),
prop_kind(prop_kind) {}
};
template <typename T>
class MklPoolingFwdPrimitive : public MklPrimitive {
public:
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
if (context_.fwd == nullptr) Setup(fwdParams);
}
~MklPoolingFwdPrimitive() {}
// Pooling forward execute
// src_data: input data buffer of src
// ws_data: output data buffer of workspace
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd()
const {
return context_.fwd_pd;
}
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetDstMemoryFormat() const { return context_.dst_fmt; }
private:
void Setup(const MklPoolingParams& fwdParams);
struct PoolingFwdContext {
// algorithm
mkldnn::algorithm alg_kind;
// Kind of propagation, forward or backward
mkldnn::prop_kind prop_kind;
// expected memory format
memory::format src_fmt;
memory::format dst_fmt;
memory::format ws_fmt;
// workspace shape
memory::dims ws_dims;
memory::data_type ws_dt;
size_t ws_size;
// MKL-DNN memory, just dummy data
std::shared_ptr<mkldnn::memory> ws_mem;
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
// desc & primitive desc
std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
// memory desc
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> dst_md;
// Pooling primitive
std::shared_ptr<mkldnn::pooling_forward> fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
PoolingFwdContext()
: src_fmt(memory::format::any),
dst_fmt(memory::format::any),
ws_fmt(memory::format::any),
ws_mem(nullptr),
src_mem(nullptr),
dst_mem(nullptr),
fwd_desc(nullptr),
fwd_pd(nullptr),
src_md(nullptr),
dst_md(nullptr),
fwd(nullptr),
fwd_stream(nullptr) {}
};
struct PoolingFwdContext context_;
engine cpu_engine_;
};
template <typename T>
class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) {
MklPoolingFwdPrimitive<T>* pooling_forward = nullptr;
// Get pooling primitive from the pool
pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>(
MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd(
fwdParams));
if (pooling_forward == nullptr) {
pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams);
MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd(
fwdParams, pooling_forward);
}
return pooling_forward;
}
static MklPoolingFwdPrimitiveFactory& GetInstance() {
static MklPoolingFwdPrimitiveFactory instance_;
return instance_;
}
private:
MklPoolingFwdPrimitiveFactory() {}
~MklPoolingFwdPrimitiveFactory() {}
// The key to be created will be used to get/set pooling
// primitive op from reuse perspective.
// A pooling key is a string which concates key parameters
// as well as algorithm kind (max versus avg).
static string CreateKey(const MklPoolingParams& fwdParams) {
string prefix = "pooling_fwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(fwdParams.src_dims);
key_creator.AddAsKey(fwdParams.dst_dims);
key_creator.AddAsKey(fwdParams.filter_dims);
key_creator.AddAsKey(fwdParams.strides);
key_creator.AddAsKey(fwdParams.padding_left);
key_creator.AddAsKey(fwdParams.padding_right);
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.prop_kind));
return key_creator.GetKey();
}
MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) {
string key = CreateKey(fwdParams);
return this->GetOp(key);
}
void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) {
string key = CreateKey(fwdParams);
this->SetOp(key, op);
}
};
template <typename T>
class MklPoolingBwdPrimitive : public MklPrimitive {
public:
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
: cpu_engine(engine::cpu, 0) {
context_.bwd_stream.reset(new stream(stream::kind::eager));
if (context_.bwd == nullptr) Setup(bwdParams);
}
~MklPoolingBwdPrimitive() {}
// Pooling backward execute
// diff_dst_data: input data buffer of diff_dst
// diff_src_data: output data buffer of diff_src
// ws_data: input data buffer of workspace
void Execute(const T* diff_dst_data, T* diff_src_data,
const void* ws_data = nullptr);
public:
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd()
const {
return context_.fwd_pd;
}
std::shared_ptr<mkldnn::pooling_backward::primitive_desc> GetPoolingBwdPd()
const {
return context_.bwd_pd;
}
memory::format GetDiffDstFormat() const { return context_.diff_dst_fmt; }
mkldnn::memory::data_type GetWorkspaceDataType() const {
return context_.ws_dt;
}
memory::format GetWorkspaceFormat() const { return context_.ws_fmt; }
private:
void Setup(const MklPoolingParams& bwdParams);
// Primitive reuse context for pooling bwd ops
struct PoolingBwdContext {
// algorithm
mkldnn::algorithm alg_kind;
// expected memory format
mkldnn::memory::format diff_src_fmt;
mkldnn::memory::format diff_dst_fmt;
mkldnn::memory::format ws_fmt;
// workspace attribute
mkldnn::memory::dims ws_dims;
mkldnn::memory::data_type ws_dt;
// MKL-DNN memory
std::shared_ptr<mkldnn::memory> ws_mem;
std::shared_ptr<mkldnn::memory> diff_src_mem;
std::shared_ptr<mkldnn::memory> diff_dst_mem;
// memory desc
std::shared_ptr<mkldnn::memory::desc> diff_src_md;
std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
// desc & primitive desc
std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::pooling_backward::desc> bwd_desc;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd;
// pooling primitive
std::shared_ptr<mkldnn::pooling_backward> bwd;
std::shared_ptr<mkldnn::stream> bwd_stream;
std::vector<mkldnn::primitive> bwd_primitives;
PoolingBwdContext()
: diff_src_fmt(memory::format::any),
diff_dst_fmt(memory::format::any),
ws_fmt(memory::format::any),
ws_mem(nullptr),
diff_src_mem(nullptr),
diff_dst_mem(nullptr),
diff_src_md(nullptr),
diff_dst_md(nullptr),
fwd_desc(nullptr),
bwd_desc(nullptr),
fwd_pd(nullptr),
bwd_pd(nullptr),
bwd(nullptr),
bwd_stream(nullptr) {}
};
struct PoolingBwdContext context_;
engine cpu_engine;
};
template <typename T>
class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) {
MklPoolingBwdPrimitive<T>* pooling_backward = nullptr;
// Find a pooling backward primitive from the pool
// If it does not exist, create a new one
pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>(
MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd(
bwdParams));
if (pooling_backward == nullptr) {
pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams);
MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd(
bwdParams, pooling_backward);
}
return pooling_backward;
}
static MklPoolingBwdPrimitiveFactory& GetInstance() {
static MklPoolingBwdPrimitiveFactory instance_;
return instance_;
}
private:
MklPoolingBwdPrimitiveFactory() {}
~MklPoolingBwdPrimitiveFactory() {}
// The key to be created will be used to get/set pooling
// primitive op from reuse perspective.
// A pooling key is a string which concates key parameters
// as well as algorithm kind (max versus avg).
static string CreateKey(const MklPoolingParams& bwdParams) {
string prefix = "pooling_bwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(bwdParams.src_dims);
key_creator.AddAsKey(bwdParams.dst_dims);
key_creator.AddAsKey(bwdParams.filter_dims);
key_creator.AddAsKey(bwdParams.strides);
key_creator.AddAsKey(bwdParams.padding_left);
key_creator.AddAsKey(bwdParams.padding_right);
key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind));
return key_creator.GetKey();
}
MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) {
string key = CreateKey(bwdParams);
return this->GetOp(key);
}
void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) {
string key = CreateKey(bwdParams);
this->SetOp(key, op);
}
};
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklPoolParameters {
int depth;
int tensor_in_planes; // Pool3D
int tensor_in_cols;
int tensor_in_rows;
int tensor_in_batch;
int window_planes; // Pool3D
int window_rows;
int window_cols;
int depth_window;
int planes_stride; // Pool3D
int row_stride;
int col_stride;
int depth_stride;
int64 out_planes; // Pool3D
int64 out_height;
int64 out_width;
int out_depth;
int64 pad_P1; // Pool3D
int64 pad_P2; // Pool3D
int64 pad_left;
int64 pad_right;
int64 pad_top;
int64 pad_bottom;
int pad_depth;
TensorFormat data_format;
MklPoolParameters()
: depth(0),
tensor_in_planes(0),
tensor_in_cols(0),
tensor_in_rows(0),
tensor_in_batch(0),
window_planes(0),
window_rows(0),
window_cols(0),
depth_window(0),
planes_stride(0),
row_stride(0),
col_stride(0),
depth_stride(0),
out_planes(0),
out_height(0),
out_width(0),
out_depth(0),
pad_P1(0),
pad_P2(0),
pad_left(0),
pad_right(0),
pad_top(0),
pad_bottom(0),
pad_depth(0),
data_format(TensorFormat::FORMAT_NCHW) {}
// Updates context->status if there is an invalid input.
void Init(OpKernelContext* context, const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const TensorShape& tensor_in_shape);
void Init(OpKernelContext* context, const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const MklDnnShape* mkl_in_shape);
private:
// Common initialization for TensorFlow and MKL formats
void Init(OpKernelContext* context, const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format);
};
template <class T>
class MklPoolingOpBase : public OpKernel {
public:
explicit MklPoolingOpBase(OpKernelConstruction* context)
: OpKernel(context), workspace_enabled_(false) {
string data_format;
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
// current quantized convolution doesn't have data_format attribute.
data_format = "NHWC";
} else {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
}
OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5,
errors::InvalidArgument("Sliding window ksize field must "
"specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
errors::Unimplemented("Pooling is not yet supported on the "
"batch dimension."));
bool is_pool2d = (this->ksize_.size() == 4);
this->data_format_mkldnn_ =
is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_)
: TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_);
// We may not get this attribute for this node if it does not go through
// graph rewrite pass. So we do not check for error while retrieving this
// attribute value.
context->GetAttr("workspace_enabled", &this->workspace_enabled_);
}
void Compute(OpKernelContext* context) override = 0;
protected:
// Calculate output shape of pooling op in MKL-DNN and TensorFlow order.
// MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order.
// But TensorFlow output will be in NHWC/NCHW(Pool2D) or
// NDHWC/NCDHW(Pool3D) format depending on data format. Function expects
// output height and width to have already been int32 bounds-checked.
void GetOutputDims(const MklPoolParameters& mkl_pool_params,
memory::dims* output_dims_mkl_order) {
if (this->ksize_.size() == 4) {
// Pooling2D: MKL-DNN always needs output in NCHW format.
*output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
mkl_pool_params.out_depth,
static_cast<int>(mkl_pool_params.out_height),
static_cast<int>(mkl_pool_params.out_width)};
} else {
// Pooling3D: MKL-DNN always needs output in NCDHW format.
*output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
mkl_pool_params.out_depth,
static_cast<int>(mkl_pool_params.out_planes),
static_cast<int>(mkl_pool_params.out_height),
static_cast<int>(mkl_pool_params.out_width)};
}
}
void InitMklPoolParameters(OpKernelContext* context,
MklPoolParameters* pool_params,
const MklDnnShape& original_input_mkl_shape,
const TensorShape& input_tensor_shape) {
if (!original_input_mkl_shape.IsMklTensor()) {
pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
this->data_format_tf_, input_tensor_shape);
} else {
pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
this->data_format_tf_, &original_input_mkl_shape);
}
}
void PoolParamsToDims(const MklPoolParameters* pool_params,
memory::dims* filter_dims, memory::dims* strides,
memory::dims* padding_left, memory::dims* padding_right,
bool is_pool2d) {
if (is_pool2d) {
// Pool2D
*filter_dims =
memory::dims({pool_params->window_rows, pool_params->window_cols});
*strides =
memory::dims({pool_params->row_stride, pool_params->col_stride});
*padding_left = memory::dims({static_cast<int>(pool_params->pad_top),
static_cast<int>(pool_params->pad_left)});
*padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom),
static_cast<int>(pool_params->pad_right)});
} else {
// Pool3D
*filter_dims =
memory::dims({pool_params->window_planes, pool_params->window_rows,
pool_params->window_cols});
*strides =
memory::dims({pool_params->planes_stride, pool_params->row_stride,
pool_params->col_stride});
*padding_left = memory::dims({static_cast<int>(pool_params->pad_P1),
static_cast<int>(pool_params->pad_top),
static_cast<int>(pool_params->pad_left)});
*padding_right = memory::dims({static_cast<int>(pool_params->pad_P2),
static_cast<int>(pool_params->pad_bottom),
static_cast<int>(pool_params->pad_right)});
}
}
void AllocateEmptyOutputTensor(OpKernelContext* context,
const int kOutputIndex,
MklPoolParameters* pool_params,
const memory::dims output_dims_mkl_order,
Tensor** output_tensor) {
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);
TensorShape output_tf_shape;
if (pool_params->data_format == TensorFormat::FORMAT_NCHW) {
output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
} else {
memory::dims output_dims_order;
// determine Pooling2D (NHWC) or Pooling3D (NDHWC)
if (this->ksize_.size() == 4) {
output_dims_order = {pool_params->tensor_in_batch,
static_cast<int>(pool_params->out_height),
static_cast<int>(pool_params->out_width),
pool_params->out_depth};
} else {
output_dims_order = {pool_params->tensor_in_batch,
static_cast<int>(pool_params->out_planes),
static_cast<int>(pool_params->out_height),
static_cast<int>(pool_params->out_width),
pool_params->out_depth};
}
output_tf_shape = MklDnnDimsToTFShape(output_dims_order);
}
AllocateOutputSetMklShape(context, kOutputIndex, output_tensor,
output_tf_shape, output_mkl_shape);
CHECK_NOTNULL(output_tensor);
}
// Checks to make sure that the memory we need to allocate
// is a multiple of sizeof(T)
// returns the number of elements
size_t GetNumTElements(const memory::primitive_desc& pd) {
size_t num_bytes = pd.get_size();
size_t ret_val = num_bytes / sizeof(T);
if (num_bytes % sizeof(T) != 0) {
ret_val++;
}
return ret_val;
}
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
TensorFormat data_format_tf_;
memory::format data_format_mkldnn_;
bool workspace_enabled_;
};
template <class T>
class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
public:
explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context)
: MklPoolingOpBase<T>(context) {}
void Compute(OpKernelContext* context) override = 0;
protected:
void ConfigureInput(OpKernelContext* context,
const MklDnnShape& input_mkl_shape,
const Tensor& input_tensor,
MklPoolParameters* pool_params,
MklDnnData<T>* dnn_data_input) {
CHECK_NOTNULL(pool_params);
CHECK_NOTNULL(dnn_data_input);
TensorShape input_tensor_shape = input_tensor.shape();
if (input_tensor.NumElements() != 0) {
memory::desc input_md =
input_mkl_shape.IsMklTensor()
? input_mkl_shape.GetMklLayout()
: memory::desc(
(this->ksize_.size() == 4)
? TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
this->data_format_tf_)
: TFShapeToMklDnnDimsInNCDHW(input_tensor_shape,
this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
dnn_data_input->SetUsrMem(input_md, &input_tensor);
if (this->ksize_.size() == 5) {
// Pool3D
std::vector<int> mkldnn_sizes(5, -1);
mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0];
mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1];
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2];
mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3];
mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4];
dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_);
}
}
this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
input_tensor_shape);
}
void AllocateOutputTensor(
OpKernelContext* context,
const pooling_forward::primitive_desc& pool_fwd_prim_desc,
const memory::dims output_dims_mkl_order,
const memory::format& output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc();
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format);
TensorShape output_tf_shape;
// only allocate enough space for the elements we need.
output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
output_tf_shape, output_mkl_shape);
CHECK_NOTNULL(*output_tensor);
}
void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
const MklDnnShape& input_mkl_shape) {
if (!input_mkl_shape.IsMklTensor()) {
OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5,
errors::InvalidArgument("Input must be 4 or 5-dimensional"));
} else {
OP_REQUIRES(
context,
input_mkl_shape.GetDimension() == 4 ||
input_mkl_shape.GetDimension() == 5,
errors::InvalidArgument("Input shape must be 4 or 5-dimensional"));
}
}
// .Input("value: T")
// .Output("output: T")
const int kInputTensorIndexInput = 0;
const int kOutputTensorIndexOutput = 0;
}; // MklPoolingForwardBaseOp
template <class T>
class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
public:
explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context)
: MklPoolingOpBase<T>(context) {}
void Compute(OpKernelContext* context) override = 0;
protected:
const int kOutputTensorIndexOutput = 0;
void AllocateOutputTensor(
OpKernelContext* context,
const pooling_backward::primitive_desc& pool_bkwd_prim_desc,
const memory::dims output_dims_mkl_order,
const memory::format& output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
memory::primitive_desc dst_pd =
pool_bkwd_prim_desc.diff_src_primitive_desc();
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format);
TensorShape output_tf_shape;
output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
output_tf_shape, output_mkl_shape);
CHECK_NOTNULL(*output_tensor);
}
memory::desc ConfigureInputGradient(
const MklDnnShape& input_gradient_mkl_shape,
const Tensor& input_gradient_tensor,
MklDnnData<T>* input_gradient_dnn_data,
const memory::desc& original_output_md) {
// Configure the gradient as is
memory::desc original_input_grad_md =
input_gradient_mkl_shape.IsMklTensor()
? input_gradient_mkl_shape.GetMklLayout()
: memory::desc(
(this->ksize_.size() == 4)
? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
this->data_format_tf_)
: TFShapeToMklDnnDimsInNCDHW(
input_gradient_tensor.shape(),
this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
input_gradient_dnn_data->SetUsrMem(original_input_grad_md,
&input_gradient_tensor);
// Check to see if input grad diff dst is in the right format
// Create a new memory descriptor with the same shape as the
// original, but the format of the other tensors.
memory::format original_output_format =
static_cast<memory::format>(original_output_md.data.format);
bool grad_reorder_needed =
input_gradient_dnn_data->IsReorderNeeded(original_output_format);
memory::dims diff_dst_dims =
input_gradient_mkl_shape.IsMklTensor()
? input_gradient_mkl_shape.GetSizesAsMklDnnDims()
: TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
this->data_format_tf_);
memory::desc target_diff_dst_md =
memory::desc(diff_dst_dims, MklDnnType<T>(), original_output_format);
return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md;
}
};
//-------------------------------------------------------------------
// Utility functions
typedef struct {
size_t in_dim;
size_t in_sizes[4];
size_t in_strides[4];
size_t out_sizes[4];
size_t out_strides[4];
int in_offset[4];
size_t kernel_stride[2];
size_t kernel_size[2];
} MklPoolingOpParams;
// Transfers the right parameters for pooling to the op parameters
// Updates context->status if there is an invalid input.
void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format,
const MklPoolParameters& params,
MklPoolingOpParams* mkl_params);
} // namespace tensorflow
#endif // INTEL_MKL
#endif // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_