Bugfix for mkl slice view not working properly with block format nchw16c when offset is 8.
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
index b259c28..c3deab7 100644
--- a/tensorflow/core/kernels/mkl_slice_op.cc
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -387,6 +387,7 @@
       // Step 1 (as per above description) - Create memory for user data.
       // We use blocked format here to describe input tensor.
       const Tensor& input_tensor = MklGetInput(context, 0);
+      memory::dims input_dims, input_strides;
       MklDnnShape input_mkl_shape;
       GetMklShape(context, 0, &input_mkl_shape);
 
@@ -397,10 +398,14 @@
         size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
         auto input_md = input_mkl_shape.GetMklLayout();
         src.SetUsrMem(input_md, &input_tensor);
+
+        // Handle data format safely, change them to block format.
+        // Compute parameters of reorder primitive first.
+        input_dims = input_mkl_shape.GetSizesAsMklDnnDims();
+        input_strides = CalculateTFStrides(input_dims);
       } else {
         // Initialize input dimensions and strides to be used when input is not
         // in MklDnn layout.
-        memory::dims input_dims, input_strides;
         input_dims = TFShapeToMklDnnDims(input_tensor.shape());
         input_strides = CalculateTFStrides(input_dims);
         // Create input memory descriptor.
@@ -409,6 +414,13 @@
         src.SetUsrMem(input_md, &input_tensor);
       }
 
+      // If format not equal to block format, execute reorder.
+      // Or else do nothing for it.
+      auto op_md =
+          MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+      auto op_pd = memory::primitive_desc(op_md, cpu_engine);
+      src.CheckReorderToOpMem(op_pd);
+
       // Step 2 - Create memory for output.
       auto output_strides = CalculateTFStrides(size_dims);
       auto output_md =
@@ -421,7 +433,7 @@
       output.SetUsrMem(output_md, output_tensor);
 
       // Step 3 - create reorder primitive.
-      MklSliceParams sliceParams(src.GetUsrMem(), output.GetUsrMem(),
+      MklSliceParams sliceParams(&src.GetOpMem(), output.GetUsrMem(),
                                  begin_dims, size_dims);
       MklSlicePrimitive<T>* reorder_prim =
           MklSlicePrimitiveFactory<T>::Get(sliceParams);
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 258b39b..dc81c4f 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -28,8 +28,12 @@
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
 
+import tensorflow as tf
 
 class SliceTest(test.TestCase):
 
@@ -42,6 +46,27 @@
         slice_val = self.evaluate(slice_t)
       self.assertAllEqual(slice_val, inp[2, k:k])
 
+  def testView(self):
+    cout = 45
+    shape = [64, 28, 28, 32]
+    dtype = dtypes.float32
+    gain = 3.14
+    kernel_size = [1, 1]
+
+    convolution = tf.keras.layers.Conv2D
+    inputs = random_ops.random_normal(shape, dtype=dtype)
+    middle = convolution(
+      padding="valid", filters=cout,
+      kernel_size=kernel_size, use_bias=False,
+      kernel_initializer=init_ops.convolutional_orthogonal_2d(gain=gain)
+    ).apply(inputs)
+
+    outputs = array_ops.slice(middle, [8, 8, 8, 8], [16, 16, 16, 16])
+    my_ops = variables.global_variables_initializer()
+    with self.session(use_gpu=True) as sess:
+      sess.run(my_ops)
+      t = outputs.eval()
+
   def testInt32(self):
     inp = np.random.rand(4, 4).astype("i")
     for k in xrange(4):