Bugfix for mkl slice view not working properly with block format nchw16c when offset is 8.
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
index b259c28..c3deab7 100644
--- a/tensorflow/core/kernels/mkl_slice_op.cc
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -387,6 +387,7 @@
// Step 1 (as per above description) - Create memory for user data.
// We use blocked format here to describe input tensor.
const Tensor& input_tensor = MklGetInput(context, 0);
+ memory::dims input_dims, input_strides;
MklDnnShape input_mkl_shape;
GetMklShape(context, 0, &input_mkl_shape);
@@ -397,10 +398,14 @@
size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
auto input_md = input_mkl_shape.GetMklLayout();
src.SetUsrMem(input_md, &input_tensor);
+
+ // Handle data format safely, change them to block format.
+ // Compute parameters of reorder primitive first.
+ input_dims = input_mkl_shape.GetSizesAsMklDnnDims();
+ input_strides = CalculateTFStrides(input_dims);
} else {
// Initialize input dimensions and strides to be used when input is not
// in MklDnn layout.
- memory::dims input_dims, input_strides;
input_dims = TFShapeToMklDnnDims(input_tensor.shape());
input_strides = CalculateTFStrides(input_dims);
// Create input memory descriptor.
@@ -409,6 +414,13 @@
src.SetUsrMem(input_md, &input_tensor);
}
+ // If format not equal to block format, execute reorder.
+ // Or else do nothing for it.
+ auto op_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ auto op_pd = memory::primitive_desc(op_md, cpu_engine);
+ src.CheckReorderToOpMem(op_pd);
+
// Step 2 - Create memory for output.
auto output_strides = CalculateTFStrides(size_dims);
auto output_md =
@@ -421,7 +433,7 @@
output.SetUsrMem(output_md, output_tensor);
// Step 3 - create reorder primitive.
- MklSliceParams sliceParams(src.GetUsrMem(), output.GetUsrMem(),
+ MklSliceParams sliceParams(&src.GetOpMem(), output.GetUsrMem(),
begin_dims, size_dims);
MklSlicePrimitive<T>* reorder_prim =
MklSlicePrimitiveFactory<T>::Get(sliceParams);
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 258b39b..dc81c4f 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -28,8 +28,12 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
+import tensorflow as tf
class SliceTest(test.TestCase):
@@ -42,6 +46,27 @@
slice_val = self.evaluate(slice_t)
self.assertAllEqual(slice_val, inp[2, k:k])
+ def testView(self):
+ cout = 45
+ shape = [64, 28, 28, 32]
+ dtype = dtypes.float32
+ gain = 3.14
+ kernel_size = [1, 1]
+
+ convolution = tf.keras.layers.Conv2D
+ inputs = random_ops.random_normal(shape, dtype=dtype)
+ middle = convolution(
+ padding="valid", filters=cout,
+ kernel_size=kernel_size, use_bias=False,
+ kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain)
+ ).apply(inputs)
+
+ outputs = array_ops.slice(middle, [8, 8, 8, 8], [16, 16, 16, 16])
+ my_ops = variables.global_variables_initializer()
+ with self.session(use_gpu=True) as sess:
+ sess.run(my_ops)
+ t = outputs.eval()
+
def testInt32(self):
inp = np.random.rand(4, 4).astype("i")
for k in xrange(4):