blob: 0776478f20ee1f98acd4c90382fd88e9a34b1be4 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef INTEL_MKL
#include <limits>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
namespace tensorflow {
using mkldnn::prop_kind;
template <typename T>
void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
DCHECK(fwdParams.alg_kind == ALGORITHM::pooling_max ||
fwdParams.alg_kind == ALGORITHM::pooling_avg ||
fwdParams.alg_kind == ALGORITHM::pooling_avg_include_padding ||
fwdParams.alg_kind == ALGORITHM::pooling_avg_exclude_padding)
<< "Pooling algorithm kind is not supported";
context_.alg_kind = fwdParams.alg_kind;
context_.prop_kind = fwdParams.prop_kind;
// Create memory descriptor
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
bool is_2d = (fwdParams.src_dims.size() == 4);
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
context_.src_fmt = is_2d ? MEMORY_FORMAT::nhwc : MEMORY_FORMAT::ndhwc;
else
context_.src_fmt = fwdParams.src_format;
context_.src_md.reset(new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
context_.src_fmt));
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
MEMORY_FORMAT::any));
// Create a pooling descriptor.
context_.fwd_desc.reset(new pooling_forward::desc(
fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md,
*context_.dst_md, fwdParams.strides, fwdParams.filter_dims,
fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero));
context_.fwd_pd.reset(
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
#ifndef ENABLE_MKLDNN_V1
context_.dst_fmt = static_cast<MEMORY_FORMAT>(
context_.fwd_pd.get()->PRIMITIVE_DESC_DST.desc().data.format);
#endif // ENABLE_MKLDNN_V1
// Create MKL-DNN internal memory object with dummy data.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
// For max pooling, need to return workspace (ws) for backward computing.
if (fwdParams.alg_kind == ALGORITHM::pooling_max &&
fwdParams.prop_kind == prop_kind::forward_training) {
#ifdef ENABLE_MKLDNN_V1
context_.ws_mem.reset(
new MEMORY_CONSTRUCTOR(context_.fwd_pd.get()->PRIMITIVE_DESC_WORKSPACE,
cpu_engine_, DummyData));
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DST, *context_.dst_mem},
{MKLDNN_ARG_WORKSPACE, *context_.ws_mem}});
context_.fwd.reset(new pooling_forward(*context_.fwd_pd));
#else
auto ws_pd = context_.fwd_pd.get()->PRIMITIVE_DESC_WORKSPACE.desc().data;
// Store workspace's dims and format to create workspace tensor.
context_.ws_fmt = static_cast<MEMORY_FORMAT>(ws_pd.format);
context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
context_.ws_size =
context_.fwd_pd.get()->PRIMITIVE_DESC_WORKSPACE.get_size();
context_.ws_mem.reset(
new memory(context_.fwd_pd.get()->PRIMITIVE_DESC_WORKSPACE, DummyData));
context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem,
*context_.dst_mem,
*context_.ws_mem));
#endif // ENABLE_MKLDNN_V1
} else {
#ifdef ENABLE_MKLDNN_V1
context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DST, *context_.dst_mem}});
context_.fwd.reset(new pooling_forward(*context_.fwd_pd));
#else
context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem,
*context_.dst_mem));
#endif // ENABLE_MKLDNN_V1
}
context_.fwd_primitives.push_back(*context_.fwd);
}
template <typename T>
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
void* ws_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
if (context_.alg_kind == ALGORITHM::pooling_max &&
context_.prop_kind ==
prop_kind::forward_training) { // Max pooling must have workspace.
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(), context_.net_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.net_args.at(i));
}
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// Set back data handle.
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
if (context_.alg_kind == ALGORITHM::pooling_max &&
context_.prop_kind ==
prop_kind::forward_training) { // Max pooling must have workspace.
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
template class MklPoolingFwdPrimitive<float>;
template class MklPoolingFwdPrimitive<quint8>;
template class MklPoolingFwdPrimitive<qint8>;
template class MklPoolingFwdPrimitive<bfloat16>;
template <typename T>
void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
DCHECK(bwdParams.alg_kind == ALGORITHM::pooling_max ||
bwdParams.alg_kind == ALGORITHM::pooling_avg ||
bwdParams.alg_kind == ALGORITHM::pooling_avg_include_padding ||
bwdParams.alg_kind == ALGORITHM::pooling_avg_exclude_padding)
<< "Pooling algorithm kind is not supported";
context_.alg_kind = bwdParams.alg_kind;
// Create memory descriptor.
context_.diff_src_md.reset(new memory::desc(
{bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
context_.diff_dst_md.reset(new memory::desc(
{bwdParams.dst_dims}, MklDnnType<T>(), bwdParams.src_format));
context_.bwd_desc.reset(new pooling_backward::desc(
bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
bwdParams.padding_right, padding_kind::zero));
// Create a forward primitive,
// which will be used as a hint for creating backward primitive.
context_.fwd_desc.reset(new pooling_forward::desc(
bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md,
*context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims,
bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero));
context_.fwd_pd.reset(
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
context_.bwd_pd.reset(new pooling_backward::primitive_desc(
*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
#ifndef ENABLE_MKLDNN_V1
context_.diff_src_fmt = static_cast<MEMORY_FORMAT>(
context_.bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC.desc().data.format);
context_.diff_dst_fmt = bwdParams.src_format;
#endif // ENABLE_MKLDNN_V1
#ifdef ENABLE_MKLDNN_V1
// Create MKL-DNN internal memory object with dummy data.
context_.diff_src_mem.reset(new memory(context_.bwd_pd.get()->diff_src_desc(),
cpu_engine_, DummyData));
context_.diff_dst_mem.reset(new memory(context_.bwd_pd.get()->diff_dst_desc(),
cpu_engine_, DummyData));
#else
context_.diff_src_mem.reset(
new memory(context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
context_.diff_dst_mem.reset(new memory(
{{{bwdParams.dst_dims}, MklDnnType<T>(), context_.diff_dst_fmt},
cpu_engine_},
DummyData));
#endif // ENABLE_MKLDNN_V1
// For max pooling, need to return workspace for backward computing.
if (bwdParams.alg_kind == ALGORITHM::pooling_max) {
#ifdef ENABLE_MKLDNN_V1
auto ws_pd = context_.fwd_pd.get()->PRIMITIVE_DESC_WORKSPACE.data;
context_.ws_mem.reset(
new memory(context_.fwd_pd.get()->workspace_desc(), cpu_engine_));
context_.net_args.push_back(
{{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
{MKLDNN_ARG_WORKSPACE, *context_.ws_mem},
{MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem}});
context_.bwd.reset(new pooling_backward(*context_.bwd_pd));
#else
auto ws_pd = context_.fwd_pd.get()->PRIMITIVE_DESC_WORKSPACE.desc().data;
context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
context_.ws_fmt = static_cast<memory::format>(ws_pd.format);
context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
context_.ws_mem.reset(new memory(
{{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine_},
DummyData));
context_.bwd.reset(
new pooling_backward(*context_.bwd_pd, *context_.diff_dst_mem,
*context_.ws_mem, *context_.diff_src_mem));
#endif // ENABLE_MKLDNN_V1
} else {
#ifdef ENABLE_MKLDNN_V1
context_.net_args.push_back(
{{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
{MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem}});
context_.bwd.reset(new pooling_backward(*context_.bwd_pd));
#else
context_.bwd.reset(new pooling_backward(
*context_.bwd_pd, *context_.diff_dst_mem, *context_.diff_src_mem));
#endif // ENABLE_MKLDNN_V1
}
context_.bwd_primitives.push_back(*context_.bwd);
}
template <typename T>
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
T* diff_src_data, const void* ws_data) {
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
if (context_.alg_kind == ALGORITHM::pooling_max) {
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
context_.net_args.at(i));
}
#else
context_.bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// Set back data handle.
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
if (context_.alg_kind == ALGORITHM::pooling_max) {
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
template class MklPoolingBwdPrimitive<float>;
template class MklPoolingBwdPrimitive<bfloat16>;
// Initialization for TensorFlow format
void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format,
const TensorShape& tensor_in_shape) {
// For max pooling, tensor_in should have 4 or 5 dimensions.
OP_REQUIRES(context,
tensor_in_shape.dims() == 4 || tensor_in_shape.dims() == 5,
errors::InvalidArgument("tensor_in must be 4 or 5-dimensional"));
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
if (tensor_in_shape.dims() == 4) {
// Pool2D
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
} else {
// Pool3D
tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
}
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
Init(context, ksize, stride, padding, data_format);
}
// Initialization for MKL format.
void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format,
const MklDnnShape* mklInputShape) {
// Get the input sizes.
if (ksize.size() == 4) {
// Pool2D
depth = mklInputShape->GetDimension('C');
tensor_in_cols = mklInputShape->GetDimension('W');
tensor_in_rows = mklInputShape->GetDimension('H');
tensor_in_batch = mklInputShape->GetDimension('N');
} else {
// Pool3D
depth = mklInputShape->GetDimension3D('C');
tensor_in_cols = mklInputShape->GetDimension3D('W');
tensor_in_rows = mklInputShape->GetDimension3D('H');
tensor_in_planes = mklInputShape->GetDimension3D('D');
tensor_in_batch = mklInputShape->GetDimension3D('N');
}
Init(context, ksize, stride, padding, data_format);
}
// Common Initialization for TensorFlow and MKL formats.
void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& ksize,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format) {
// Get the data format.
this->data_format = data_format;
bool is_pool2d = (ksize.size() == 4);
if (is_pool2d) {
// Pool2D
// Get the output sizes.
window_rows = GetTensorDim(ksize, data_format, 'H');
window_cols = GetTensorDim(ksize, data_format, 'W');
depth_window = GetTensorDim(ksize, data_format, 'C');
// Get the strides.
row_stride = GetTensorDim(stride, data_format, 'H');
col_stride = GetTensorDim(stride, data_format, 'W');
depth_stride = GetTensorDim(stride, data_format, 'C');
// We only support 2D pooling across width/height and depthwise
// pooling, not a combination.
OP_REQUIRES(context,
(depth_window == 1 || (window_rows == 1 && window_cols == 1)),
errors::Unimplemented(
"MaxPooling supports exactly one of pooling across depth "
"or pooling across width/height."));
} else {
// Pool3D
// Get the output sizes.
window_planes = GetTensorDim(ksize, data_format, '0');
window_rows = GetTensorDim(ksize, data_format, '1');
window_cols = GetTensorDim(ksize, data_format, '2');
depth_window = GetTensorDim(ksize, data_format, 'C');
// Get the strides.
planes_stride = GetTensorDim(stride, data_format, '0');
row_stride = GetTensorDim(stride, data_format, '1');
col_stride = GetTensorDim(stride, data_format, '2');
depth_stride = GetTensorDim(stride, data_format, 'C');
// We only support 3D pooling across depth/width/height and depthwise
// pooling, not a combination.
OP_REQUIRES(context,
(depth_window == 1 ||
(window_rows == 1 && window_cols == 1 && window_planes == 1)),
errors::Unimplemented(
"AvgPooling3D supports exactly one of pooling across depth "
"or pooling across depth/width/height."));
}
if (depth_window == 1) { // We are pooling in the D (Pool3D only), H and W.
if (!is_pool2d) {
OP_REQUIRES_OK(
context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes,
planes_stride, padding,
&out_planes, &pad_P1, &pad_P2));
}
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_rows, window_rows, row_stride,
padding, &out_height, &pad_top, &pad_bottom));
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_cols, window_cols, col_stride,
padding, &out_width, &pad_left, &pad_right));
// TF can work with int64, but mkldnn only supports int32.
// Fail if the depth, height or width are greater than MAX_INT.
// We check depth only for 3D pooling case.
if (!is_pool2d) {
OP_REQUIRES(context,
FastBoundsCheck(out_planes, std::numeric_limits<int>::max()),
errors::InvalidArgument("output depth/planes is too large"));
}
OP_REQUIRES(context,
FastBoundsCheck(out_height, std::numeric_limits<int>::max()),
errors::InvalidArgument("output height is too large"));
OP_REQUIRES(context,
FastBoundsCheck(out_width, std::numeric_limits<int>::max()),
errors::InvalidArgument("output width is too large"));
out_depth = depth; // Output will have the same depth as the input.
} else { // We are pooling in the depth dimension.
// Our current version of depthwise max pooling does not support
// any padding, and expects the depth_window to equal the depth
// stride (no overlapping).
OP_REQUIRES(context, depth % depth_window == 0,
errors::Unimplemented("Depthwise max pooling requires the"
" depth window to evenly divide the"
" input depth"));
OP_REQUIRES(context, depth_stride == depth_window,
errors::Unimplemented("Depthwise max pooling requires the"
" depth window to equal the depth"
" stride"));
// The current version of depthwise max is only implemented on CPU.
OP_REQUIRES(context,
(DeviceType(static_cast<Device*>(context->device())
->attributes()
.device_type()) == DeviceType(DEVICE_CPU)),
errors::Unimplemented("Depthwise max pooling is currently "
"only implemented for CPU devices."));
out_depth = depth / depth_window;
}
}
} // namespace tensorflow
#endif // INTEL_MKL