Roll forward https://github.com/tensorflow/tensorflow/commit/ab0ca4bbc66a476aea305f81c69e0201b5876d0a. The internal test that it broke has been fixed.

PiperOrigin-RevId: 401913101
Change-Id: I67f095899187e38101fbb10289c5e444b0a9e8c0
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index ce89b02..9edd5cf 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -325,6 +325,14 @@
     if (!context->status().ok()) {
       return;
     }
+    OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected orig_output shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", tensor_out.shape()));
+    OP_REQUIRES(context, out_backprop.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected grad shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", out_backprop.shape()));
 
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
@@ -538,6 +546,18 @@
                           /*explicit_paddings=*/{},
                           FORMAT_NHWC,
                           tensor_in.shape()};
+    if (!context->status().ok()) {
+      return;
+    }
+    OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected orig_output shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", tensor_out.shape()));
+    OP_REQUIRES(
+        context, out_grad_backprop.shape() == tensor_in.shape(),
+        errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
+                                ", but got ", out_grad_backprop.shape()));
+
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
                                 {2}, 0, tensor_out.shape(), &output));
@@ -742,6 +762,17 @@
                           /*explicit_paddings=*/{},
                           data_format_,
                           tensor_in.shape()};
+    if (!context->status().ok()) {
+      return;
+    }
+    OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected orig_output shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", tensor_out.shape()));
+    OP_REQUIRES(
+        context, out_grad_backprop.shape() == tensor_in.shape(),
+        errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
+                                ", but got ", out_grad_backprop.shape()));
 
     functor::MaxPoolGradBackwardNoMask<T>()(
         data_format_, tensor_in.flat<T>().data(), tensor_out.flat<T>().data(),
@@ -1096,6 +1127,14 @@
     if (!context->status().ok()) {
       return;
     }
+    OP_REQUIRES(context, grad_in.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected grad shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", grad_in.shape()));
+    OP_REQUIRES(context, argmax.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected argmax shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", argmax.shape()));
 
     TensorShape out_shape({params.tensor_in_batch, params.tensor_in_rows,
                            params.tensor_in_cols, params.depth});
@@ -1156,6 +1195,14 @@
     if (!context->status().ok()) {
       return;
     }
+    OP_REQUIRES(
+        context, grad_in.shape() == tensor_in.shape(),
+        errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
+                                ", but got ", grad_in.shape()));
+    OP_REQUIRES(context, argmax.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected argmax shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", argmax.shape()));
 
     TensorShape out_shape({params.tensor_in_batch, params.out_height,
                            params.out_width, params.depth});
diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc
index d4dc87c..d4444b6 100644
--- a/tensorflow/core/kernels/pooling_ops_3d.cc
+++ b/tensorflow/core/kernels/pooling_ops_3d.cc
@@ -366,6 +366,19 @@
 
     OP_REQUIRES_OK(context, Get3dOutputSize(input_size, window, stride,
                                             padding_, &out, &padding));
+
+    const int64_t depth = GetTensorDim(tensor_in, data_format_, 'C');
+    const int64_t in_batch = GetTensorDim(tensor_in, data_format_, 'N');
+    TensorShape out_shape = ShapeFromFormat(data_format_, in_batch,
+                                            {{out[2], out[1], out[0]}}, depth);
+    OP_REQUIRES(
+        context, tensor_out.shape() == out_shape,
+        errors::InvalidArgument("Expected orig_output shape to be ", out_shape,
+                                ", but got ", tensor_out.shape()));
+    OP_REQUIRES(context, out_backprop.shape() == out_shape,
+                errors::InvalidArgument("Expected grad shape to be ", out_shape,
+                                        ", but got ", out_backprop.shape()));
+
     LaunchMaxPooling3dGradOp<Device, T>::launch(
         context, tensor_in, tensor_out, out_backprop, window, stride, out,
         padding, data_format_, input_backprop);
@@ -712,6 +725,14 @@
     Pool3dParameters params{context,  ksize_,       stride_,
                             padding_, data_format_, tensor_in.shape()};
     if (!context->status().ok()) return;  // params is invalid
+    OP_REQUIRES(context, tensor_out.shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected orig_output shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", tensor_out.shape()));
+    OP_REQUIRES(
+        context, out_grad_backprop.shape() == tensor_in.shape(),
+        errors::InvalidArgument("Expected grad shape to be ", tensor_in.shape(),
+                                ", but got ", out_grad_backprop.shape()));
 
     Tensor* output = nullptr;
     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index 817072c..d621e77 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -465,6 +465,16 @@
   if (!context->status().ok()) {
     return;
   }
+  if (tensor_out) {
+    OP_REQUIRES(context, tensor_out->shape() == params.forward_output_shape(),
+                errors::InvalidArgument("Expected orig_output shape to be ",
+                                        params.forward_output_shape(),
+                                        ", but got ", tensor_out->shape()));
+  }
+  OP_REQUIRES(context, out_backprop.shape() == params.forward_output_shape(),
+              errors::InvalidArgument("Expected grad shape to be ",
+                                      params.forward_output_shape(),
+                                      ", but got ", out_backprop.shape()));
 
   TensorFormat transformed_input_data_format = data_format;
 
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index 5e6f2e4..1a41a5a 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -83,11 +83,6 @@
   TensorFormat data_format;
 };
 
-// Checks if the sizes of the paddings are less than the size of window.
-// This is required for MaxPool because it pads with -inf, so the pooling
-// window cannot fully cover the padded area.
-Status CheckPaddingSize(PoolParameters& params);
-
 // An implementation of MaxPooling (forward).
 // TODO (yongtang): Remove MaxPoolingOp and use MaxPoolingV2Op,
 //     QuantizedMaxPoolingOp depends on MaxPoolingOp so keep intact for now
diff --git a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
index 203d3ad..12710c4 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_3d_test.py
@@ -16,9 +16,13 @@
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import gradient_checker
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import nn_ops
@@ -515,6 +519,44 @@
           pool_3d = f(input_tensor, ksize=[2, 2, 0], strides=1, padding="VALID")
           self.evaluate(pool_3d)
 
+  def testMaxPoolGradEagerShapeErrors(self):
+    with context.eager_mode():
+      orig_in = array_ops.ones((1, 1, 1, 1, 1))
+
+      # Test invalid orig_out shape
+      orig_out = array_ops.ones((1, 1, 1, 1, 2))
+      grad = array_ops.ones((1, 1, 1, 1, 1))
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected orig_output shape to be \[1,1,1,1,1\], but got "
+          r"\[1,1,1,1,2\]"):
+        gen_nn_ops.max_pool3d_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
+            strides=[1, 1, 1, 1, 1], padding="VALID")
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected orig_output shape to be \[1,1,1,1,1\], but got "
+          r"\[1,1,1,1,2\]"):
+        gen_nn_ops.max_pool3d_grad_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
+            strides=[1, 1, 1, 1, 1], padding="VALID")
+
+      # Test invalid grad shape
+      orig_out = array_ops.ones((1, 1, 1, 1, 1))
+      grad = array_ops.ones((1, 1, 1, 1, 2))
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected grad shape to be \[1,1,1,1,1\], but got \[1,1,1,1,2\]"):
+        gen_nn_ops.max_pool3d_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
+            strides=[1, 1, 1, 1, 1], padding="VALID")
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected grad shape to be \[1,1,1,1,1\], but got \[1,1,1,1,2\]"):
+        gen_nn_ops.max_pool3d_grad_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1, 1],
+            strides=[1, 1, 1, 1, 1], padding="VALID")
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 6f1b5ed..79adfff 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -618,6 +618,7 @@
 
   @parameterized.parameters(
       GetTestConfigsDicts(nn_ops.max_pool, nn_ops.max_pool_v2))
