add max_pool2d to AD, add tests for both autodiff and inference mode

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19661

Differential Revision: D15074431

Pulled By: wanchaol

fbshipit-source-id: b31cf2126c2c5d6a12c2ef5dc67b57677652f1fc
diff --git a/test/test_jit.py b/test/test_jit.py
index 33b8b00..ebabfd4 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1809,6 +1809,17 @@
         m = self.createScriptModuleFromGraph(trace)
         self.assertEqual(outputs, m(*inputs))
 
+    def test_max_pool(self):
+        x = torch.rand(20, 16, 10, 10)
+
+        def max_pool2d(x):
+            return F.max_pool2d(x, 2) + 2
+
+        trace = torch.jit.trace(max_pool2d, (x))
+        graph = trace.graph_for(x)
+        FileCheck().check("aten::max_pool2d(").run(graph)
+        self.assertEqual(max_pool2d(x), trace(x))
+
     def test_repeated_input(self):
         def fn(a, b):
             return a + b
@@ -13068,7 +13079,8 @@
     ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
     ('max_pool1d', (S, S, S), (2, 1)),
     ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
-    ('max_pool2d', (S, S, S, S), (2, 1)),
+    ('max_pool2d', (S, S, S, S), (2, 1), '', (True, 'aten::max_pool2d_with_indices')),
+    ('max_pool2d', (S, S, S, S), (2, 1, 1, 1, False, True), 'with_indices', (True, 'aten::max_pool2d_with_indices')),
     ('max_pool3d', (S, S, S, S, S), (2, 1)),
     ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
     ('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 5120a04..908a5b5 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -37,6 +37,7 @@
   static OperatorSet need_trim_grad_ops = {
       "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
       "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
+      "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
   };
   if (need_trim_grad_ops.find(n)) {
     return true;
diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp
index 6799955..2f26f44 100644
--- a/torch/csrc/jit/symbolic_script.cpp
+++ b/torch/csrc/jit/symbolic_script.cpp
@@ -1008,6 +1008,18 @@
 
             return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad), backward
 
+        def max_pool2d(self,
+                       kernel_size: List[int],
+                       stride: List[int],
+                       padding: List[int],
+                       dilation: List[int],
+                       ceil_mode: bool):
+            output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
+            def backward(grad_output):
+                grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
+                return grad_self, None, None, None, None, None
+            return output, backward
+
         def batch_norm(input : Tensor,
                        weight : Optional[Tensor],
                        bias : Optional[Tensor],