+  @test_util.xla_allow_fallback("XLA doesn't support explicit padding")
   @test_util.run_deprecated_v1
   def testMaxPoolNegativeInputExpPaddingAdv(self, **kwargs):
     expected_output = [-1, -1, -3, -5, -7, -7, -9, -11, -19, -19, -21, -23, -31,
@@ -2390,6 +2391,82 @@
             explicit_paddings=[1, 1, 1, 1, 1, 1, 0, 0],
             data_format="NHWC"))
 
+  def testMaxPoolGradEagerShapeErrors(self):
+    with context.eager_mode():
+      orig_in = array_ops.ones((1, 1, 1, 1))
+
+      # Test invalid orig_out shape
+      orig_out = array_ops.ones((1, 1, 1, 2))
+      grad = array_ops.ones((1, 1, 1, 1))
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected orig_output shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+        gen_nn_ops.max_pool_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+            padding="VALID")
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected orig_output shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+        gen_nn_ops.max_pool_grad_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+            padding="VALID")
+
+      # Test invalid grad shape
+      orig_out = array_ops.ones((1, 1, 1, 1))
+      grad = array_ops.ones((1, 1, 1, 2))
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected grad shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+        gen_nn_ops.max_pool_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+            padding="VALID")
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected grad shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+        gen_nn_ops.max_pool_grad_grad(
+            orig_in, orig_out, grad, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+            padding="VALID")
+
+  def testMaxPoolGradWithArgmaxEagerShapeErrors(self):
+    with context.eager_mode():
+      inp = array_ops.ones((1, 1, 1, 1))
+
+      # Test invalid grad shape
+      grad = array_ops.ones((1, 1, 1, 2))
+      argmax = array_ops.zeros((1, 1, 1, 1), dtype=dtypes.int64)
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected grad shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+        gen_nn_ops.max_pool_grad_with_argmax(
+            inp, grad, argmax, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+            padding="VALID")
+      # max_pool_grad_grad_with_argmax is only implemented for GPUs
+      if test.is_gpu_available():
+        with self.assertRaisesRegex(
+            errors_impl.InvalidArgumentError,
+            r"Expected grad shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+          gen_nn_ops.max_pool_grad_grad_with_argmax(
+              inp, grad, argmax, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+              padding="VALID")
+
+      # Test invalid argmax shape
+      grad = array_ops.ones((1, 1, 1, 1))
+      argmax = array_ops.ones((1, 1, 1, 2), dtype=dtypes.int64)
+      with self.assertRaisesRegex(
+          errors_impl.InvalidArgumentError,
+          r"Expected argmax shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+        gen_nn_ops.max_pool_grad_with_argmax(
+            inp, grad, argmax, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+            padding="VALID")
+      # max_pool_grad_grad_with_argmax is only implemented for GPUs
+      if test.is_gpu_available():
+        with self.assertRaisesRegex(
+            errors_impl.InvalidArgumentError,
+            r"Expected argmax shape to be \[1,1,1,1\], but got \[1,1,1,2\]"):
+          gen_nn_ops.max_pool_grad_grad_with_argmax(
+              inp, grad, argmax, ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1],
+              padding="VALID")
+
 
 def GetMaxPoolFwdTest(input_size, filter_size, strides, padding):